Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/device/dispatch/dispatch_radix_sort.cuh view on Meta::CPAN
// BlockRadixSort type
typedef BlockRadixSort<
KeyT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
ValueT,
ChainedPolicyT::ActivePolicy::SingleTilePolicy::RADIX_BITS,
ChainedPolicyT::ActivePolicy::SingleTilePolicy::MEMOIZE_OUTER_SCAN,
ChainedPolicyT::ActivePolicy::SingleTilePolicy::INNER_SCAN_ALGORITHM>
BlockRadixSortT;
// BlockLoad type (keys)
typedef BlockLoad<
KeyT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadKeys;
// BlockLoad type (values)
typedef BlockLoad<
ValueT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
ChainedPolicyT::ActivePolicy::SingleTilePolicy::LOAD_ALGORITHM> BlockLoadValues;
// Unsigned word for key bits
typedef typename Traits<KeyT>::UnsignedBits UnsignedBitsT;
// Shared memory storage
__shared__ union
{
typename BlockRadixSortT::TempStorage sort;
typename BlockLoadKeys::TempStorage load_keys;
typename BlockLoadValues::TempStorage load_values;
} temp_storage;
// Keys and values for the block
KeyT keys[ITEMS_PER_THREAD];
ValueT values[ITEMS_PER_THREAD];
// Get default (min/max) value for out-of-bounds keys
UnsignedBitsT default_key_bits = (IS_DESCENDING) ? Traits<KeyT>::LOWEST_KEY : Traits<KeyT>::MAX_KEY;
KeyT default_key = reinterpret_cast<KeyT&>(default_key_bits);
// Load keys
BlockLoadKeys(temp_storage.load_keys).Load(d_keys_in, keys, num_items, default_key);
CTA_SYNC();
// Load values
if (!KEYS_ONLY)
{
BlockLoadValues(temp_storage.load_values).Load(d_values_in, values, num_items);
CTA_SYNC();
}
// Sort tile
BlockRadixSortT(temp_storage.sort).SortBlockedToStriped(
keys,
values,
current_bit,
end_bit,
Int2Type<IS_DESCENDING>(),
Int2Type<KEYS_ONLY>());
// Store keys and values
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_offset = ITEM * BLOCK_THREADS + threadIdx.x;
if (item_offset < num_items)
{
d_keys_out[item_offset] = keys[ITEM];
if (!KEYS_ONLY)
d_values_out[item_offset] = values[ITEM];
}
}
}
/**
* Segmented radix sorting pass (one block per segment)
*/
template <
typename ChainedPolicyT, ///< Chained tuning policy
bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
typename KeyT, ///< Key type
typename ValueT, ///< Value type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS :
ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS))
__global__ void DeviceSegmentedRadixSortKernel(
const KeyT *d_keys_in, ///< [in] Input keys buffer
KeyT *d_keys_out, ///< [in] Output keys buffer
const ValueT *d_values_in, ///< [in] Input values buffer
ValueT *d_values_out, ///< [in] Output values buffer
const int *d_begin_offsets, ///< [in] %Device-accessible pointer to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup>...
const int *d_end_offsets, ///< [in] %Device-accessible pointer to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> dat...
int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data
int current_bit, ///< [in] Bit position of current radix digit
int pass_bits) ///< [in] Number of bits of current radix digit
{
//
// Constants
//
typedef typename If<(ALT_DIGIT_BITS),
typename ChainedPolicyT::ActivePolicy::AltSegmentedPolicy,
typename ChainedPolicyT::ActivePolicy::SegmentedPolicy>::Type SegmentedPolicyT;
enum
{
BLOCK_THREADS = SegmentedPolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = SegmentedPolicyT::ITEMS_PER_THREAD,
RADIX_BITS = SegmentedPolicyT::RADIX_BITS,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
( run in 2.515 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )