Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/cub/cub/device/dispatch/dispatch_spmv_orig.cuh view on Meta::CPAN
}
}
/**
* Spmv agent entry point
*/
template <
typename SpmvPolicyT, ///< Parameterized SpmvPolicy tuning policy type
typename ScanTileStateT, ///< Tile status interface type
typename ValueT, ///< Matrix and vector value type
typename OffsetT, ///< Signed integer type for sequence offsets
typename CoordinateT, ///< Merge path coordinate type
bool HAS_ALPHA, ///< Whether the input parameter Alpha is 1
bool HAS_BETA> ///< Whether the input parameter Beta is 0
__launch_bounds__ (int(SpmvPolicyT::BLOCK_THREADS))
__global__ void DeviceSpmvKernel(
SpmvParams<ValueT, OffsetT> spmv_params, ///< [in] SpMV input parameter bundle
CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates
KeyValuePair<OffsetT,ValueT>* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block
int num_tiles, ///< [in] Number of merge tiles
ScanTileStateT tile_state, ///< [in] Tile status interface for fixup reduce-by-key kernel
int num_segment_fixup_tiles) ///< [in] Number of reduce-by-key tiles (fixup grid size)
{
// Spmv agent type specialization
typedef AgentSpmv<
SpmvPolicyT,
ValueT,
OffsetT,
HAS_ALPHA,
HAS_BETA>
AgentSpmvT;
// Shared memory for AgentSpmv
__shared__ typename AgentSpmvT::TempStorage temp_storage;
AgentSpmvT(temp_storage, spmv_params).ConsumeTile(
d_tile_coordinates,
d_tile_carry_pairs,
num_tiles);
// Initialize fixup tile status
tile_state.InitializeStatus(num_segment_fixup_tiles);
}
/**
* Multi-block reduce-by-key sweep kernel entry point
*/
template <
typename AgentSegmentFixupPolicyT, ///< Parameterized AgentSegmentFixupPolicy tuning policy type
typename PairsInputIteratorT, ///< Random-access input iterator type for keys
typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values
typename OffsetT, ///< Signed integer type for global offsets
typename ScanTileStateT> ///< Tile status interface type
__launch_bounds__ (int(AgentSegmentFixupPolicyT::BLOCK_THREADS))
__global__ void DeviceSegmentFixupKernel(
PairsInputIteratorT d_pairs_in, ///< [in] Pointer to the array carry-out dot product row-ids, one per spmv block
AggregatesOutputIteratorT d_aggregates_out, ///< [in,out] Output value aggregates
OffsetT num_items, ///< [in] Total number of items to select from
int num_tiles, ///< [in] Total number of tiles for the entire problem
ScanTileStateT tile_state) ///< [in] Tile status interface
{
// Thread block type for reducing tiles of value segments
typedef AgentSegmentFixup<
AgentSegmentFixupPolicyT,
PairsInputIteratorT,
AggregatesOutputIteratorT,
cub::Equality,
cub::Sum,
OffsetT>
AgentSegmentFixupT;
// Shared memory for AgentSegmentFixup
__shared__ typename AgentSegmentFixupT::TempStorage temp_storage;
// Process tiles
AgentSegmentFixupT(temp_storage, d_pairs_in, d_aggregates_out, cub::Equality(), cub::Sum()).ConsumeRange(
num_items,
num_tiles,
tile_state);
}
/******************************************************************************
* Dispatch
******************************************************************************/
/**
* Utility class for dispatching the appropriately-tuned kernels for DeviceSpmv
*/
template <
typename ValueT, ///< Matrix and vector value type
typename OffsetT> ///< Signed integer type for global offsets
struct DispatchSpmv
{
//---------------------------------------------------------------------
// Constants and Types
//---------------------------------------------------------------------
enum
{
INIT_KERNEL_THREADS = 128
};
// SpmvParams bundle type
typedef SpmvParams<ValueT, OffsetT> SpmvParamsT;
// 2D merge path coordinate type
typedef typename CubVector<OffsetT, 2>::Type CoordinateT;
// Tile status descriptor interface type
typedef ReduceByKeyScanTileState<ValueT, OffsetT> ScanTileStateT;
// Tuple type for scanning (pairs accumulated segment-value with segment-index)
typedef KeyValuePair<OffsetT, ValueT> KeyValuePairT;
//---------------------------------------------------------------------
// Tuning policies
( run in 1.140 second using v1.01-cache-2.11-cpan-13bb782fe5a )