Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitTrackerRobustnessSuite.scala  view on Meta::CPAN

       To prevent unit tests from crashing, deterministic delays were introduced to make sure that
       the exception is thrown at last, ideally after all worker connections have been established.
       For the same reason, the Java RabitTracker class delays the killing of the Python tracker
       process to ensure that pending worker connections are handled.
     */
    val dummyTasks = rdd.mapPartitions { iter =>
      Rabit.init(trackerEnvs)
      val index = iter.next()
      Thread.sleep(100 + index * 10)
      if (index == workerCount) {
        // kill the worker by throwing an exception
        throw new RuntimeException("Worker exception.")
      }
      Rabit.shutdown()
      Iterator(index)
    }.cache()

    val sparkThread = new Thread() {
      override def run(): Unit = {
        // forces a Spark job.
        dummyTasks.foreachPartition(() => _)
      }
    }

    sparkThread.setUncaughtExceptionHandler(tracker)
    sparkThread.start()
    assert(tracker.waitFor(0) != 0)
  }

  test("test Scala RabitTracker's exception handling: it should not hang forever.") {
    val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()

    val tracker = new ScalaRabitTracker(numWorkers)
    tracker.start(0)
    val trackerEnvs = tracker.getWorkerEnvs

    val workerCount: Int = numWorkers
    val dummyTasks = rdd.mapPartitions { iter =>
      Rabit.init(trackerEnvs)
      val index = iter.next()
      Thread.sleep(100 + index * 10)
      if (index == workerCount) {
        // kill the worker by throwing an exception
        throw new RuntimeException("Worker exception.")
      }
      Rabit.shutdown()
      Iterator(index)
    }.cache()

    val sparkThread = new Thread() {
      override def run(): Unit = {
        // forces a Spark job.
        dummyTasks.foreachPartition(() => _)
      }
    }
    sparkThread.setUncaughtExceptionHandler(tracker)
    sparkThread.start()
    assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
  }

  test("test Scala RabitTracker's workerConnectionTimeout") {
    val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()

    val tracker = new ScalaRabitTracker(numWorkers)
    tracker.start(500)
    val trackerEnvs = tracker.getWorkerEnvs

    val dummyTasks = rdd.mapPartitions { iter =>
      val index = iter.next()
      // simulate that the first worker cannot connect to tracker due to network issues.
      if (index != 1) {
        Rabit.init(trackerEnvs)
        Thread.sleep(1000)
        Rabit.shutdown()
      }

      Iterator(index)
    }.cache()

    val sparkThread = new Thread() {
      override def run(): Unit = {
        // forces a Spark job.
        dummyTasks.foreachPartition(() => _)
      }
    }
    sparkThread.setUncaughtExceptionHandler(tracker)
    sparkThread.start()
    // should fail due to connection timeout
    assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
  }
}

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 0.749 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )