Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/R-package/R/xgb.importance.R  view on Meta::CPAN

#'                booster = "gblinear", eta = 0.2, nthread = 1, nrounds = 15,
#'                objective = "multi:softprob", num_class = nclass)
#' xgb.importance(model = mbst)
#'
#' @export
xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
                           data = NULL, label = NULL, target = NULL){
  
  if (!(is.null(data) && is.null(label) && is.null(target)))
    warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated")
  
  if (!inherits(model, "xgb.Booster"))
    stop("model: must be an object of class xgb.Booster")
  
  if (is.null(feature_names) && !is.null(model$feature_names))
    feature_names <- model$feature_names
  
  if (!(is.null(feature_names) || is.character(feature_names)))
    stop("feature_names: Has to be a character vector")

  model_text_dump <- xgb.dump(model = model, with_stats = TRUE)
  
  # linear model
  if(model_text_dump[2] == "bias:"){
    weights <- which(model_text_dump == "weight:") %>%
               {model_text_dump[(. + 1):length(model_text_dump)]} %>%
               as.numeric
    
    num_class <- NVL(model$params$num_class, 1)
    if(is.null(feature_names)) 
      feature_names <- seq(to = length(weights) / num_class) - 1
    if (length(feature_names) * num_class != length(weights))
      stop("feature_names length does not match the number of features used in the model")
    
    result <- if (num_class == 1) {
      data.table(Feature = feature_names, Weight = weights)[order(-abs(Weight))]
    } else {
      data.table(Feature = rep(feature_names, each = num_class),
                 Weight = weights,
                 Class = seq_len(num_class) - 1)[order(Class, -abs(Weight))]
    }
  } else { 
  # tree model
    result <- xgb.model.dt.tree(feature_names = feature_names,
                                text = model_text_dump,
                                trees = trees)[
      Feature != "Leaf", .(Gain = sum(Quality), 
                           Cover = sum(Cover), 
                           Frequency = .N), by = Feature][
      ,`:=`(Gain = Gain / sum(Gain), 
            Cover = Cover / sum(Cover),
            Frequency = Frequency / sum(Frequency))][
      order(Gain, decreasing = TRUE)]
  }
  result
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(c(".", ".N", "Gain", "Cover", "Frequency", "Feature"))



( run in 2.480 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )