Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala  view on Meta::CPAN

 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.rabit

import java.net.{InetAddress, InetSocketAddress}

import akka.actor.ActorSystem
import akka.pattern.ask
import ml.dmlc.xgboost4j.java.IRabitTracker
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler

import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success, Try}

/**
  * Scala implementation of the Rabit tracker interface without Python dependency.
  * The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
  * (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
  * failures.
  *
  * Note that this implementation is currently experimental, and should be used at your own risk.
  *
  * Example usage:
  * {{{
  *   import scala.concurrent.duration._
  *
  *   val tracker = new RabitTracker(32)
  *   // allow up to 10 minutes for all workers to connect to the tracker.
  *   tracker.start(10 minutes)
  *
  *   /* ...
  *      launching workers in parallel
  *      ...
  *   */
  *
  *   // wait for worker execution up to 6 hours.
  *   // providing a finite timeout prevents a long-running task from hanging forever in
  *   // catastrophic events, like the loss of an executor during model training.
  *   tracker.waitFor(6 hours)
  * }}}
  *
  * @param numWorkers Number of distributed workers from which the tracker expects connections.
  * @param port The minimum port number that the tracker binds to.
  *             If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
  * @param maxPortTrials The maximum number of trials of socket binding, by sequentially
  *                      increasing the port number.
  */
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
                                  maxPortTrials: Int = 1000)
  extends IRabitTracker {

  import scala.collection.JavaConverters._

  require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")

  val system = ActorSystem.create("RabitTracker")
  val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
  implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
  private[this] val tcpBindingTimeout: Duration = 1 minute

  var workerEnvs: Map[String, String] = Map.empty

  override def uncaughtException(t: Thread, e: Throwable): Unit = {
    handler ? RabitTrackerHandler.InterruptTracker(e)
  }

  /**
    * Start the Rabit tracker.
    *
    * @param timeout The timeout for awaiting connections from worker nodes.
    *      Note that when used in Spark applications, because all Spark transformations are
    *      lazily executed, the I/O time for loading RDDs/DataFrames from external sources
    *      (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
    *      If the timeout value is too small, the Rabit tracker will likely timeout before workers
    *      establishing connections to the tracker, due to the overhead of loading data.
    *      Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
    *      running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
    * @return Boolean flag indicating if the Rabit tracker starts successfully.
    */
  private def start(timeout: Duration): Boolean = {
    handler ? RabitTrackerHandler.StartTracker(
      new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout)

    // block by waiting for the actor to bind to a port
    Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
      .asInstanceOf[Future[Map[String, String]]]) match {
      case Success(futurePortBound) =>
        // The success of the Future is contingent on binding to an InetSocketAddress.
        val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
        if (isBound) {
          workerEnvs = Await.result(futurePortBound, 0 nano)
        }
        isBound
      case Failure(ex: Throwable) =>
        false
    }
  }

  /**
    * Start the Rabit tracker.
    *
    * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
    *                                connections. If a non-positive value is provided, the tracker
    *                                waits for incoming worker connections indefinitely.
    * @return Boolean flag indicating if the Rabit tracker starts successfully.
    */
  def start(connectionTimeoutMillis: Long): Boolean = {
    if (connectionTimeoutMillis <= 0) {
      start(Duration.Inf)
    } else {
      start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
    }
  }

  def stop(): Unit = {
    if (!system.isTerminated) {
      system.shutdown()
    }
  }

  /**
    * Get a Map of necessary environment variables to initiate Rabit workers.
    *
    * @return HashMap containing tracker information.
    */
  def getWorkerEnvs: java.util.Map[String, String] = {
    new java.util.HashMap((workerEnvs ++ Map(
        "DMLC_NUM_WORKER" -> numWorkers.toString,
        "DMLC_NUM_SERVER" -> "0"
    )).asJava)
  }

  /**
    * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
    * This method blocks until timeout or task completion.
    *
    * @param atMost the maximum execution time for the workers. By default,
    *     the tracker waits for the workers indefinitely.
    * @return 0 if the tasks complete successfully, and non-zero otherwise.
    */
  private def waitFor(atMost: Duration): Int = {
    // request the completion Future from the tracker actor
    Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
      .asInstanceOf[Future[Int]]) match {
      case Success(futureCompleted) =>
        // wait for all workers to complete synchronously.
        val statusCode = Try(Await.result(futureCompleted, atMost)) match {
          case Success(n) if n == numWorkers =>
            IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
          case Success(n) if n < numWorkers =>
            IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
          case Failure(e) =>
            IRabitTracker.TrackerStatus.FAILURE.getStatusCode
        }
        system.shutdown()
        statusCode
      case Failure(ex: Throwable) =>
        if (!system.isTerminated) {
          system.shutdown()
        }
        IRabitTracker.TrackerStatus.FAILURE.getStatusCode
    }
  }

  /**
    * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
    * This method blocks until timeout or task completion.
    *
    * @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
    *                     non-positive number is given, the tracker waits indefinitely.
    * @return 0 if the tasks complete successfully, and non-zero otherwise
    */
  def waitFor(atMostMillis: Long): Int = {
    if (atMostMillis <= 0) {
      waitFor(Duration.Inf)
    } else {
      waitFor(Duration.fromNanos(atMostMillis * 1e6))
    }
  }
}

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 1.955 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )