Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/agent/agent_reduce_by_key.cuh view on Meta::CPAN
// Smem needed for loading values
typename BlockLoadValuesT::TempStorage load_values;
// Smem needed for compacting key value pairs(allows non POD items in this union)
Uninitialized<KeyValuePairT[TILE_ITEMS + 1]> raw_exchange;
};
// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedKeysInputIteratorT d_keys_in; ///< Input keys
UniqueOutputIteratorT d_unique_out; ///< Unique output keys
WrappedValuesInputIteratorT d_values_in; ///< Input values
AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates
NumRunsOutputIteratorT d_num_runs_out; ///< Output pointer for total number of segments identified
EqualityOpT equality_op; ///< KeyT equality operator
ReductionOpT reduction_op; ///< Reduction operator
ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Constructor
__device__ __forceinline__
AgentReduceByKey(
TempStorage& temp_storage, ///< Reference to temp_storage
KeysInputIteratorT d_keys_in, ///< Input keys
UniqueOutputIteratorT d_unique_out, ///< Unique output keys
ValuesInputIteratorT d_values_in, ///< Input values
AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates
NumRunsOutputIteratorT d_num_runs_out, ///< Output pointer for total number of segments identified
EqualityOpT equality_op, ///< KeyT equality operator
ReductionOpT reduction_op) ///< ValueT reduction operator
:
temp_storage(temp_storage.Alias()),
d_keys_in(d_keys_in),
d_unique_out(d_unique_out),
d_values_in(d_values_in),
d_aggregates_out(d_aggregates_out),
d_num_runs_out(d_num_runs_out),
equality_op(equality_op),
reduction_op(reduction_op),
scan_op(reduction_op)
{}
//---------------------------------------------------------------------
// Scatter utility methods
//---------------------------------------------------------------------
/**
* Directly scatter flagged items to output offsets
*/
__device__ __forceinline__ void ScatterDirect(
KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD])
{
// Scatter flagged keys and values
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (segment_flags[ITEM])
{
d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key;
d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value;
}
}
}
/**
* 2-phase scatter flagged items to output offsets
*
* The exclusive scan causes each head flag to be paired with the previous
* value aggregate: the scatter offsets must be decremented for value aggregates
*/
__device__ __forceinline__ void ScatterTwoPhase(
KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
{
CTA_SYNC();
// Compact and scatter pairs
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (segment_flags[ITEM])
{
temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM];
}
}
CTA_SYNC();
for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS)
{
KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item];
d_unique_out[num_tile_segments_prefix + item] = pair.key;
d_aggregates_out[num_tile_segments_prefix + item] = pair.value;
}
}
/**
* Scatter flagged items
*/
__device__ __forceinline__ void Scatter(
KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
{
// Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one
if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS))
{
ScatterTwoPhase(
scatter_items,
segment_flags,
segment_indices,
num_tile_segments,
num_tile_segments_prefix);
}
else
{
ScatterDirect(
scatter_items,
segment_flags,
segment_indices);
}
}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
/**
* Process a tile of input (dynamic chained scan)
*/
template <bool IS_LAST_TILE> ///< Whether the current tile is the last tile
__device__ __forceinline__ void ConsumeTile(
OffsetT num_remaining, ///< Number of global input items remaining (including this tile)
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
KeyOutputT keys[ITEMS_PER_THREAD]; // Tile keys
KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile keys shuffled up
ValueOutputT values[ITEMS_PER_THREAD]; // Tile values
OffsetT head_flags[ITEMS_PER_THREAD]; // Segment head flags
OffsetT segment_indices[ITEMS_PER_THREAD]; // Segment indices
OffsetValuePairT scan_items[ITEMS_PER_THREAD]; // Zipped values and segment flags|indices
KeyValuePairT scatter_items[ITEMS_PER_THREAD]; // Zipped key value pairs for scattering
// Load keys
if (IS_LAST_TILE)
BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining);
else
BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys);
// Load tile predecessor key in first thread
KeyOutputT tile_predecessor;
if (threadIdx.x == 0)
{
tile_predecessor = (tile_idx == 0) ?
keys[0] : // First tile gets repeat of first item (thus first item will not be flagged as a head)
d_keys_in[tile_offset - 1]; // Subsequent tiles get last key from previous tile
}
CTA_SYNC();
// Load values
if (IS_LAST_TILE)
BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining);
else
BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values);
CTA_SYNC();
// Initialize head-flags and shuffle up the previous keys
if (IS_LAST_TILE)
{
// Use custom flag operator to additionally flag the first out-of-bounds item
GuardedInequalityWrapper<EqualityOpT> flag_op(equality_op, num_remaining);
BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads(
head_flags, keys, prev_keys, flag_op, tile_predecessor);
}
else
{
InequalityWrapper<EqualityOpT> flag_op(equality_op);
BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads(
head_flags, keys, prev_keys, flag_op, tile_predecessor);
}
// Zip values and head flags
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
scan_items[ITEM].value = values[ITEM];
scan_items[ITEM].key = head_flags[ITEM];
}
// Perform exclusive tile scan
OffsetValuePairT block_aggregate; // Inclusive block-wide scan aggregate
OffsetT num_segments_prefix; // Number of segments prior to this tile
ValueOutputT total_aggregate; // The tile prefix folded with block_aggregate
if (tile_idx == 0)
{
// Scan first tile
BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate);
num_segments_prefix = 0;
total_aggregate = block_aggregate.value;
// Update tile status if there are successor tiles
if ((!IS_LAST_TILE) && (threadIdx.x == 0))
tile_state.SetInclusive(0, block_aggregate);
}
else
{
// Scan non-first tile
TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op);
block_aggregate = prefix_op.GetBlockAggregate();
num_segments_prefix = prefix_op.GetExclusivePrefix().key;
total_aggregate = reduction_op(
prefix_op.GetExclusivePrefix().value,
block_aggregate.value);
}
// Rezip scatter items and segment indices
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
scatter_items[ITEM].key = prev_keys[ITEM];
scatter_items[ITEM].value = scan_items[ITEM].value;
segment_indices[ITEM] = scan_items[ITEM].key;
}
// At this point, each flagged segment head has:
// - The key for the previous segment
// - The reduced value from the previous segment
// - The segment index for the reduced value
// Scatter flagged keys and values
OffsetT num_tile_segments = block_aggregate.key;
Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix);
// Last thread in last tile will output final count (and last pair, if necessary)
if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1))
{
OffsetT num_segments = num_segments_prefix + num_tile_segments;
// If the last tile is a whole tile, output the final_value
if (num_remaining == TILE_ITEMS)
{
d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1];
d_aggregates_out[num_segments] = total_aggregate;
num_segments++;
}
// Output the total number of items selected
*d_num_runs_out = num_segments;
}
}
/**
* Scan tiles of items as part of a dynamic chained scan
*/
__device__ __forceinline__ void ConsumeRange(
int num_items, ///< Total number of input items
ScanTileStateT& tile_state, ///< Global tile state descriptor
int start_tile) ///< The starting tile for the current grid
{
// Blocks are launched in increasing order, so just assign one tile per block
int tile_idx = start_tile + blockIdx.x; // Current tile index
OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile
OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile)
if (num_remaining > TILE_ITEMS)
{
// Not last tile
ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state);
}
else if (num_remaining > 0)
{
// Last tile
ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state);
}
}
};
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)
( run in 1.178 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )