Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala view on Meta::CPAN
object XGBoost {
/**
* Helper map function to start the job.
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, Any],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
val logger = LogFactory.getLog(this.getClass)
def mapPartition(it: java.lang.Iterable[LabeledVector],
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
Rabit.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start(0L)) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
} else {
throw new Error("Tracker cannot be started")
null
}
}
}
( run in 0.516 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )