Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/block/block_exchange.cuh view on Meta::CPAN
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(input_items, output_items, ranks, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Exchanges data items annotated by rank into <em>striped</em> arrangement. Items with rank -1 are not exchanged.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToStripedGuarded(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
if (ranks[ITEM] >= 0)
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* \brief Exchanges valid data items annotated by rank into <em>striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
* \tparam ValidFlag <b>[inferred]</b> FlagT type denoting which items are valid
*/
template <typename OutputT, typename OffsetT, typename ValidFlag>
__device__ __forceinline__ void ScatterToStripedFlagged(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
if (is_valid[ITEM])
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
//@} end member group
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
__device__ __forceinline__ void StripedToBlocked(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
StripedToBlocked(items, items);
}
__device__ __forceinline__ void BlockedToStriped(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToStriped(items, items);
}
__device__ __forceinline__ void WarpStripedToBlocked(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
WarpStripedToBlocked(items, items);
}
__device__ __forceinline__ void BlockedToWarpStriped(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToWarpStriped(items, items);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToBlocked(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStripedGuarded(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStripedGuarded(items, items, ranks);
}
template <typename OffsetT, typename ValidFlag>
__device__ __forceinline__ void ScatterToStripedFlagged(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
ScatterToStriped(items, items, ranks, is_valid);
}
#endif // DOXYGEN_SHOULD_SKIP_THIS
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
template <
typename T,
int ITEMS_PER_THREAD,
int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
int PTX_ARCH = CUB_PTX_ARCH>
class WarpExchange
{
private:
/******************************************************************************
* Constants
******************************************************************************/
/// Constants
enum
{
// Whether the logical warp size and the PTX warp size coincide
IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
WARP_ITEMS = (ITEMS_PER_THREAD * LOGICAL_WARP_THREADS) + 1,
LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH),
SMEM_BANKS = 1 << LOG_SMEM_BANKS,
// Insert padding if the number of items per thread is a power of two and > 4 (otherwise we can typically use 128b loads)
INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo<ITEMS_PER_THREAD>::VALUE),
PADDING_ITEMS = (INSERT_PADDING) ? (WARP_ITEMS >> LOG_SMEM_BANKS) : 0,
};
/******************************************************************************
* Type definitions
******************************************************************************/
/// Shared memory storage layout type
struct _TempStorage
{
T buff[WARP_ITEMS + PADDING_ITEMS];
};
public:
/// \smemstorage{WarpExchange}
struct TempStorage : Uninitialized<_TempStorage> {};
private:
( run in 0.679 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )