Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/R-package/R/callbacks.R view on Meta::CPAN
# run some checks in the begining
init <- function(env) {
nrounds <<- env$end_iteration - env$begin_iteration + 1
if (is.null(env$bst) && is.null(env$bst_folds))
stop("Parent frame has neither 'bst' nor 'bst_folds'")
# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
not_allowed <- pnames %in%
c('num_class', 'num_output_group', 'size_leaf_vector', 'updater_seq')
if (any(not_allowed))
stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.")
for (n in pnames) {
p <- new_params[[n]]
if (is.function(p)) {
if (length(formals(p)) != 2)
stop("Parameter '", n, "' is a function but not of two arguments")
} else if (is.numeric(p) || is.character(p)) {
if (length(p) != nrounds)
stop("Length of '", n, "' has to be equal to 'nrounds'")
} else {
stop("Parameter '", n, "' is not a function or a vector")
}
}
}
callback <- function(env = parent.frame()) {
if (is.null(nrounds))
init(env)
i <- env$iteration
pars <- lapply(new_params, function(p) {
if (is.function(p))
return(p(i, nrounds))
p[i]
})
if (!is.null(env$bst)) {
xgb.parameters(env$bst$handle) <- pars
} else {
for (fd in env$bst_folds)
xgb.parameters(fd$bst) <- pars
}
}
attr(callback, 'is_pre_iteration') <- TRUE
attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.reset.parameters'
callback
}
#' Callback closure to activate the early stopping.
#'
#' @param stopping_rounds The number of rounds with no improvement in
#' the evaluation metric in order to stop the training.
#' @param maximize whether to maximize the evaluation metric
#' @param metric_name the name of an evaluation column to use as a criteria for early
#' stopping. If not set, the last column would be used.
#' Let's say the test data in \code{watchlist} was labelled as \code{dtest},
#' and one wants to use the AUC in test data for early stopping regardless of where
#' it is in the \code{watchlist}, then one of the following would need to be set:
#' \code{metric_name='dtest-auc'} or \code{metric_name='dtest_auc'}.
#' All dash '-' characters in metric names are considered equivalent to '_'.
#' @param verbose whether to print the early stopping information.
#'
#' @details
#' This callback function determines the condition for early stopping
#' by setting the \code{stop_condition = TRUE} flag in its calling frame.
#'
#' The following additional fields are assigned to the model's R object:
#' \itemize{
#' \item \code{best_score} the evaluation score at the best iteration
#' \item \code{best_iteration} at which boosting iteration the best score has occurred (1-based index)
#' \item \code{best_ntreelimit} to use with the \code{ntreelimit} parameter in \code{predict}.
#' It differs from \code{best_iteration} in multiclass or random forest settings.
#' }
#'
#' The Same values are also stored as xgb-attributes:
#' \itemize{
#' \item \code{best_iteration} is stored as a 0-based iteration index (for interoperability of binary models)
#' \item \code{best_msg} message string is also stored.
#' }
#'
#' At least one data element is required in the evaluation watchlist for early stopping to work.
#'
#' Callback function expects the following values to be set in its calling frame:
#' \code{stop_condition},
#' \code{bst_evaluation},
#' \code{rank},
#' \code{bst} (or \code{bst_folds} and \code{basket}),
#' \code{iteration},
#' \code{begin_iteration},
#' \code{end_iteration},
#' \code{num_parallel_tree}.
#'
#' @seealso
#' \code{\link{callbacks}},
#' \code{\link{xgb.attr}}
#'
#' @export
cb.early.stop <- function(stopping_rounds, maximize = FALSE,
metric_name = NULL, verbose = TRUE) {
# state variables
best_iteration <- -1
best_ntreelimit <- -1
best_score <- Inf
best_msg <- NULL
metric_idx <- 1
init <- function(env) {
if (length(env$bst_evaluation) == 0)
stop("For early stopping, watchlist must have at least one element")
eval_names <- gsub('-', '_', names(env$bst_evaluation))
if (!is.null(metric_name)) {
metric_idx <<- which(gsub('-', '_', metric_name) == eval_names)
if (length(metric_idx) == 0)
stop("'metric_name' for early stopping is not one of the following:\n",
paste(eval_names, collapse = ' '), '\n')
( run in 0.325 second using v1.01-cache-2.11-cpan-d7a12ab2c7f )