Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/block/block_radix_sort.cuh view on Meta::CPAN
template <int IS_BLOCKED>
__device__ __forceinline__ void ExchangeValues(
ValueT (&/*values*/)[ITEMS_PER_THREAD],
int (&/*ranks*/)[ITEMS_PER_THREAD],
Int2Type<true> /*is_keys_only*/,
Int2Type<IS_BLOCKED> /*is_blocked*/)
{}
/// Sort blocked arrangement
template <int DESCENDING, int KEYS_ONLY>
__device__ __forceinline__ void SortBlocked(
KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort
ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort
int begin_bit, ///< The beginning (least-significant) bit index needed for key comparison
int end_bit, ///< The past-the-end (most-significant) bit index needed for key comparison
Int2Type<DESCENDING> is_descending, ///< Tag whether is a descending-order sort
Int2Type<KEYS_ONLY> is_keys_only) ///< Tag whether is keys-only sort
{
UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
// Twiddle bits if necessary
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
{
unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
}
// Radix sorting passes
while (true)
{
int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
// Rank the blocked keys
int ranks[ITEMS_PER_THREAD];
RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
begin_bit += RADIX_BITS;
CTA_SYNC();
// Exchange keys through shared memory in blocked arrangement
BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
// Exchange values through shared memory in blocked arrangement
ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
// Quit if done
if (begin_bit >= end_bit) break;
CTA_SYNC();
}
// Untwiddle bits if necessary
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
{
unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
}
}
public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
/// Sort blocked -> striped arrangement
template <int DESCENDING, int KEYS_ONLY>
__device__ __forceinline__ void SortBlockedToStriped(
KeyT (&keys)[ITEMS_PER_THREAD], ///< Keys to sort
ValueT (&values)[ITEMS_PER_THREAD], ///< Values to sort
int begin_bit, ///< The beginning (least-significant) bit index needed for key comparison
int end_bit, ///< The past-the-end (most-significant) bit index needed for key comparison
Int2Type<DESCENDING> is_descending, ///< Tag whether is a descending-order sort
Int2Type<KEYS_ONLY> is_keys_only) ///< Tag whether is keys-only sort
{
UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
// Twiddle bits if necessary
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
{
unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
}
// Radix sorting passes
while (true)
{
int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
// Rank the blocked keys
int ranks[ITEMS_PER_THREAD];
RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
begin_bit += RADIX_BITS;
CTA_SYNC();
// Check if this is the last pass
if (begin_bit >= end_bit)
{
// Last pass exchanges keys through shared memory in striped arrangement
BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks);
// Last pass exchanges through shared memory in striped arrangement
ExchangeValues(values, ranks, is_keys_only, Int2Type<false>());
// Quit
break;
}
// Exchange keys through shared memory in blocked arrangement
BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
// Exchange values through shared memory in blocked arrangement
ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
CTA_SYNC();
}
// Untwiddle bits if necessary
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
( run in 0.538 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )