Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/nccl/src/primitives.h view on Meta::CPAN
/*************************************************************************
* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef PRIMITIVES_H_
#define PRIMITIVES_H_
#include <type_traits>
#include "copy_kernel.h" // for FuncPassA
#include "reduce_kernel.h" // for reduction funcs
/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
*
* In order to reduce the reptetion of template arguments, the operations
* are bundled as static methods of the Primitives class.
*
* Each primitive operation copies/reduces a contiguous buffer and syncs
* an optional set of flags against a sub-step counter. The sync value is
* based on the step parameter. Sync flags must be of type WaitFlag or
* PostFlag. The primitive routines wait for all WaitFlag args to attain
* at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
* corresponding substep by previous step) before executing the transfer.
* After each substep is transfered, all PostFlag arguments get updated to
* the value SUBSTEPS*step+substep+1.
*/
class WaitFlag {
volatile int * const flag;
const int shift;
public:
__device__ __forceinline__
WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { }
__device__ __forceinline__
void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; }
};
class PostFlag {
volatile int * const flag;
const int shift;
public:
__device__ __forceinline__
PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { }
__device__ __forceinline__
void post(int val) { *flag = (val + shift); }
};
// Helper to check if any argument is of type T.
// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
template<typename T> __device__ __forceinline__
bool AnyAre() { return false; }
template<typename T, typename FIRST_T, typename... TAIL_Ts>
__device__ __forceinline__
bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
}
// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
void WaitOnFlags(int val) { }
template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
flag.wait(val);
WaitOnFlags(val, tail...);
}
template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) {
WaitOnFlags(val, tail...);
}
// Post all PostFlags, ingnore WaitFlags
__device__ __forceinline__
void PostToFlags(int val) { }
template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
( run in 1.356 second using v1.01-cache-2.11-cpan-39bf76dae61 )