Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/rabit/src/allreduce_base.h view on Meta::CPAN
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation provides basic utility of AllReduce and Broadcast
* without considering node failure
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <vector>
#include <string>
#include <algorithm>
#include "../include/rabit/internal/utils.h"
#include "../include/rabit/internal/engine.h"
#include "./socket.h"
namespace MPI {
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
namespace engine {
/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
public:
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
}
/*! \brief get rank */
virtual int GetWorldSize(void) const {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
return tracker_uri != "NULL";
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
return host_uri;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
( run in 0.583 second using v1.01-cache-2.11-cpan-13bb782fe5a )