Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java  view on Meta::CPAN

/*
 Copyright (c) 2014 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.InputStream;
import java.util.*;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * trainer for xgboost
 *
 * @author hzx
 */
public class XGBoost {
  private static final Log logger = LogFactory.getLog(XGBoost.class);

  /**
   * load model from modelPath
   *
   * @param modelPath booster modelPath (model generated by booster.saveModel)
   * @throws XGBoostError native error
   */
  public static Booster loadModel(String modelPath)
          throws XGBoostError {
    return Booster.loadModel(modelPath);
  }

  /**
   * Load a new Booster model from a file opened as input stream.
   * The assumption is the input stream only contains one XGBoost Model.
   * This can be used to load existing booster models saved by other xgboost bindings.
   *
   * @param in The input stream of the file,
   *           will be closed after this function call.
   * @return The create boosted
   * @throws XGBoostError
   * @throws IOException
   */
  public static Booster loadModel(InputStream in)
          throws XGBoostError, IOException {
    return Booster.loadModel(in);
  }

  public static Booster train(
          DMatrix dtrain,
          Map<String, Object> params,
          int round,
          Map<String, DMatrix> watches,
          IObjective obj,
          IEvaluation eval) throws XGBoostError {
    return train(dtrain, params, round, watches, null, obj, eval);
  }

  public static Booster train(
          DMatrix dtrain,
          Map<String, Object> params,
          int round,
          Map<String, DMatrix> watches,
          float[][] metrics,
          IObjective obj,
          IEvaluation eval) throws XGBoostError {

    //collect eval matrixs
    String[] evalNames;
    DMatrix[] evalMats;
    List<String> names = new ArrayList<String>();
    List<DMatrix> mats = new ArrayList<DMatrix>();

    for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
      names.add(evalEntry.getKey());
      mats.add(evalEntry.getValue());
    }

    evalNames = names.toArray(new String[names.size()]);
    evalMats = mats.toArray(new DMatrix[mats.size()]);

    //collect all data matrixs
    DMatrix[] allMats;
    if (evalMats.length > 0) {
      allMats = new DMatrix[evalMats.length + 1];
      allMats[0] = dtrain;
      System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
    } else {
      allMats = new DMatrix[1];
      allMats[0] = dtrain;
    }

    //initialize booster
    Booster booster = new Booster(params, allMats);

    int version = booster.loadRabitCheckpoint();



( run in 0.621 second using v1.01-cache-2.11-cpan-2398b32b56e )