Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/agent_rle.cuh  view on Meta::CPAN


        template <typename Index>
        __device__ __forceinline__ bool operator()(T first, T second, Index idx)
        {
            if (!LAST_TILE || (idx < num_remaining))
                return !equality_op(first, second);
            else
                return true;
        }
    };


    // Cache-modified Input iterator wrapper type (for applying cache modifier) for data
    typedef typename If<IsPointer<InputIteratorT>::VALUE,
            CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,      // Wrap the native input pointer with CacheModifiedVLengthnputIterator
            InputIteratorT>::Type                                                       // Directly use the supplied input iterator type
        WrappedInputIteratorT;

    // Parameterized BlockLoad type for data
    typedef BlockLoad<
            T,
            AgentRlePolicyT::BLOCK_THREADS,
            AgentRlePolicyT::ITEMS_PER_THREAD,
            AgentRlePolicyT::LOAD_ALGORITHM>
        BlockLoadT;

    // Parameterized BlockDiscontinuity type for data
    typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT;

    // Parameterized WarpScan type
    typedef WarpScan<LengthOffsetPair> WarpScanPairs;

    // Reduce-length-by-run scan operator
    typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;

    // Callback type for obtaining tile prefix during block scan
    typedef TilePrefixCallbackOp<
            LengthOffsetPair,
            ReduceBySegmentOpT,
            ScanTileStateT>
        TilePrefixCallbackOpT;

    // Warp exchange types
    typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>        WarpExchangePairs;

    typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage;

    typedef WarpExchange<OffsetT, ITEMS_PER_THREAD>                 WarpExchangeOffsets;
    typedef WarpExchange<LengthT, ITEMS_PER_THREAD>                 WarpExchangeLengths;

    typedef LengthOffsetPair WarpAggregates[WARPS];

    // Shared memory type for this threadblock
    struct _TempStorage
    {
        union
        {
            struct
            {
                typename BlockDiscontinuityT::TempStorage       discontinuity;              // Smem needed for discontinuity detection
                typename WarpScanPairs::TempStorage             warp_scan[WARPS];           // Smem needed for warp-synchronous scans
                Uninitialized<LengthOffsetPair[WARPS]>          warp_aggregates;            // Smem needed for sharing warp-wide aggregates
                typename TilePrefixCallbackOpT::TempStorage     prefix;                     // Smem needed for cooperative prefix callback
            };

            // Smem needed for input loading
            typename BlockLoadT::TempStorage                    load;

            // Smem needed for two-phase scatter
            union
            {
                unsigned long long                              align;
                WarpExchangePairsStorage                        exchange_pairs[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeOffsets::TempStorage       exchange_offsets[ACTIVE_EXCHANGE_WARPS];
                typename WarpExchangeLengths::TempStorage       exchange_lengths[ACTIVE_EXCHANGE_WARPS];
            };
        };

        OffsetT             tile_idx;                   // Shared tile index
        LengthOffsetPair    tile_inclusive;             // Inclusive tile prefix
        LengthOffsetPair    tile_exclusive;             // Exclusive tile prefix
    };

    // Alias wrapper allowing storage to be unioned
    struct TempStorage : Uninitialized<_TempStorage> {};


    //---------------------------------------------------------------------
    // Per-thread fields
    //---------------------------------------------------------------------

    _TempStorage&                   temp_storage;       ///< Reference to temp_storage

    WrappedInputIteratorT           d_in;               ///< Pointer to input sequence of data items
    OffsetsOutputIteratorT          d_offsets_out;      ///< Input run offsets
    LengthsOutputIteratorT          d_lengths_out;      ///< Output run lengths

    EqualityOpT                     equality_op;        ///< T equality operator
    ReduceBySegmentOpT              scan_op;            ///< Reduce-length-by-flag scan operator
    OffsetT                         num_items;          ///< Total number of input items


    //---------------------------------------------------------------------
    // Constructor
    //---------------------------------------------------------------------

    // Constructor
    __device__ __forceinline__
    AgentRle(
        TempStorage                 &temp_storage,      ///< [in] Reference to temp_storage
        InputIteratorT              d_in,               ///< [in] Pointer to input sequence of data items
        OffsetsOutputIteratorT      d_offsets_out,      ///< [out] Pointer to output sequence of run offsets
        LengthsOutputIteratorT      d_lengths_out,      ///< [out] Pointer to output sequence of run lengths
        EqualityOpT                 equality_op,        ///< [in] T equality operator
        OffsetT                     num_items)          ///< [in] Total number of input items
    :
        temp_storage(temp_storage.Alias()),
        d_in(d_in),
        d_offsets_out(d_offsets_out),
        d_lengths_out(d_lengths_out),
        equality_op(equality_op),



( run in 0.698 second using v1.01-cache-2.11-cpan-39bf76dae61 )