Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/nccl/src/primitives.h  view on Meta::CPAN

/*************************************************************************
 * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#ifndef PRIMITIVES_H_
#define PRIMITIVES_H_

#include <type_traits>
#include "copy_kernel.h" // for FuncPassA
#include "reduce_kernel.h" // for reduction funcs


/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
 *
 * In order to reduce the reptetion of template arguments, the operations
 * are bundled as static methods of the Primitives class.
 *
 * Each primitive operation copies/reduces a contiguous buffer and syncs
 * an optional set of flags against a sub-step counter. The sync value is
 * based on the step parameter. Sync flags must be of type WaitFlag or
 * PostFlag. The primitive routines wait for all WaitFlag args to attain
 * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
 * corresponding substep by previous step) before executing the transfer.
 * After each substep is transfered, all PostFlag arguments get updated to
 * the value SUBSTEPS*step+substep+1.
 */


class WaitFlag {
  volatile int * const flag;
  const int shift;
  public:
  __device__ __forceinline__
  WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { }
  __device__ __forceinline__
  void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; }
};


class PostFlag {
  volatile int * const flag;
  const int shift;
  public:
  __device__ __forceinline__
  PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { }
  __device__ __forceinline__
  void post(int val) { *flag = (val + shift); }
};


// Helper to check if any argument is of type T.
// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
template<typename T> __device__ __forceinline__
bool AnyAre() { return false; }

template<typename T, typename FIRST_T, typename... TAIL_Ts>
__device__ __forceinline__
bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
  return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
}


// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
void WaitOnFlags(int val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
  flag.wait(val);
  WaitOnFlags(val, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) {
  WaitOnFlags(val, tail...);
}


// Post all PostFlags, ingnore WaitFlags
__device__ __forceinline__
void PostToFlags(int val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) {



( run in 1.356 second using v1.01-cache-2.11-cpan-39bf76dae61 )