Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/src/learner.cc view on Meta::CPAN
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "./common/common.h"
#include "./common/io.h"
#include "./common/random.h"
namespace xgboost {
// implementation of base learner.
bool Learner::AllowLazyCheckPoint() const {
return gbm_->AllowLazyCheckPoint();
}
std::vector<std::string> Learner::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
return gbm_->DumpModel(fmap, with_stats, format);
}
/*! \brief training parameter for regression */
struct LearnerModelParam : public dmlc::Parameter<LearnerModelParam> {
/* \brief global bias */
bst_float base_score;
/* \brief number of features */
unsigned num_feature;
/* \brief number of classes, if it is multi-class classification */
int num_class;
/*! \brief Model contain additional properties */
int contain_extra_attrs;
/*! \brief Model contain eval metrics */
int contain_eval_metrics;
/*! \brief reserved field */
int reserved[29];
/*! \brief constructor */
LearnerModelParam() {
std::memset(this, 0, sizeof(LearnerModelParam));
base_score = 0.5f;
}
// declare parameters
DMLC_DECLARE_PARAMETER(LearnerModelParam) {
DMLC_DECLARE_FIELD(base_score)
.set_default(0.5f)
.describe("Global bias of the model.");
DMLC_DECLARE_FIELD(num_feature)
.set_default(0)
.describe(
"Number of features in training data,"
" this parameter will be automatically detected by learner.");
DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe(
"Number of class option for multi-class classifier. "
" By default equals 0 and corresponds to binary classifier.");
}
};
struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
// stored random seed
int seed;
// whether seed the PRNG each iteration
bool seed_per_iteration;
// data split mode, can be row, col, or none.
int dsplit;
// tree construction method
int tree_method;
// internal test flag
std::string test_flag;
// maximum buffered row value
float prob_buffer_row;
// maximum row per batch.
size_t max_row_perbatch;
// number of threads to use if OpenMP is enabled
// if equals 0, use system default
int nthread;
// flag to print out detailed breakdown of runtime
int debug_verbose;
// declare parameters
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
"Random number seed during training.");
DMLC_DECLARE_FIELD(seed_per_iteration)
.set_default(false)
.describe(
"Seed PRNG determnisticly via iterator number, "
"this option will be switched on automatically on distributed "
"mode.");
DMLC_DECLARE_FIELD(dsplit)
.set_default(0)
.add_enum("auto", 0)
.add_enum("col", 1)
.add_enum("row", 2)
.describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(tree_method)
.set_default(0)
.add_enum("auto", 0)
.add_enum("approx", 1)
.add_enum("exact", 2)
.add_enum("hist", 3)
.add_enum("gpu_exact", 4)
.add_enum("gpu_hist", 5)
.describe("Choice of tree construction method.");
DMLC_DECLARE_FIELD(test_flag).set_default("").describe(
"Internal test flag");
DMLC_DECLARE_FIELD(prob_buffer_row)
.set_default(1.0f)
.set_range(0.0f, 1.0f)
.describe("Maximum buffered row portion");
DMLC_DECLARE_FIELD(max_row_perbatch)
.set_default(std::numeric_limits<size_t>::max())
.describe("maximum row per batch.");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
DMLC_DECLARE_FIELD(debug_verbose)
.set_lower_bound(0)
.set_default(0)
.describe("flag to print out detailed breakdown of runtime");
}
};
DMLC_REGISTER_PARAMETER(LearnerModelParam);
DMLC_REGISTER_PARAMETER(LearnerTrainParam);
/*!
* \brief learner that performs gradient boosting for a specific objective
* function. It does training and prediction.
*/
class LearnerImpl : public Learner {
public:
explicit LearnerImpl(const std::vector<std::shared_ptr<DMatrix> >& cache)
: cache_(cache) {
// boosted tree
name_obj_ = "reg:linear";
name_gbm_ = "gbtree";
}
void ConfigureUpdaters() {
if (tparam.tree_method == 0 || tparam.tree_method == 1 ||
tparam.tree_method == 2) {
if (cfg_.count("updater") == 0) {
if (tparam.dsplit == 1) {
cfg_["updater"] = "distcol";
} else if (tparam.dsplit == 2) {
cfg_["updater"] = "grow_histmaker,prune";
}
if (tparam.prob_buffer_row != 1.0f) {
cfg_["updater"] = "grow_histmaker,refresh,prune";
}
}
} else if (tparam.tree_method == 3) {
/* histogram-based algorithm */
LOG(CONSOLE) << "Tree method is selected to be \'hist\', which uses a "
"single updater "
( run in 1.175 second using v1.01-cache-2.11-cpan-efa8479b9fe )