Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/agent_rle.cuh  view on Meta::CPAN






/******************************************************************************
 * Thread block abstractions
 ******************************************************************************/

/**
 * \brief AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode 
 */
template <
    typename    AgentRlePolicyT,        ///< Parameterized AgentRlePolicyT tuning policy type
    typename    InputIteratorT,         ///< Random-access input iterator type for data
    typename    OffsetsOutputIteratorT, ///< Random-access output iterator type for offset values
    typename    LengthsOutputIteratorT, ///< Random-access output iterator type for length values
    typename    EqualityOpT,            ///< T equality operator type
    typename    OffsetT>                ///< Signed integer type for global offsets
struct AgentRle
{
    //---------------------------------------------------------------------
    // Types and constants
    //---------------------------------------------------------------------

    /// The input value type
    typedef typename std::iterator_traits<InputIteratorT>::value_type T;

    /// The lengths output value type
    typedef typename If<(Equals<typename std::iterator_traits<LengthsOutputIteratorT>::value_type, void>::VALUE),   // LengthT =  (if output iterator's value type is void) ?
        OffsetT,                                                                                                    // ... then the OffsetT type,
        typename std::iterator_traits<LengthsOutputIteratorT>::value_type>::Type LengthT;                           // ... else the output iterator's value type

    /// Tuple type for scanning (pairs run-length and run-index)
    typedef KeyValuePair<OffsetT, LengthT> LengthOffsetPair;

    /// Tile status descriptor interface type
    typedef ReduceByKeyScanTileState<LengthT, OffsetT> ScanTileStateT;

    // Constants
    enum
    {
        WARP_THREADS            = CUB_WARP_THREADS(PTX_ARCH),
        BLOCK_THREADS           = AgentRlePolicyT::BLOCK_THREADS,
        ITEMS_PER_THREAD        = AgentRlePolicyT::ITEMS_PER_THREAD,
        WARP_ITEMS              = WARP_THREADS * ITEMS_PER_THREAD,
        TILE_ITEMS              = BLOCK_THREADS * ITEMS_PER_THREAD,
        WARPS                   = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,

        /// Whether or not to sync after loading data
        SYNC_AFTER_LOAD         = (AgentRlePolicyT::LOAD_ALGORITHM != BLOCK_LOAD_DIRECT),

        /// Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
        STORE_WARP_TIME_SLICING = AgentRlePolicyT::STORE_WARP_TIME_SLICING,
        ACTIVE_EXCHANGE_WARPS   = (STORE_WARP_TIME_SLICING) ? 1 : WARPS,
    };


    /**
     * Special operator that signals all out-of-bounds items are not equal to everything else,
     * forcing both (1) the last item to be tail-flagged and (2) all oob items to be marked
     * trivial.
     */
    template <bool LAST_TILE>
    struct OobInequalityOp
    {
        OffsetT         num_remaining;
        EqualityOpT      equality_op;

        __device__ __forceinline__ OobInequalityOp(
            OffsetT     num_remaining,
            EqualityOpT  equality_op)
        :
            num_remaining(num_remaining),
            equality_op(equality_op)
        {}

        template <typename Index>
        __device__ __forceinline__ bool operator()(T first, T second, Index idx)
        {
            if (!LAST_TILE || (idx < num_remaining))
                return !equality_op(first, second);
            else
                return true;
        }
    };


    // Cache-modified Input iterator wrapper type (for applying cache modifier) for data
    typedef typename If<IsPointer<InputIteratorT>::VALUE,
            CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,      // Wrap the native input pointer with CacheModifiedVLengthnputIterator
            InputIteratorT>::Type                                                       // Directly use the supplied input iterator type
        WrappedInputIteratorT;

    // Parameterized BlockLoad type for data
    typedef BlockLoad<
            T,
            AgentRlePolicyT::BLOCK_THREADS,
            AgentRlePolicyT::ITEMS_PER_THREAD,
            AgentRlePolicyT::LOAD_ALGORITHM>
        BlockLoadT;

    // Parameterized BlockDiscontinuity type for data
    typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT;

    // Parameterized WarpScan type
    typedef WarpScan<LengthOffsetPair> WarpScanPairs;

    // Reduce-length-by-run scan operator
    typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;

    // Callback type for obtaining tile prefix during block scan
    typedef TilePrefixCallbackOp<
            LengthOffsetPair,
            ReduceBySegmentOpT,
            ScanTileStateT>
        TilePrefixCallbackOpT;

    // Warp exchange types
    typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>        WarpExchangePairs;

    typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage;

    typedef WarpExchange<OffsetT, ITEMS_PER_THREAD>                 WarpExchangeOffsets;
    typedef WarpExchange<LengthT, ITEMS_PER_THREAD>                 WarpExchangeLengths;

    typedef LengthOffsetPair WarpAggregates[WARPS];

    // Shared memory type for this threadblock
    struct _TempStorage
    {
        union
        {
            struct
            {
                typename BlockDiscontinuityT::TempStorage       discontinuity;              // Smem needed for discontinuity detection
                typename WarpScanPairs::TempStorage             warp_scan[WARPS];           // Smem needed for warp-synchronous scans
                Uninitialized<LengthOffsetPair[WARPS]>          warp_aggregates;            // Smem needed for sharing warp-wide aggregates
                typename TilePrefixCallbackOpT::TempStorage     prefix;                     // Smem needed for cooperative prefix callback
            };

            // Smem needed for input loading
            typename BlockLoadT::TempStorage                    load;

            // Smem needed for two-phase scatter
            union
            {
                unsigned long long                              align;
                WarpExchangePairsStorage                        exchange_pairs[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeOffsets::TempStorage       exchange_offsets[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeLengths::TempStorage       exchange_lengths[ACTIVE_EXCHANGE_WARPS];
            };
        };

        OffsetT             tile_idx;                   // Shared tile index
        LengthOffsetPair    tile_inclusive;             // Inclusive tile prefix
        LengthOffsetPair    tile_exclusive;             // Exclusive tile prefix
    };

    // Alias wrapper allowing storage to be unioned
    struct TempStorage : Uninitialized<_TempStorage> {};


    //---------------------------------------------------------------------
    // Per-thread fields
    //---------------------------------------------------------------------

    _TempStorage&                   temp_storage;       ///< Reference to temp_storage

    WrappedInputIteratorT           d_in;               ///< Pointer to input sequence of data items
    OffsetsOutputIteratorT          d_offsets_out;      ///< Input run offsets
    LengthsOutputIteratorT          d_lengths_out;      ///< Output run lengths

    EqualityOpT                     equality_op;        ///< T equality operator
    ReduceBySegmentOpT              scan_op;            ///< Reduce-length-by-flag scan operator
    OffsetT                         num_items;          ///< Total number of input items


    //---------------------------------------------------------------------
    // Constructor
    //---------------------------------------------------------------------

    // Constructor
    __device__ __forceinline__
    AgentRle(
        TempStorage                 &temp_storage,      ///< [in] Reference to temp_storage
        InputIteratorT              d_in,               ///< [in] Pointer to input sequence of data items
        OffsetsOutputIteratorT      d_offsets_out,      ///< [out] Pointer to output sequence of run offsets
        LengthsOutputIteratorT      d_lengths_out,      ///< [out] Pointer to output sequence of run lengths
        EqualityOpT                 equality_op,        ///< [in] T equality operator
        OffsetT                     num_items)          ///< [in] Total number of input items
    :
        temp_storage(temp_storage.Alias()),
        d_in(d_in),
        d_offsets_out(d_offsets_out),
        d_lengths_out(d_lengths_out),
        equality_op(equality_op),
        scan_op(cub::Sum()),
        num_items(num_items)
    {}


    //---------------------------------------------------------------------
    // Utility methods for initializing the selections
    //---------------------------------------------------------------------

    template <bool FIRST_TILE, bool LAST_TILE>
    __device__ __forceinline__ void InitializeSelections(
        OffsetT             tile_offset,
        OffsetT             num_remaining,
        T                   (&items)[ITEMS_PER_THREAD],
        LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
    {
        bool                head_flags[ITEMS_PER_THREAD];
        bool                tail_flags[ITEMS_PER_THREAD];

        OobInequalityOp<LAST_TILE> inequality_op(num_remaining, equality_op);

        if (FIRST_TILE && LAST_TILE)
        {
            // First-and-last-tile always head-flags the first item and tail-flags the last item

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tail_flags, items, inequality_op);
        }
        else if (FIRST_TILE)
        {
            // First-tile always head-flags the first item

            // Get the first item from the next tile
            T tile_successor_item;
            if (threadIdx.x == BLOCK_THREADS - 1)
                tile_successor_item = d_in[tile_offset + TILE_ITEMS];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tail_flags, tile_successor_item, items, inequality_op);
        }
        else if (LAST_TILE)
        {
            // Last-tile always flags the last item

            // Get the last item from the previous tile
            T tile_predecessor_item;
            if (threadIdx.x == 0)
                tile_predecessor_item = d_in[tile_offset - 1];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tile_predecessor_item, tail_flags, items, inequality_op);
        }
        else
        {
            // Get the first item from the next tile
            T tile_successor_item;
            if (threadIdx.x == BLOCK_THREADS - 1)
                tile_successor_item = d_in[tile_offset + TILE_ITEMS];

            // Get the last item from the previous tile
            T tile_predecessor_item;
            if (threadIdx.x == 0)
                tile_predecessor_item = d_in[tile_offset - 1];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op);
        }

        // Zip counts and runs
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            lengths_and_num_runs[ITEM].key   = head_flags[ITEM] && (!tail_flags[ITEM]);
            lengths_and_num_runs[ITEM].value    = ((!head_flags[ITEM]) || (!tail_flags[ITEM]));
        }
    }

    //---------------------------------------------------------------------
    // Scan utility methods
    //---------------------------------------------------------------------

    /**
     * Scan of allocations
     */
    __device__ __forceinline__ void WarpScanAllocations(
        LengthOffsetPair    &tile_aggregate,
        LengthOffsetPair    &warp_aggregate,
        LengthOffsetPair    &warp_exclusive_in_tile,
        LengthOffsetPair    &thread_exclusive_in_warp,
        LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
    {
        // Perform warpscans
        unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
        int lane_id = LaneId();

        LengthOffsetPair identity;
        identity.key = 0;
        identity.value = 0;

        LengthOffsetPair thread_inclusive;
        LengthOffsetPair thread_aggregate = ThreadReduce(lengths_and_num_runs, scan_op);
        WarpScanPairs(temp_storage.warp_scan[warp_id]).Scan(
            thread_aggregate,
            thread_inclusive,
            thread_exclusive_in_warp,
            identity,
            scan_op);

        // Last lane in each warp shares its warp-aggregate
        if (lane_id == WARP_THREADS - 1)
            temp_storage.warp_aggregates.Alias()[warp_id] = thread_inclusive;

        CTA_SYNC();

        // Accumulate total selected and the warp-wide prefix
        warp_exclusive_in_tile          = identity;
        warp_aggregate                  = temp_storage.warp_aggregates.Alias()[warp_id];
        tile_aggregate                  = temp_storage.warp_aggregates.Alias()[0];

        #pragma unroll
        for (int WARP = 1; WARP < WARPS; ++WARP)
        {
            if (warp_id == WARP)
                warp_exclusive_in_tile = tile_aggregate;

            tile_aggregate = scan_op(tile_aggregate, temp_storage.warp_aggregates.Alias()[WARP]);
        }
    }


    //---------------------------------------------------------------------
    // Utility methods for scattering selections
    //---------------------------------------------------------------------

xgboost/cub/cub/agent/agent_rle.cuh  view on Meta::CPAN

        LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD])
    {
        if ((ITEMS_PER_THREAD == 1) || (tile_num_runs_aggregate < BLOCK_THREADS))
        {
            // Direct scatter if the warp has any items
            if (warp_num_runs_aggregate)
            {
                ScatterDirect<FIRST_TILE>(
                    tile_num_runs_exclusive_in_global,
                    warp_num_runs_aggregate,
                    warp_num_runs_exclusive_in_tile,
                    thread_num_runs_exclusive_in_warp,
                    lengths_and_offsets);
            }
        }
        else
        {
            // Scatter two phase
            ScatterTwoPhase<FIRST_TILE>(
                tile_num_runs_exclusive_in_global,
                warp_num_runs_aggregate,
                warp_num_runs_exclusive_in_tile,
                thread_num_runs_exclusive_in_warp,
                lengths_and_offsets,
                Int2Type<STORE_WARP_TIME_SLICING>());
        }
    }



    //---------------------------------------------------------------------
    // Cooperatively scan a device-wide sequence of tiles with other CTAs
    //---------------------------------------------------------------------

    /**
     * Process a tile of input (dynamic chained scan)
     */
    template <
        bool                LAST_TILE>
    __device__ __forceinline__ LengthOffsetPair ConsumeTile(
        OffsetT             num_items,          ///< Total number of global input items
        OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
        int                 tile_idx,           ///< Tile index
        OffsetT             tile_offset,       ///< Tile offset
        ScanTileStateT       &tile_status)       ///< Global list of tile status
    {
        if (tile_idx == 0)
        {
            // First tile

            // Load items
            T items[ITEMS_PER_THREAD];
            if (LAST_TILE)
                BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining, T());
            else
                BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);

            if (SYNC_AFTER_LOAD)
                CTA_SYNC();

            // Set flags
            LengthOffsetPair    lengths_and_num_runs[ITEMS_PER_THREAD];

            InitializeSelections<true, LAST_TILE>(
                tile_offset,
                num_remaining,
                items,
                lengths_and_num_runs);

            // Exclusive scan of lengths and runs
            LengthOffsetPair tile_aggregate;
            LengthOffsetPair warp_aggregate;
            LengthOffsetPair warp_exclusive_in_tile;
            LengthOffsetPair thread_exclusive_in_warp;

            WarpScanAllocations(
                tile_aggregate,
                warp_aggregate,
                warp_exclusive_in_tile,
                thread_exclusive_in_warp,
                lengths_and_num_runs);

            // Update tile status if this is not the last tile
            if (!LAST_TILE && (threadIdx.x == 0))
                tile_status.SetInclusive(0, tile_aggregate);

            // Update thread_exclusive_in_warp to fold in warp run-length
            if (thread_exclusive_in_warp.key == 0)
                thread_exclusive_in_warp.value += warp_exclusive_in_tile.value;

            LengthOffsetPair    lengths_and_offsets[ITEMS_PER_THREAD];
            OffsetT             thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];
            LengthOffsetPair    lengths_and_num_runs2[ITEMS_PER_THREAD];

            // Downsweep scan through lengths_and_num_runs
            ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);

            // Zip

            #pragma unroll
            for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
            {
                lengths_and_offsets[ITEM].value         = lengths_and_num_runs2[ITEM].value;
                lengths_and_offsets[ITEM].key        = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
                thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
                                                                lengths_and_num_runs2[ITEM].key :         // keep
                                                                WARP_THREADS * ITEMS_PER_THREAD;            // discard
            }

            OffsetT tile_num_runs_aggregate              = tile_aggregate.key;
            OffsetT tile_num_runs_exclusive_in_global    = 0;
            OffsetT warp_num_runs_aggregate              = warp_aggregate.key;
            OffsetT warp_num_runs_exclusive_in_tile      = warp_exclusive_in_tile.key;

            // Scatter
            Scatter<true>(
                tile_num_runs_aggregate,
                tile_num_runs_exclusive_in_global,
                warp_num_runs_aggregate,
                warp_num_runs_exclusive_in_tile,
                thread_num_runs_exclusive_in_warp,
                lengths_and_offsets);

            // Return running total (inclusive of this tile)
            return tile_aggregate;
        }
        else
        {
            // Not first tile

            // Load items
            T items[ITEMS_PER_THREAD];
            if (LAST_TILE)
                BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining, T());
            else
                BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);

            if (SYNC_AFTER_LOAD)
                CTA_SYNC();

            // Set flags
            LengthOffsetPair    lengths_and_num_runs[ITEMS_PER_THREAD];

            InitializeSelections<false, LAST_TILE>(
                tile_offset,
                num_remaining,
                items,
                lengths_and_num_runs);

            // Exclusive scan of lengths and runs
            LengthOffsetPair tile_aggregate;
            LengthOffsetPair warp_aggregate;
            LengthOffsetPair warp_exclusive_in_tile;
            LengthOffsetPair thread_exclusive_in_warp;

            WarpScanAllocations(
                tile_aggregate,
                warp_aggregate,
                warp_exclusive_in_tile,
                thread_exclusive_in_warp,
                lengths_and_num_runs);

            // First warp computes tile prefix in lane 0
            TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.prefix, Sum(), tile_idx);
            unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
            if (warp_id == 0)
            {
                prefix_op(tile_aggregate);
                if (threadIdx.x == 0)
                    temp_storage.tile_exclusive = prefix_op.exclusive_prefix;
            }

            CTA_SYNC();

            LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive;

            // Update thread_exclusive_in_warp to fold in warp and tile run-lengths
            LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile);
            if (thread_exclusive_in_warp.key == 0)
                thread_exclusive_in_warp.value += thread_exclusive.value;

            // Downsweep scan through lengths_and_num_runs
            LengthOffsetPair    lengths_and_num_runs2[ITEMS_PER_THREAD];
            LengthOffsetPair    lengths_and_offsets[ITEMS_PER_THREAD];
            OffsetT             thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];

            ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);

            // Zip
            #pragma unroll
            for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
            {
                lengths_and_offsets[ITEM].value         = lengths_and_num_runs2[ITEM].value;
                lengths_and_offsets[ITEM].key        = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
                thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
                                                                lengths_and_num_runs2[ITEM].key :         // keep
                                                                WARP_THREADS * ITEMS_PER_THREAD;            // discard
            }

            OffsetT tile_num_runs_aggregate              = tile_aggregate.key;
            OffsetT tile_num_runs_exclusive_in_global    = tile_exclusive_in_global.key;



( run in 1.881 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )