Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/agent/agent_segment_fixup.cuh view on Meta::CPAN
pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value);
}
// Flush last item if valid
ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key;
if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0))
atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value);
}
/**
* Process input tile. Specialized for reduce-by-key fixup
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ void ConsumeTile(
OffsetT num_remaining, ///< Number of global input items remaining (including this tile)
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state, ///< Global tile state descriptor
Int2Type<false> use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key)
{
KeyValuePairT pairs[ITEMS_PER_THREAD];
KeyValuePairT scatter_pairs[ITEMS_PER_THREAD];
// Load pairs
KeyValuePairT oob_pair;
oob_pair.key = -1;
if (IS_LAST_TILE)
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
else
BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);
CTA_SYNC();
KeyValuePairT tile_aggregate;
if (tile_idx == 0)
{
// Exclusive scan of values and segment_flags
BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate);
// Update tile status if this is not the last tile
if (threadIdx.x == 0)
{
// Set first segment id to not trigger a flush (invalid from exclusive scan)
scatter_pairs[0].key = pairs[0].key;
if (!IS_LAST_TILE)
tile_state.SetInclusive(0, tile_aggregate);
}
}
else
{
// Exclusive scan of values and segment_flags
TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op);
tile_aggregate = prefix_op.GetBlockAggregate();
}
// Scatter updated values
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (scatter_pairs[ITEM].key != pairs[ITEM].key)
{
// Update the value at the key location
ValueT value = d_fixup_in[scatter_pairs[ITEM].key];
value = reduction_op(value, scatter_pairs[ITEM].value);
d_aggregates_out[scatter_pairs[ITEM].key] = value;
}
}
// Finalize the last item
if (IS_LAST_TILE)
{
// Last thread will output final count and last item, if necessary
if (threadIdx.x == BLOCK_THREADS - 1)
{
// If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment
if (num_remaining == TILE_ITEMS)
{
// Update the value at the key location
OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key;
d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]);
}
}
}
}
/**
* Scan tiles of items as part of a dynamic chained scan
*/
__device__ __forceinline__ void ConsumeRange(
int num_items, ///< Total number of input items
int num_tiles, ///< Total number of input tiles
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
// Blocks are launched in increasing order, so just assign one tile per block
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index
OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile
OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile)
if (num_remaining > TILE_ITEMS)
{
// Not the last tile (full)
ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
}
else if (num_remaining > 0)
{
// The last tile (possibly partially-full)
ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
}
}
};
} // CUB namespace
( run in 0.941 second using v1.01-cache-2.11-cpan-39bf76dae61 )