Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/device/dispatch/dispatch_radix_sort.cuh  view on Meta::CPAN


    // BlockRadixSort type
    typedef BlockRadixSort<
            KeyT,
            BLOCK_THREADS,
            ITEMS_PER_THREAD,
            ValueT,
            ChainedPolicyT::ActivePolicy::SingleTilePolicy::RADIX_BITS,
            ChainedPolicyT::ActivePolicy::SingleTilePolicy::MEMOIZE_OUTER_SCAN,
            ChainedPolicyT::ActivePolicy::SingleTilePolicy::INNER_SCAN_ALGORITHM>
        BlockRadixSortT;

    // BlockLoad type (keys)
    typedef BlockLoad<
        KeyT,
        BLOCK_THREADS,
        ITEMS_PER_THREAD,
        ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadKeys;

    // BlockLoad type (values)
    typedef BlockLoad<
        ValueT,
        BLOCK_THREADS,
        ITEMS_PER_THREAD,
        ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadValues;

    // Unsigned word for key bits
    typedef typename Traits<KeyT>::UnsignedBits UnsignedBitsT;

    // Shared memory storage
    __shared__ union
    {
        typename BlockRadixSortT::TempStorage       sort;
        typename BlockLoadKeys::TempStorage         load_keys;
        typename BlockLoadValues::TempStorage       load_values;

    } temp_storage;

    // Keys and values for the block
    KeyT            keys[ITEMS_PER_THREAD];
    ValueT          values[ITEMS_PER_THREAD];

    // Get default (min/max) value for out-of-bounds keys
    UnsignedBitsT   default_key_bits = (IS_DESCENDING) ? Traits<KeyT>::LOWEST_KEY : Traits<KeyT>::MAX_KEY;
    KeyT            default_key = reinterpret_cast<KeyT&>(default_key_bits);

    // Load keys
    BlockLoadKeys(temp_storage.load_keys).Load(d_keys_in, keys, num_items, default_key);

    CTA_SYNC();

    // Load values
    if (!KEYS_ONLY)
    {
        BlockLoadValues(temp_storage.load_values).Load(d_values_in, values, num_items);

        CTA_SYNC();
    }

    // Sort tile
    BlockRadixSortT(temp_storage.sort).SortBlockedToStriped(
        keys,
        values,
        current_bit,
        end_bit,
        Int2Type<IS_DESCENDING>(),
        Int2Type<KEYS_ONLY>());

    // Store keys and values
    #pragma unroll
    for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
    {
        int item_offset = ITEM * BLOCK_THREADS + threadIdx.x;
        if (item_offset < num_items)
        {
            d_keys_out[item_offset] = keys[ITEM];
            if (!KEYS_ONLY)
                d_values_out[item_offset] = values[ITEM];
        }
    }
}


/**
 * Segmented radix sorting pass (one block per segment)
 */
template <
    typename                ChainedPolicyT,                 ///< Chained tuning policy
    bool                    ALT_DIGIT_BITS,                 ///< Whether or not to use the alternate (lower-bits) policy
    bool                    IS_DESCENDING,                  ///< Whether or not the sorted-order is high-to-low
    typename                KeyT,                           ///< Key type
    typename                ValueT,                         ///< Value type
    typename                OffsetT>                        ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
    ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS :
    ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS))
__global__ void DeviceSegmentedRadixSortKernel(
    const KeyT              *d_keys_in,                     ///< [in] Input keys buffer
    KeyT                    *d_keys_out,                    ///< [in] Output keys buffer
    const ValueT            *d_values_in,                   ///< [in] Input values buffer
    ValueT                  *d_values_out,                  ///< [in] Output values buffer
    const int               *d_begin_offsets,               ///< [in] %Device-accessible pointer to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup>...
    const int               *d_end_offsets,                 ///< [in] %Device-accessible pointer to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> dat...
    int                     /*num_segments*/,               ///< [in] The number of segments that comprise the sorting data
    int                     current_bit,                    ///< [in] Bit position of current radix digit
    int                     pass_bits)                      ///< [in] Number of bits of current radix digit
{
    //
    // Constants
    //

    typedef typename If<(ALT_DIGIT_BITS),
        typename ChainedPolicyT::ActivePolicy::AltSegmentedPolicy,
        typename ChainedPolicyT::ActivePolicy::SegmentedPolicy>::Type SegmentedPolicyT;

    enum
    {
        BLOCK_THREADS       = SegmentedPolicyT::BLOCK_THREADS,
        ITEMS_PER_THREAD    = SegmentedPolicyT::ITEMS_PER_THREAD,
        RADIX_BITS          = SegmentedPolicyT::RADIX_BITS,
        TILE_ITEMS          = BLOCK_THREADS * ITEMS_PER_THREAD,



( run in 2.515 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )