Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/agent/single_pass_scan_operators.cuh view on Meta::CPAN
template <
typename T,
typename ScanOpT,
typename ScanTileStateT,
int PTX_ARCH = CUB_PTX_ARCH>
struct TilePrefixCallbackOp
{
// Parameterized warp reduce
typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT;
// Temporary storage type
struct _TempStorage
{
typename WarpReduceT::TempStorage warp_reduce;
T exclusive_prefix;
T inclusive_prefix;
T block_aggregate;
};
// Alias wrapper allowing temporary storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
// Type of status word
typedef typename ScanTileStateT::StatusWord StatusWord;
// Fields
_TempStorage& temp_storage; ///< Reference to a warp-reduction instance
ScanTileStateT& tile_status; ///< Interface to tile status
ScanOpT scan_op; ///< Binary scan operator
int tile_idx; ///< The current tile index
T exclusive_prefix; ///< Exclusive prefix for the tile
T inclusive_prefix; ///< Inclusive prefix for the tile
// Constructor
__device__ __forceinline__
TilePrefixCallbackOp(
ScanTileStateT &tile_status,
TempStorage &temp_storage,
ScanOpT scan_op,
int tile_idx)
:
temp_storage(temp_storage.Alias()),
tile_status(tile_status),
scan_op(scan_op),
tile_idx(tile_idx) {}
// Block until all predecessors within the warp-wide window have non-invalid status
__device__ __forceinline__
void ProcessWindow(
int predecessor_idx, ///< Preceding tile index to inspect
StatusWord &predecessor_status, ///< [out] Preceding tile status
T &window_aggregate) ///< [out] Relevant partial reduction from this window of preceding tiles
{
T value;
tile_status.WaitForValid(predecessor_idx, predecessor_status, value);
// Perform a segmented reduction to get the prefix for the current window.
// Use the swizzled scan operator because we are now scanning *down* towards thread0.
int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE));
window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce(
value,
tail_flag,
SwizzleScanOp<ScanOpT>(scan_op));
}
// BlockScan prefix callback functor (called by the first warp)
__device__ __forceinline__
T operator()(T block_aggregate)
{
// Update our status with our tile-aggregate
if (threadIdx.x == 0)
{
temp_storage.block_aggregate = block_aggregate;
tile_status.SetPartial(tile_idx, block_aggregate);
}
int predecessor_idx = tile_idx - threadIdx.x - 1;
StatusWord predecessor_status;
T window_aggregate;
// Wait for the warp-wide window of predecessor tiles to become valid
ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);
// The exclusive tile prefix starts out as the current window aggregate
exclusive_prefix = window_aggregate;
// Keep sliding the window back until we come across a tile whose inclusive prefix is known
while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff))
{
predecessor_idx -= CUB_PTX_WARP_THREADS;
// Update exclusive tile prefix with the window prefix
ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);
exclusive_prefix = scan_op(window_aggregate, exclusive_prefix);
}
// Compute the inclusive tile prefix and update the status for this tile
if (threadIdx.x == 0)
{
inclusive_prefix = scan_op(exclusive_prefix, block_aggregate);
tile_status.SetInclusive(tile_idx, inclusive_prefix);
temp_storage.exclusive_prefix = exclusive_prefix;
temp_storage.inclusive_prefix = inclusive_prefix;
}
// Return exclusive_prefix
return exclusive_prefix;
}
// Get the exclusive prefix stored in temporary storage
__device__ __forceinline__
T GetExclusivePrefix()
{
return temp_storage.exclusive_prefix;
}
// Get the inclusive prefix stored in temporary storage
__device__ __forceinline__
T GetInclusivePrefix()
( run in 1.957 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )