Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/block/block_exchange.cuh view on Meta::CPAN
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToBlocked(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStripedGuarded(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStripedGuarded(items, items, ranks);
}
template <typename OffsetT, typename ValidFlag>
__device__ __forceinline__ void ScatterToStripedFlagged(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
ScatterToStriped(items, items, ranks, is_valid);
}
#endif // DOXYGEN_SHOULD_SKIP_THIS
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
template <
typename T,
int ITEMS_PER_THREAD,
int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
int PTX_ARCH = CUB_PTX_ARCH>
class WarpExchange
{
private:
/******************************************************************************
* Constants
******************************************************************************/
/// Constants
enum
{
// Whether the logical warp size and the PTX warp size coincide
IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
WARP_ITEMS = (ITEMS_PER_THREAD * LOGICAL_WARP_THREADS) + 1,
LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH),
SMEM_BANKS = 1 << LOG_SMEM_BANKS,
// Insert padding if the number of items per thread is a power of two and > 4 (otherwise we can typically use 128b loads)
INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo<ITEMS_PER_THREAD>::VALUE),
PADDING_ITEMS = (INSERT_PADDING) ? (WARP_ITEMS >> LOG_SMEM_BANKS) : 0,
};
/******************************************************************************
* Type definitions
******************************************************************************/
/// Shared memory storage layout type
struct _TempStorage
{
T buff[WARP_ITEMS + PADDING_ITEMS];
};
public:
/// \smemstorage{WarpExchange}
struct TempStorage : Uninitialized<_TempStorage> {};
private:
/******************************************************************************
* Thread fields
******************************************************************************/
_TempStorage &temp_storage;
int lane_id;
public:
/******************************************************************************
* Construction
******************************************************************************/
/// Constructor
__device__ __forceinline__ WarpExchange(
TempStorage &temp_storage)
:
temp_storage(temp_storage.Alias()),
lane_id(IS_ARCH_WARP ?
LaneId() :
LaneId() % LOGICAL_WARP_THREADS)
{}
/******************************************************************************
* Interface
******************************************************************************/
/**
* \brief Exchanges valid data items annotated by rank into <em>striped</em> arrangement.
( run in 2.088 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )