Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala  view on Meta::CPAN


object XGBoost {
  /**
    * Helper map function to start the job.
    *
    * @param workerEnvs
    */
  private class MapFunction(paramMap: Map[String, Any],
                            round: Int,
                            workerEnvs: java.util.Map[String, String])
    extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
    val logger = LogFactory.getLog(this.getClass)

    def mapPartition(it: java.lang.Iterable[LabeledVector],
                     collector: Collector[XGBoostModel]): Unit = {
      workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
      logger.info("start with env" + workerEnvs.toString)
      Rabit.init(workerEnvs)
      val mapper = (x: LabeledVector) => {
        val (index, value) = x.vector.toSeq.unzip
        LabeledPoint(x.label.toFloat, index.toArray, value.map(_.toFloat).toArray)
      }
      val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
      val trainMat = new DMatrix(dataIter, null)
      val watches = List("train" -> trainMat).toMap
      val round = 2
      val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
      Rabit.shutdown()
      collector.collect(new XGBoostModel(booster))
    }
  }

  val logger = LogFactory.getLog(this.getClass)

  /**
    * Load XGBoost model from path, using Hadoop Filesystem API.
    *
    * @param modelPath The path that is accessible by hadoop filesystem API.
    * @return The loaded model
    */
  def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
    new XGBoostModel(
      XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
  }

  /**
    * Train a xgboost model with link.
    *
    * @param dtrain The training data.
    * @param params The parameters to XGBoost.
    * @param round Number of rounds to train.
    */
  def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
      XGBoostModel = {
    val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
    if (tracker.start(0L)) {
      dtrain
        .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
        .reduce((x, y) => x).collect().head
    } else {
      throw new Error("Tracker cannot be started")
      null
    }
  }
}



( run in 0.516 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )