Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/agent/agent_rle.cuh view on Meta::CPAN
template <typename Index>
__device__ __forceinline__ bool operator()(T first, T second, Index idx)
{
if (!LAST_TILE || (idx < num_remaining))
return !equality_op(first, second);
else
return true;
}
};
// Cache-modified Input iterator wrapper type (for applying cache modifier) for data
typedef typename If<IsPointer<InputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>, // Wrap the native input pointer with CacheModifiedVLengthnputIterator
InputIteratorT>::Type // Directly use the supplied input iterator type
WrappedInputIteratorT;
// Parameterized BlockLoad type for data
typedef BlockLoad<
T,
AgentRlePolicyT::BLOCK_THREADS,
AgentRlePolicyT::ITEMS_PER_THREAD,
AgentRlePolicyT::LOAD_ALGORITHM>
BlockLoadT;
// Parameterized BlockDiscontinuity type for data
typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT;
// Parameterized WarpScan type
typedef WarpScan<LengthOffsetPair> WarpScanPairs;
// Reduce-length-by-run scan operator
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
// Callback type for obtaining tile prefix during block scan
typedef TilePrefixCallbackOp<
LengthOffsetPair,
ReduceBySegmentOpT,
ScanTileStateT>
TilePrefixCallbackOpT;
// Warp exchange types
typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD> WarpExchangePairs;
typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage;
typedef WarpExchange<OffsetT, ITEMS_PER_THREAD> WarpExchangeOffsets;
typedef WarpExchange<LengthT, ITEMS_PER_THREAD> WarpExchangeLengths;
typedef LengthOffsetPair WarpAggregates[WARPS];
// Shared memory type for this threadblock
struct _TempStorage
{
union
{
struct
{
typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection
typename WarpScanPairs::TempStorage warp_scan[WARPS]; // Smem needed for warp-synchronous scans
Uninitialized<LengthOffsetPair[WARPS]> warp_aggregates; // Smem needed for sharing warp-wide aggregates
typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback
};
// Smem needed for input loading
typename BlockLoadT::TempStorage load;
// Smem needed for two-phase scatter
union
{
unsigned long long align;
WarpExchangePairsStorage exchange_pairs[ACTIVE_EXCHANGE_WARPS];
typename WarpExchangeOffsets::TempStorage exchange_offsets[ACTIVE_EXCHANGE_WARPS];
typename WarpExchangeLengths::TempStorage exchange_lengths[ACTIVE_EXCHANGE_WARPS];
};
};
OffsetT tile_idx; // Shared tile index
LengthOffsetPair tile_inclusive; // Inclusive tile prefix
LengthOffsetPair tile_exclusive; // Exclusive tile prefix
};
// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Pointer to input sequence of data items
OffsetsOutputIteratorT d_offsets_out; ///< Input run offsets
LengthsOutputIteratorT d_lengths_out; ///< Output run lengths
EqualityOpT equality_op; ///< T equality operator
ReduceBySegmentOpT scan_op; ///< Reduce-length-by-flag scan operator
OffsetT num_items; ///< Total number of input items
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Constructor
__device__ __forceinline__
AgentRle(
TempStorage &temp_storage, ///< [in] Reference to temp_storage
InputIteratorT d_in, ///< [in] Pointer to input sequence of data items
OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run offsets
LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run lengths
EqualityOpT equality_op, ///< [in] T equality operator
OffsetT num_items) ///< [in] Total number of input items
:
temp_storage(temp_storage.Alias()),
d_in(d_in),
d_offsets_out(d_offsets_out),
d_lengths_out(d_lengths_out),
equality_op(equality_op),
( run in 0.698 second using v1.01-cache-2.11-cpan-39bf76dae61 )