Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/R-package/R/callbacks.R  view on Meta::CPAN

  # run some checks in the begining
  init <- function(env) {
    nrounds <<- env$end_iteration - env$begin_iteration + 1
    
    if (is.null(env$bst) && is.null(env$bst_folds))
      stop("Parent frame has neither 'bst' nor 'bst_folds'")
    
    # Some parameters are not allowed to be changed,
    # since changing them would simply wreck some chaos
    not_allowed <- pnames %in% 
      c('num_class', 'num_output_group', 'size_leaf_vector', 'updater_seq')
    if (any(not_allowed))
      stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.")
    
    for (n in pnames) {
      p <- new_params[[n]]
      if (is.function(p)) {
        if (length(formals(p)) != 2)
          stop("Parameter '", n, "' is a function but not of two arguments")
      } else if (is.numeric(p) || is.character(p)) {
        if (length(p) != nrounds)
          stop("Length of '", n, "' has to be equal to 'nrounds'")
      } else {
        stop("Parameter '", n, "' is not a function or a vector")
      }
    }
  }
  
  callback <- function(env = parent.frame()) {
    if (is.null(nrounds))
      init(env)
    
    i <- env$iteration
    pars <- lapply(new_params, function(p) {
      if (is.function(p))
        return(p(i, nrounds))
      p[i]
    })
    
    if (!is.null(env$bst)) {
      xgb.parameters(env$bst$handle) <- pars
    } else {
      for (fd in env$bst_folds)
        xgb.parameters(fd$bst) <- pars
    }
  }
  attr(callback, 'is_pre_iteration') <- TRUE
  attr(callback, 'call') <- match.call()
  attr(callback, 'name') <- 'cb.reset.parameters'
  callback
}


#' Callback closure to activate the early stopping.
#' 
#' @param stopping_rounds The number of rounds with no improvement in 
#'        the evaluation metric in order to stop the training.
#' @param maximize whether to maximize the evaluation metric
#' @param metric_name the name of an evaluation column to use as a criteria for early
#'        stopping. If not set, the last column would be used.
#'        Let's say the test data in \code{watchlist} was labelled as \code{dtest}, 
#'        and one wants to use the AUC in test data for early stopping regardless of where 
#'        it is in the \code{watchlist}, then one of the following would need to be set:
#'        \code{metric_name='dtest-auc'} or \code{metric_name='dtest_auc'}.
#'        All dash '-' characters in metric names are considered equivalent to '_'.
#' @param verbose whether to print the early stopping information.
#' 
#' @details
#' This callback function determines the condition for early stopping 
#' by setting the \code{stop_condition = TRUE} flag in its calling frame.
#' 
#' The following additional fields are assigned to the model's R object:
#' \itemize{
#' \item \code{best_score} the evaluation score at the best iteration
#' \item \code{best_iteration} at which boosting iteration the best score has occurred (1-based index)
#' \item \code{best_ntreelimit} to use with the \code{ntreelimit} parameter in \code{predict}.
#'      It differs from \code{best_iteration} in multiclass or random forest settings.
#' }
#' 
#' The Same values are also stored as xgb-attributes:
#' \itemize{
#' \item \code{best_iteration} is stored as a 0-based iteration index (for interoperability of binary models)
#' \item \code{best_msg} message string is also stored.
#' }
#' 
#' At least one data element is required in the evaluation watchlist for early stopping to work.
#'
#' Callback function expects the following values to be set in its calling frame:
#' \code{stop_condition},
#' \code{bst_evaluation},
#' \code{rank},
#' \code{bst} (or \code{bst_folds} and \code{basket}),
#' \code{iteration},
#' \code{begin_iteration},
#' \code{end_iteration},
#' \code{num_parallel_tree}.
#' 
#' @seealso
#' \code{\link{callbacks}},
#' \code{\link{xgb.attr}}
#' 
#' @export
cb.early.stop <- function(stopping_rounds, maximize = FALSE, 
                          metric_name = NULL, verbose = TRUE) {
  # state variables
  best_iteration <- -1
  best_ntreelimit <- -1
  best_score <- Inf
  best_msg <- NULL
  metric_idx <- 1
  
  init <- function(env) {
    if (length(env$bst_evaluation) == 0)
      stop("For early stopping, watchlist must have at least one element")
    
    eval_names <- gsub('-', '_', names(env$bst_evaluation))
    if (!is.null(metric_name)) {
      metric_idx <<- which(gsub('-', '_', metric_name) == eval_names)
      if (length(metric_idx) == 0)
        stop("'metric_name' for early stopping is not one of the following:\n",
             paste(eval_names, collapse = ' '), '\n')



( run in 0.325 second using v1.01-cache-2.11-cpan-d7a12ab2c7f )