Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala view on Meta::CPAN
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, TaskContext}
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
}
/**
* Rabit tracker configurations.
* @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
* Set timeout length to zero to disable timeout.
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private def fromDenseToSparseLabeledPoints(
denseLabeledPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
denseLabeledPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
} else {
denseLabeledPoints
}
}
private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
val builder = new mutable.ArrayBuilder.ofFloat()
var nTotal = 0
var nUndefined = 0
while (baseMargins.hasNext) {
nTotal += 1
val baseMargin = baseMargins.next()
if (baseMargin.isNaN) {
nUndefined += 1 // don't waste space for all-NaNs.
} else {
builder += baseMargin
}
}
if (nUndefined == nTotal) {
None
} else if (nUndefined == 0) {
Some(builder.result())
} else {
throw new IllegalArgumentException(
s"Encountered a partition with $nUndefined NaN base margin values. " +
"If you want to specify base margin, ensure all values are non-NaN.")
}
}
private[spark] def buildDistributedBoosters(
trainingSet: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
numWorkers: Int,
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
useExternalMemory: Boolean,
missing: Float): RDD[Booster] = {
val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) {
logger.info(s"repartitioning training set to $numWorkers partitions")
trainingSet.repartition(numWorkers)
} else {
xgboost/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala view on Meta::CPAN
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
objective != null && {
val objStr = objective.toString
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
objStr != "rank:pairwise")
}
}
/**
* train XGBoost model with the RDD-represented data
* @param trainingData the trainingset represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@deprecated("Use XGBoost.trainWithRDD instead.")
def train(
trainingData: RDD[MLLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory,
missing)
}
private def overrideParamsAccordingToTaskCPUs(
params: Map[String, Any],
sc: SparkContext): Map[String, Any] = {
val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
var overridedParams = params
if (overridedParams.contains("nthread")) {
val nThread = overridedParams("nthread").toString.toInt
require(nThread <= coresPerTask,
s"the nthread configuration ($nThread) must be no larger than " +
s"spark.task.cpus ($coresPerTask)")
} else {
overridedParams = params + ("nthread" -> coresPerTask)
}
overridedParams
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers)
case _ => new PyRabitTracker(nWorkers)
}
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
tracker
}
/**
* various of train()
* @param trainingData the trainingset represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def trainWithRDD(
trainingData: RDD[MLLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
import DataUtils._
val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) =>
features.asXGB.copy(label = label.toFloat)
}
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval,
useExternalMemory, missing)
}
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now")
}
require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) {
require(params.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
}
val trackerConf = params.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
( run in 0.224 second using v1.01-cache-2.11-cpan-4d50c553e7e )