Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/single_pass_scan_operators.cuh  view on Meta::CPAN

template <
    typename    T,
    typename    ScanOpT,
    typename    ScanTileStateT,
    int         PTX_ARCH = CUB_PTX_ARCH>
struct TilePrefixCallbackOp
{
    // Parameterized warp reduce
    typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT;

    // Temporary storage type
    struct _TempStorage
    {
        typename WarpReduceT::TempStorage   warp_reduce;
        T                                   exclusive_prefix;
        T                                   inclusive_prefix;
        T                                   block_aggregate;
    };

    // Alias wrapper allowing temporary storage to be unioned
    struct TempStorage : Uninitialized<_TempStorage> {};

    // Type of status word
    typedef typename ScanTileStateT::StatusWord StatusWord;

    // Fields
    _TempStorage&               temp_storage;       ///< Reference to a warp-reduction instance
    ScanTileStateT&             tile_status;        ///< Interface to tile status
    ScanOpT                     scan_op;            ///< Binary scan operator
    int                         tile_idx;           ///< The current tile index
    T                           exclusive_prefix;   ///< Exclusive prefix for the tile
    T                           inclusive_prefix;   ///< Inclusive prefix for the tile

    // Constructor
    __device__ __forceinline__
    TilePrefixCallbackOp(
        ScanTileStateT       &tile_status,
        TempStorage         &temp_storage,
        ScanOpT              scan_op,
        int                 tile_idx)
    :
        temp_storage(temp_storage.Alias()),
        tile_status(tile_status),
        scan_op(scan_op),
        tile_idx(tile_idx) {}


    // Block until all predecessors within the warp-wide window have non-invalid status
    __device__ __forceinline__
    void ProcessWindow(
        int         predecessor_idx,        ///< Preceding tile index to inspect
        StatusWord  &predecessor_status,    ///< [out] Preceding tile status
        T           &window_aggregate)      ///< [out] Relevant partial reduction from this window of preceding tiles
    {
        T value;
        tile_status.WaitForValid(predecessor_idx, predecessor_status, value);

        // Perform a segmented reduction to get the prefix for the current window.
        // Use the swizzled scan operator because we are now scanning *down* towards thread0.

        int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE));
        window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce(
            value,
            tail_flag,
            SwizzleScanOp<ScanOpT>(scan_op));
    }


    // BlockScan prefix callback functor (called by the first warp)
    __device__ __forceinline__
    T operator()(T block_aggregate)
    {

        // Update our status with our tile-aggregate
        if (threadIdx.x == 0)
        {
            temp_storage.block_aggregate = block_aggregate;
            tile_status.SetPartial(tile_idx, block_aggregate);
        }

        int         predecessor_idx = tile_idx - threadIdx.x - 1;
        StatusWord  predecessor_status;
        T           window_aggregate;

        // Wait for the warp-wide window of predecessor tiles to become valid
        ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);

        // The exclusive tile prefix starts out as the current window aggregate
        exclusive_prefix = window_aggregate;

        // Keep sliding the window back until we come across a tile whose inclusive prefix is known
        while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff))
        {
            predecessor_idx -= CUB_PTX_WARP_THREADS;

            // Update exclusive tile prefix with the window prefix
            ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);
            exclusive_prefix = scan_op(window_aggregate, exclusive_prefix);
        }

        // Compute the inclusive tile prefix and update the status for this tile
        if (threadIdx.x == 0)
        {
            inclusive_prefix = scan_op(exclusive_prefix, block_aggregate);
            tile_status.SetInclusive(tile_idx, inclusive_prefix);

            temp_storage.exclusive_prefix = exclusive_prefix;
            temp_storage.inclusive_prefix = inclusive_prefix;
        }

        // Return exclusive_prefix
        return exclusive_prefix;
    }

    // Get the exclusive prefix stored in temporary storage
    __device__ __forceinline__
    T GetExclusivePrefix()
    {
        return temp_storage.exclusive_prefix;
    }

    // Get the inclusive prefix stored in temporary storage
    __device__ __forceinline__
    T GetInclusivePrefix()



( run in 1.957 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )