Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala  view on Meta::CPAN

import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, TaskContext}

object TrackerConf {
  def apply(): TrackerConf = TrackerConf(0L, "python")
}

/**
  * Rabit tracker configurations.
  * @param workerConnectionTimeout The timeout for all workers to connect to the tracker.
  *                                Set timeout length to zero to disable timeout.
  *                                Use a finite, non-zero timeout value to prevent tracker from
  *                                hanging indefinitely (in milliseconds)
 *                                (supported by "scala" implementation only.)
  * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
  *                    the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
  *                    in Scala without Python components, and with full support of timeouts.
  *                    The Scala implementation is currently experimental, use at your own risk.
  */
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String)

object XGBoost extends Serializable {
  private val logger = LogFactory.getLog("XGBoostSpark")

  private def fromDenseToSparseLabeledPoints(
      denseLabeledPoints: Iterator[XGBLabeledPoint],
      missing: Float): Iterator[XGBLabeledPoint] = {
    if (!missing.isNaN) {
      denseLabeledPoints.map { labeledPoint =>
        val indicesBuilder = new mutable.ArrayBuilder.ofInt()

xgboost/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala  view on Meta::CPAN

    overridedParams
  }

  private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
    val tracker: IRabitTracker = trackerConf.trackerImpl match {
      case "scala" => new RabitTracker(nWorkers)
      case "python" => new PyRabitTracker(nWorkers)
      case _ => new PyRabitTracker(nWorkers)
    }

    require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
    tracker
  }

  /**
   * various of train()
   * @param trainingData the trainingset represented as RDD
   * @param params Map containing the configuration entries
   * @param round the number of iterations
   * @param nWorkers the number of xgboost workers, 0 by default which means that the number of
   *                 workers equals to the partition number of trainingData RDD

xgboost/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala  view on Meta::CPAN


  /**
   * the value treated as missing. default: Float.NaN
   */
  val missing = new FloatParam(this, "missing", "the value treated as missing")

  /**
    * Rabit tracker configurations. The parameter must be provided as an instance of the
    * TrackerConf class, which has the following definition:
    *
    *     case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration,
    *                            trackerImpl: String)
    *
    * See below for detailed explanations.
    *
    *   - trackerImpl: Select the implementation of Rabit tracker.
    *                  default: "python"
    *
    *        Choice between "python" or "scala". The former utilizes the Java wrapper of the
    *        Python Rabit tracker (in dmlc_core), and does not support timeout settings.
    *        The "scala" version removes Python components, and fully supports timeout settings.
    *
    *   - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker.
    *                             default: 0 millisecond (no timeout)
    *
    *        The timeout value should take the time of data loading and pre-processing into account,
    *        due to the lazy execution of Spark's operations. Alternatively, you may force Spark to
    *        perform data transformation before calling XGBoost.train(), so that this timeout truly
    *        reflects the connection delay. Set a reasonable timeout value to prevent model
    *        training/testing from hanging indefinitely, possible due to network issues.
    *        Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
    *        Ignored if the tracker implementation is "python".
    */

xgboost/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala  view on Meta::CPAN

      override def run(): Unit = {
        // forces a Spark job.
        dummyTasks.foreachPartition(() => _)
      }
    }
    sparkThread.setUncaughtExceptionHandler(tracker)
    sparkThread.start()
    assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
  }

  test("test Scala RabitTracker's workerConnectionTimeout") {
    val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()

    val tracker = new ScalaRabitTracker(numWorkers)
    tracker.start(500)
    val trackerEnvs = tracker.getWorkerEnvs

    val dummyTasks = rdd.mapPartitions { iter =>
      val index = iter.next()
      // simulate that the first worker cannot connect to tracker due to network issues.
      if (index != 1) {

xgboost/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java  view on Meta::CPAN

    TrackerStatus(int statusCode) {
      this.statusCode = statusCode;
    }

    public int getStatusCode() {
      return this.statusCode;
    }
  }

  Map<String, String> getWorkerEnvs();
  boolean start(long workerConnectionTimeout);
  void stop();
  // taskExecutionTimeout has no effect in current version of XGBoost.
  int waitFor(long taskExecutionTimeout);
}

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala  view on Meta::CPAN

private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
                                  maxPortTrials: Int = 1000)
  extends IRabitTracker {

  import scala.collection.JavaConverters._

  require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")

  val system = ActorSystem.create("RabitTracker")
  val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
  implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
  private[this] val tcpBindingTimeout: Duration = 1 minute

  var workerEnvs: Map[String, String] = Map.empty

  override def uncaughtException(t: Thread, e: Throwable): Unit = {
    handler ? RabitTrackerHandler.InterruptTracker(e)
  }

  /**
    * Start the Rabit tracker.
    *

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala  view on Meta::CPAN

    *      establishing connections to the tracker, due to the overhead of loading data.
    *      Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
    *      running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
    * @return Boolean flag indicating if the Rabit tracker starts successfully.
    */
  private def start(timeout: Duration): Boolean = {
    handler ? RabitTrackerHandler.StartTracker(
      new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)

    // block by waiting for the actor to bind to a port
    Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
      .asInstanceOf[Future[Map[String, String]]]) match {
      case Success(futurePortBound) =>
        // The success of the Future is contingent on binding to an InetSocketAddress.
        val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
        if (isBound) {
          workerEnvs = Await.result(futurePortBound, 0 nano)
        }
        isBound
      case Failure(ex: Throwable) =>
        false
    }
  }

  /**
    * Start the Rabit tracker.
    *
    * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
    *                                connections. If a non-positive value is provided, the tracker
    *                                waits for incoming worker connections indefinitely.
    * @return Boolean flag indicating if the Rabit tracker starts successfully.
    */
  def start(connectionTimeoutMillis: Long): Boolean = {
    if (connectionTimeoutMillis <= 0) {
      start(Duration.Inf)
    } else {
      start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
    }
  }

  def stop(): Unit = {
    if (!system.isTerminated) {
      system.shutdown()
    }
  }

  /**

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala  view on Meta::CPAN

  /**
    * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
    * This method blocks until timeout or task completion.
    *
    * @param atMost the maximum execution time for the workers. By default,
    *     the tracker waits for the workers indefinitely.
    * @return 0 if the tasks complete successfully, and non-zero otherwise.
    */
  private def waitFor(atMost: Duration): Int = {
    // request the completion Future from the tracker actor
    Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
      .asInstanceOf[Future[Int]]) match {
      case Success(futureCompleted) =>
        // wait for all workers to complete synchronously.
        val statusCode = Try(Await.result(futureCompleted, atMost)) match {
          case Success(n) if n == numWorkers =>
            IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
          case Success(n) if n < numWorkers =>
            IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
          case Failure(e) =>
            IRabitTracker.TrackerStatus.FAILURE.getStatusCode

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN

 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.rabit.handler

import java.net.InetSocketAddress
import java.util.UUID

import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}

import scala.util.{Failure, Random, Success, Try}

/** The Akka actor for handling and coordinating Rabit worker connections.
  * This is the main actor for handling socket connections, interacting with the synchronous
  * tracker interface, and resolving tree/ring/parent dependencies between workers.

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN


  // resolves worker connection dependency.
  val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")

  // workers that have sent "shutdown" signal
  private[this] val shutdownWorkers = mutable.Set.empty[Int]
  private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
  private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
  private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
  private[this] var maxPortTrials = 0
  private[this] var workerConnectionTimeout: Duration = Duration.Inf
  private[this] var portTrials = 0
  private[this] val startedWorkers = mutable.Set.empty[Int]

  val linkMap = new LinkMap(numWorkers)

  def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
    rank match {
      case r if r >= 0 => Some(r)
      case _ =>
        jobId match {

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN

  /**
    * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
    * by the RabitWorkerHandler.
    *
    * @param event Generic Tcp.Event
    */
  private def handleTcpEvents(event: Tcp.Event): Unit = event match {
    case Tcp.Bound(local) =>
      // expect all workers to connect within timeout
      log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
      log.info(s"Worker connection timeout is $workerConnectionTimeout.")

      context.setReceiveTimeout(workerConnectionTimeout)
      promisedWorkerEnvs.success(Map(
        "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
        "DMLC_TRACKER_PORT" -> local.getPort.toString,
        // not required because the world size will be communicated to the
        // worker node after the rank is assigned.
        "rabit_world_size" -> numWorkers.toString
      ))

    case Tcp.CommandFailed(cmd: Tcp.Bind) =>
      if (portTrials < maxPortTrials) {

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN

    * Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
    * to interact with the tracker interface.
    *
    * @param trackerMsg control messages sent by RabitTracker class.
    */
  private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
    trackerMsg match {

    case msg: StartTracker =>
      maxPortTrials = msg.maxPortTrials
      workerConnectionTimeout = msg.connectionTimeout

      // if the port number is missing, try binding to a random ephemeral port.
      if (msg.addr.getPort == 0) {
        tcpManager ! Tcp.Bind(self,
          new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
          backlog = 256)
      } else {
        tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
      }
      sender() ! true

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN

      resolver forward req

    case _ =>
  }

  def receive: Actor.Receive = {
    case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
    case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
    case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)

    case akka.actor.ReceiveTimeout =>
      if (startedWorkers.size < numWorkers) {
        promisedShutdownWorkers.failure(
          new TimeoutException("Timed out waiting for workers to connect: " +
            s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
        )
        context.stop(self)
      }

      context.setReceiveTimeout(Duration.Undefined)
  }
}

/**
  * Resolve the dependency between nodes as they connect to the tracker.
  * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
  * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
  * connections by workers, this dependency constraint assumes that a worker node connects first
  * is likely to finish setup first.
  */

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala  view on Meta::CPAN

      }
  }
}

private[scala] object RabitTrackerHandler {
  // Messages sent by RabitTracker to this RabitTrackerHandler actor
  trait TrackerControlMessage
  case object RequestCompletionFuture extends TrackerControlMessage
  case object RequestBoundFuture extends TrackerControlMessage
  // Start the Rabit tracker at given socket address awaiting worker connections.
  // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
  // shut down due to timeout.
  case class StartTracker(addr: InetSocketAddress,
                          maxPortTrials: Int,
                          connectionTimeout: Duration) extends TrackerControlMessage
  // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
  // driver for the distributed training.
  case class InterruptTracker(e: Throwable) extends TrackerControlMessage

  def props(numWorkers: Int): Props =
    Props(new RabitTrackerHandler(numWorkers))
}



( run in 0.541 second using v1.01-cache-2.11-cpan-4d50c553e7e )