Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java view on Meta::CPAN
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* trainer for xgboost
*
* @author hzx
*/
public class XGBoost {
private static final Log logger = LogFactory.getLog(XGBoost.class);
/**
* load model from modelPath
*
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
public static Booster loadModel(String modelPath)
throws XGBoostError {
return Booster.loadModel(modelPath);
}
/**
* Load a new Booster model from a file opened as input stream.
* The assumption is the input stream only contains one XGBoost Model.
* This can be used to load existing booster models saved by other xgboost bindings.
*
* @param in The input stream of the file,
* will be closed after this function call.
* @return The create boosted
* @throws XGBoostError
* @throws IOException
*/
public static Booster loadModel(InputStream in)
throws XGBoostError, IOException {
return Booster.loadModel(in);
}
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
return train(dtrain, params, round, watches, null, obj, eval);
}
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
mats.add(evalEntry.getValue());
}
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs
DMatrix[] allMats;
if (evalMats.length > 0) {
allMats = new DMatrix[evalMats.length + 1];
allMats[0] = dtrain;
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
} else {
allMats = new DMatrix[1];
allMats[0] = dtrain;
}
//initialize booster
Booster booster = new Booster(params, allMats);
int version = booster.loadRabitCheckpoint();
( run in 0.621 second using v1.01-cache-2.11-cpan-2398b32b56e )