Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/agent_rle.cuh  view on Meta::CPAN


            // Smem needed for two-phase scatter
            union
            {
                unsigned long long                              align;
                WarpExchangePairsStorage                        exchange_pairs[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeOffsets::TempStorage       exchange_offsets[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeLengths::TempStorage       exchange_lengths[ACTIVE_EXCHANGE_WARPS];
            };
        };

        OffsetT             tile_idx;                   // Shared tile index
        LengthOffsetPair    tile_inclusive;             // Inclusive tile prefix
        LengthOffsetPair    tile_exclusive;             // Exclusive tile prefix
    };

    // Alias wrapper allowing storage to be unioned
    struct TempStorage : Uninitialized<_TempStorage> {};


    //---------------------------------------------------------------------
    // Per-thread fields
    //---------------------------------------------------------------------

    _TempStorage&                   temp_storage;       ///< Reference to temp_storage

    WrappedInputIteratorT           d_in;               ///< Pointer to input sequence of data items
    OffsetsOutputIteratorT          d_offsets_out;      ///< Input run offsets
    LengthsOutputIteratorT          d_lengths_out;      ///< Output run lengths

    EqualityOpT                     equality_op;        ///< T equality operator
    ReduceBySegmentOpT              scan_op;            ///< Reduce-length-by-flag scan operator
    OffsetT                         num_items;          ///< Total number of input items


    //---------------------------------------------------------------------
    // Constructor
    //---------------------------------------------------------------------

    // Constructor
    __device__ __forceinline__
    AgentRle(
        TempStorage                 &temp_storage,      ///< [in] Reference to temp_storage
        InputIteratorT              d_in,               ///< [in] Pointer to input sequence of data items
        OffsetsOutputIteratorT      d_offsets_out,      ///< [out] Pointer to output sequence of run offsets
        LengthsOutputIteratorT      d_lengths_out,      ///< [out] Pointer to output sequence of run lengths
        EqualityOpT                 equality_op,        ///< [in] T equality operator
        OffsetT                     num_items)          ///< [in] Total number of input items
    :
        temp_storage(temp_storage.Alias()),
        d_in(d_in),
        d_offsets_out(d_offsets_out),
        d_lengths_out(d_lengths_out),
        equality_op(equality_op),
        scan_op(cub::Sum()),
        num_items(num_items)
    {}


    //---------------------------------------------------------------------
    // Utility methods for initializing the selections
    //---------------------------------------------------------------------

    template <bool FIRST_TILE, bool LAST_TILE>
    __device__ __forceinline__ void InitializeSelections(
        OffsetT             tile_offset,
        OffsetT             num_remaining,
        T                   (&items)[ITEMS_PER_THREAD],
        LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
    {
        bool                head_flags[ITEMS_PER_THREAD];
        bool                tail_flags[ITEMS_PER_THREAD];

        OobInequalityOp<LAST_TILE> inequality_op(num_remaining, equality_op);

        if (FIRST_TILE && LAST_TILE)
        {
            // First-and-last-tile always head-flags the first item and tail-flags the last item

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tail_flags, items, inequality_op);
        }
        else if (FIRST_TILE)
        {
            // First-tile always head-flags the first item

            // Get the first item from the next tile
            T tile_successor_item;
            if (threadIdx.x == BLOCK_THREADS - 1)
                tile_successor_item = d_in[tile_offset + TILE_ITEMS];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tail_flags, tile_successor_item, items, inequality_op);
        }
        else if (LAST_TILE)
        {
            // Last-tile always flags the last item

            // Get the last item from the previous tile
            T tile_predecessor_item;
            if (threadIdx.x == 0)
                tile_predecessor_item = d_in[tile_offset - 1];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tile_predecessor_item, tail_flags, items, inequality_op);
        }
        else
        {
            // Get the first item from the next tile
            T tile_successor_item;
            if (threadIdx.x == BLOCK_THREADS - 1)
                tile_successor_item = d_in[tile_offset + TILE_ITEMS];

            // Get the last item from the previous tile
            T tile_predecessor_item;
            if (threadIdx.x == 0)
                tile_predecessor_item = d_in[tile_offset - 1];

            BlockDiscontinuityT(temp_storage.discontinuity).FlagHeadsAndTails(
                head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op);
        }

        // Zip counts and runs
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            lengths_and_num_runs[ITEM].key   = head_flags[ITEM] && (!tail_flags[ITEM]);
            lengths_and_num_runs[ITEM].value    = ((!head_flags[ITEM]) || (!tail_flags[ITEM]));
        }
    }

    //---------------------------------------------------------------------
    // Scan utility methods
    //---------------------------------------------------------------------

    /**
     * Scan of allocations
     */
    __device__ __forceinline__ void WarpScanAllocations(
        LengthOffsetPair    &tile_aggregate,
        LengthOffsetPair    &warp_aggregate,
        LengthOffsetPair    &warp_exclusive_in_tile,
        LengthOffsetPair    &thread_exclusive_in_warp,
        LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
    {
        // Perform warpscans
        unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
        int lane_id = LaneId();

        LengthOffsetPair identity;
        identity.key = 0;
        identity.value = 0;

        LengthOffsetPair thread_inclusive;
        LengthOffsetPair thread_aggregate = ThreadReduce(lengths_and_num_runs, scan_op);
        WarpScanPairs(temp_storage.warp_scan[warp_id]).Scan(
            thread_aggregate,
            thread_inclusive,
            thread_exclusive_in_warp,
            identity,
            scan_op);

        // Last lane in each warp shares its warp-aggregate
        if (lane_id == WARP_THREADS - 1)
            temp_storage.warp_aggregates.Alias()[warp_id] = thread_inclusive;

        CTA_SYNC();

        // Accumulate total selected and the warp-wide prefix
        warp_exclusive_in_tile          = identity;
        warp_aggregate                  = temp_storage.warp_aggregates.Alias()[warp_id];
        tile_aggregate                  = temp_storage.warp_aggregates.Alias()[0];

        #pragma unroll
        for (int WARP = 1; WARP < WARPS; ++WARP)
        {
            if (warp_id == WARP)
                warp_exclusive_in_tile = tile_aggregate;

            tile_aggregate = scan_op(tile_aggregate, temp_storage.warp_aggregates.Alias()[WARP]);
        }
    }


    //---------------------------------------------------------------------
    // Utility methods for scattering selections
    //---------------------------------------------------------------------

    /**
     * Two-phase scatter, specialized for warp time-slicing
     */
    template <bool FIRST_TILE>
    __device__ __forceinline__ void ScatterTwoPhase(
        OffsetT             tile_num_runs_exclusive_in_global,
        OffsetT             warp_num_runs_aggregate,
        OffsetT             warp_num_runs_exclusive_in_tile,
        OffsetT             (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
        LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD],
        Int2Type<true>      is_warp_time_slice)
    {
        unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
        int lane_id = LaneId();

        // Locally compact items within the warp (first warp)
        if (warp_id == 0)
        {
            WarpExchangePairs(temp_storage.exchange_pairs[0]).ScatterToStriped(lengths_and_offsets, thread_num_runs_exclusive_in_warp);
        }

        // Locally compact items within the warp (remaining warps)
        #pragma unroll
        for (int SLICE = 1; SLICE < WARPS; ++SLICE)
        {
            CTA_SYNC();

            if (warp_id == SLICE)
            {
                WarpExchangePairs(temp_storage.exchange_pairs[0]).ScatterToStriped(lengths_and_offsets, thread_num_runs_exclusive_in_warp);
            }
        }

        // Global scatter
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
        {
            if ((ITEM * WARP_THREADS) < warp_num_runs_aggregate - lane_id)
            {
                OffsetT item_offset =
                    tile_num_runs_exclusive_in_global +
                    warp_num_runs_exclusive_in_tile +
                    (ITEM * WARP_THREADS) + lane_id;

                // Scatter offset
                d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key;

                // Scatter length if not the first (global) length
                if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0))
                {
                    d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value;
                }
            }
        }
    }


    /**

xgboost/cub/cub/agent/agent_rle.cuh  view on Meta::CPAN

            // First warp computes tile prefix in lane 0
            TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.prefix, Sum(), tile_idx);
            unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
            if (warp_id == 0)
            {
                prefix_op(tile_aggregate);
                if (threadIdx.x == 0)
                    temp_storage.tile_exclusive = prefix_op.exclusive_prefix;
            }

            CTA_SYNC();

            LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive;

            // Update thread_exclusive_in_warp to fold in warp and tile run-lengths
            LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile);
            if (thread_exclusive_in_warp.key == 0)
                thread_exclusive_in_warp.value += thread_exclusive.value;

            // Downsweep scan through lengths_and_num_runs
            LengthOffsetPair    lengths_and_num_runs2[ITEMS_PER_THREAD];
            LengthOffsetPair    lengths_and_offsets[ITEMS_PER_THREAD];
            OffsetT             thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];

            ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);

            // Zip
            #pragma unroll
            for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
            {
                lengths_and_offsets[ITEM].value         = lengths_and_num_runs2[ITEM].value;
                lengths_and_offsets[ITEM].key        = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
                thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
                                                                lengths_and_num_runs2[ITEM].key :         // keep
                                                                WARP_THREADS * ITEMS_PER_THREAD;            // discard
            }

            OffsetT tile_num_runs_aggregate              = tile_aggregate.key;
            OffsetT tile_num_runs_exclusive_in_global    = tile_exclusive_in_global.key;
            OffsetT warp_num_runs_aggregate              = warp_aggregate.key;
            OffsetT warp_num_runs_exclusive_in_tile      = warp_exclusive_in_tile.key;

            // Scatter
            Scatter<false>(
                tile_num_runs_aggregate,
                tile_num_runs_exclusive_in_global,
                warp_num_runs_aggregate,
                warp_num_runs_exclusive_in_tile,
                thread_num_runs_exclusive_in_warp,
                lengths_and_offsets);

            // Return running total (inclusive of this tile)
            return prefix_op.inclusive_prefix;
        }
    }


    /**
     * Scan tiles of items as part of a dynamic chained scan
     */
    template <typename NumRunsIteratorT>            ///< Output iterator type for recording number of items selected
    __device__ __forceinline__ void ConsumeRange(
        int                 num_tiles,              ///< Total number of input tiles
        ScanTileStateT&     tile_status,            ///< Global list of tile status
        NumRunsIteratorT    d_num_runs_out)         ///< Output pointer for total number of runs identified
    {
        // Blocks are launched in increasing order, so just assign one tile per block
        int     tile_idx        = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
        OffsetT tile_offset     = tile_idx * TILE_ITEMS;                  // Global offset for the current tile
        OffsetT num_remaining   = num_items - tile_offset;                  // Remaining items (including this tile)

        if (tile_idx < num_tiles - 1)
        {
            // Not the last tile (full)
            ConsumeTile<false>(num_items, num_remaining, tile_idx, tile_offset, tile_status);
        }
        else if (num_remaining > 0)
        {
            // The last tile (possibly partially-full)
            LengthOffsetPair running_total = ConsumeTile<true>(num_items, num_remaining, tile_idx, tile_offset, tile_status);

            if (threadIdx.x == 0)
            {
                // Output the total number of items selected
                *d_num_runs_out = running_total.key;

                // The inclusive prefix contains accumulated length reduction for the last run
                if (running_total.key > 0)
                    d_lengths_out[running_total.key - 1] = running_total.value;
            }
        }
    }
};


}               // CUB namespace
CUB_NS_POSTFIX  // Optional outer namespace(s)



( run in 1.435 second using v1.01-cache-2.11-cpan-13bb782fe5a )