Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/agent/agent_select_if.cuh view on Meta::CPAN
#include "../block/block_discontinuity.cuh"
#include "../grid/grid_queue.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
#include "../util_namespace.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/******************************************************************************
* Tuning policy types
******************************************************************************/
/**
* Parameterizable tuning policy type for AgentSelectIf
*/
template <
int _BLOCK_THREADS, ///< Threads per thread block
int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use
CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements
BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use
struct AgentSelectIfPolicy
{
enum
{
BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block
ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
};
static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use
static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements
static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use
};
/******************************************************************************
* Thread block abstractions
******************************************************************************/
/**
* \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection
*
* Performs functor-based selection if SelectOpT functor type != NullType
* Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType
* Otherwise performs discontinuity selection (keep unique)
*/
template <
typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type
typename InputIteratorT, ///< Random-access input iterator type for selection items
typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection)
typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items
typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection)
typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection)
typename OffsetT, ///< Signed integer type for global offsets
bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output
struct AgentSelectIf
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
// The input value type
typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
// The output value type
typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The flag value type
typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT;
// Tile status descriptor interface type
typedef ScanTileState<OffsetT> ScanTileStateT;
// Constants
enum
{
USE_SELECT_OP,
USE_SELECT_FLAGS,
USE_DISCONTINUITY,
BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1),
SELECT_METHOD = (!Equals<SelectOpT, NullType>::VALUE) ?
USE_SELECT_OP :
(!Equals<FlagT, NullType>::VALUE) ?
USE_SELECT_FLAGS :
USE_DISCONTINUITY
};
// Cache-modified Input iterator wrapper type (for applying cache modifier) for items
typedef typename If<IsPointer<InputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
InputIteratorT>::Type // Directly use the supplied input iterator type
WrappedInputIteratorT;
// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
FlagsInputIteratorT>::Type // Directly use the supplied input iterator type
WrappedFlagsInputIteratorT;
// Parameterized BlockLoad type for input data
typedef BlockLoad<
OutputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>
BlockLoadT;
xgboost/cub/cub/agent/agent_select_if.cuh view on Meta::CPAN
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile
Int2Type<false> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition
{
CTA_SYNC();
// Compact and scatter items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix;
if (selection_flags[ITEM])
{
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}
}
CTA_SYNC();
for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS)
{
d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item];
}
}
/**
* Scatter flagged items to output offsets (specialized for two-phase scattering)
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
int num_tile_selections, ///< Number of selections in this tile
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile
Int2Type<true> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition
{
CTA_SYNC();
int tile_num_rejections = num_tile_items - num_tile_selections;
// Scatter items to shared memory (rejections first)
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
int local_selection_idx = selection_indices[ITEM] - num_selections_prefix;
int local_rejection_idx = item_idx - local_selection_idx;
int local_scatter_offset = (selection_flags[ITEM]) ?
tile_num_rejections + local_selection_idx :
local_rejection_idx;
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}
CTA_SYNC();
// Gather items from shared memory and scatter to global
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset = (item_idx < tile_num_rejections) ?
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;
OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
d_selected_out[scatter_offset] = item;
}
}
}
/**
* Scatter flagged items
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void Scatter(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
int num_tile_selections, ///< Number of selections in this tile
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile
OffsetT num_selections) ///< Total number of selections including this tile
{
// Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one
if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)))
{
ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
Int2Type<KEEP_REJECTS>());
}
else
{
ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_selections);
}
}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
( run in 0.536 second using v1.01-cache-2.11-cpan-39bf76dae61 )