Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/rabit/src/allreduce_base.h  view on Meta::CPAN

/*!
 *  Copyright (c) 2014 by Contributors
 * \file allreduce_base.h
 * \brief Basic implementation of AllReduce
 *   using TCP non-block socket and tree-shape reduction.
 *
 *   This implementation provides basic utility of AllReduce and Broadcast
 *   without considering node failure
 *
 * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
 */
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_

#include <vector>
#include <string>
#include <algorithm>
#include "../include/rabit/internal/utils.h"
#include "../include/rabit/internal/engine.h"
#include "./socket.h"

namespace MPI {
// MPI data type to be compatible with existing MPI interface
class Datatype {
 public:
  size_t type_size;
  explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
namespace engine {
/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
 public:
  // magic number to verify server
  static const int kMagic = 0xff99;
  // constant one byte out of band message to indicate error happening
  AllreduceBase(void);
  virtual ~AllreduceBase(void) {}
  // initialize the manager
  virtual void Init(int argc, char* argv[]);
  // shutdown the engine
  virtual void Shutdown(void);
  /*!
   * \brief set parameters to the engine
   * \param name parameter name
   * \param val parameter value
   */
  virtual void SetParam(const char *name, const char *val);
  /*!
   * \brief print the msg in the tracker,
   *    this function can be used to communicate the information of the progress to
   *    the user who monitors the tracker
   * \param msg message to be printed in the tracker
   */
  virtual void TrackerPrint(const std::string &msg);
  /*! \brief get rank */
  virtual int GetRank(void) const {
    return rank;
  }
  /*! \brief get rank */
  virtual int GetWorldSize(void) const {
    if (world_size == -1) return 1;
    return world_size;
  }
  /*! \brief whether is distributed or not */
  virtual bool IsDistributed(void) const {
    return tracker_uri != "NULL";
  }
  /*! \brief get rank */
  virtual std::string GetHost(void) const {
    return host_uri;
  }
  /*!
   * \brief perform in-place allreduce, on sendrecvbuf
   *        this function is NOT thread-safe
   * \param sendrecvbuf_ buffer for both sending and recving data
   * \param type_nbytes the unit number of bytes the type have
   * \param count number of elements to be reduced
   * \param reducer reduce function
   * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
   *                     will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
   *                     If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
   * \param prepare_arg argument used to passed into the lazy preprocessing function
   */
  virtual void Allreduce(void *sendrecvbuf_,
                         size_t type_nbytes,
                         size_t count,
                         ReduceFunction reducer,
                         PreprocFunction prepare_fun = NULL,
                         void *prepare_arg = NULL) {
    if (prepare_fun != NULL) prepare_fun(prepare_arg);
    if (world_size == 1 || world_size == -1) return;
    utils::Assert(TryAllreduce(sendrecvbuf_,
                               type_nbytes, count, reducer) == kSuccess,



( run in 0.583 second using v1.01-cache-2.11-cpan-13bb782fe5a )