Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/python-package/xgboost/training.py view on Meta::CPAN
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
# pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines."""
from __future__ import absolute_import
import warnings
import numpy as np
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit
from . import callback
def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
xgb_model=None, callbacks=None):
"""internal training function"""
callbacks = [] if callbacks is None else callbacks
evals = list(evals)
if isinstance(params, dict) \
and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
num_parallel_tree = 1
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw()
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
nboost = len(bst.get_dump())
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert(rabit.get_world_size() != 1 or version == 0)
rank = rabit.get_rank()
start_iteration = int(version / 2)
nboost += start_iteration
callbacks_before_iter = [
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
callbacks_after_iter = [
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
( run in 0.602 second using v1.01-cache-2.11-cpan-63c85eba8c4 )