Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/dmlc-core/tracker/yarn/src/main/java/org/apache/hadoop/yarn/dmlc/ApplicationMaster.java  view on Meta::CPAN

    // hdfs handler
    private FileSystem dfs;

    // number of cores allocated for each worker task
    private int workerCores = 1;
    // number of cores allocated for each server task
    private int serverCores = 1;
    // memory needed requested for the worker task
    private int workerMemoryMB = 10;
    // memory needed requested for the server task
    private int serverMemoryMB = 10;
    // priority of the app master
    private int appPriority = 0;
    // total number of workers
    private int numWorker = 1;
    // total number of server
    private int numServer = 0;
    // total number of tasks
    private int numTasks;
    // maximum number of attempts to try in each task
    private int maxNumAttempt = 3;
    // command to launch
    private String command = "";

    // username
    private String userName = "";
    // user credentials
    private Credentials credentials = null;
    // application tracker hostname
    private String appHostName = "";
    // tracker URL to do
    private String appTrackerUrl = "";
    // tracker port
    private int appTrackerPort = 0;

    // whether we start to abort the application, due to whatever fatal reasons
    private boolean startAbort = false;
    // worker resources
    private Map<String, LocalResource> workerResources = new java.util.HashMap<String, LocalResource>();
    // record the aborting reason
    private String abortDiagnosis = "";
    // resource manager
    private AMRMClientAsync<ContainerRequest> rmClient = null;
    // node manager
    private NMClientAsync nmClient = null;

    // list of tasks that pending for resources to be allocated
    private final Queue<TaskRecord> pendingTasks = new java.util.LinkedList<TaskRecord>();
    // map containerId->task record of tasks that was running
    private final Map<ContainerId, TaskRecord> runningTasks = new java.util.HashMap<ContainerId, TaskRecord>();
    // collection of tasks
    private final Collection<TaskRecord> finishedTasks = new java.util.LinkedList<TaskRecord>();
    // collection of killed tasks
    private final Collection<TaskRecord> killedTasks = new java.util.LinkedList<TaskRecord>();
    // worker environment
    private final Map<String, String> env = new java.util.HashMap<String, String>();

    //add the blacklist
    private Collection<String> blackList = new java.util.HashSet();

    public static void main(String[] args) throws Exception {
        new ApplicationMaster().run(args);
    }

    private ApplicationMaster() throws IOException {
        dfs = FileSystem.get(conf);
        userName = UserGroupInformation.getCurrentUser().getShortUserName();
        credentials = UserGroupInformation.getCurrentUser().getCredentials();
    }


    /**
     * setup security token given current user
     * @return the ByeBuffer containing the security tokens
     * @throws IOException
     */
    private ByteBuffer setupTokens() {
        try {
            DataOutputBuffer dob = new DataOutputBuffer();
            credentials.writeTokenStorageToStream(dob);
            return ByteBuffer.wrap(dob.getData(), 0, dob.getLength()).duplicate();
        } catch (IOException e) {
            throw new RuntimeException(e);  // TODO: FIXME
        }
    }


    /**
     * get integer argument from environment variable
     *
     * @param name
     *            name of key
     * @param required
     *            whether this is required
     * @param defv
     *            default value
     * @return the requested result
     */
    private int getEnvInteger(String name, boolean required, int defv)
            throws IOException {
        String value = System.getenv(name);
        if (value == null) {
            if (required) {
                throw new IOException("environment variable " + name
                        + " not set");
            } else {
                return defv;
            }
        }
        return Integer.valueOf(value);
    }

    /**
     * initialize from arguments and command lines
     *
     * @param args
     */
    private void initArgs(String args[]) throws IOException {
        LOG.info("Start AM as user=" + this.userName);
        // get user name
        userName = UserGroupInformation.getCurrentUser().getShortUserName();
        // cached maps
        Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
        for (int i = 0; i < args.length; ++i) {
            if (args[i].equals("-file")) {
                String[] arr = args[++i].split("#");
                Path path = new Path(arr[0]);
                if (arr.length == 1) {
                    cacheFiles.put(path.getName(), path);
                } else {
                    cacheFiles.put(arr[1], path);
                }
            } else if (args[i].equals("-env")) {
                String[] pair = args[++i].split("=", 2);
                env.put(pair[0], (pair.length == 1) ? "" : pair[1]);
            } else {
                this.command += args[i] + " ";
            }
        }
        for (Map.Entry<String, Path> e : cacheFiles.entrySet()) {
            LocalResource r = Records.newRecord(LocalResource.class);
            FileStatus status = dfs.getFileStatus(e.getValue());
            r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue()));
            r.setSize(status.getLen());
            r.setTimestamp(status.getModificationTime());
            r.setType(LocalResourceType.FILE);
            r.setVisibility(LocalResourceVisibility.APPLICATION);
            workerResources.put(e.getKey(), r);
        }
        workerCores = this.getEnvInteger("DMLC_WORKER_CORES", true, workerCores);
        serverCores = this.getEnvInteger("DMLC_SERVER_CORES", true, serverCores);
        workerMemoryMB = this.getEnvInteger("DMLC_WORKER_MEMORY_MB", true, workerMemoryMB);
        serverMemoryMB = this.getEnvInteger("DMLC_SERVER_MEMORY_MB", true, serverMemoryMB);
        numWorker = this.getEnvInteger("DMLC_NUM_WORKER", true, numWorker);
        numServer = this.getEnvInteger("DMLC_NUM_SERVER", true, numServer);
        numTasks = numWorker + numServer;
        maxNumAttempt = this.getEnvInteger("DMLC_MAX_ATTEMPT", false,
                                           maxNumAttempt);
        LOG.info("Try to start " + numServer + " Servers and " + numWorker + " Workers");
    }

    /**
     * called to start the application
     */
    private void run(String args[]) throws Exception {
        this.initArgs(args);
        this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000,
                new RMCallbackHandler());
        this.nmClient = NMClientAsync
                .createNMClientAsync(new NMCallbackHandler());
        this.rmClient.init(conf);
        this.rmClient.start();
        this.nmClient.init(conf);
        this.nmClient.start();
        RegisterApplicationMasterResponse response = this.rmClient
                .registerApplicationMaster(this.appHostName,
                        this.appTrackerPort, this.appTrackerUrl);

        boolean success = false;
        String diagnostics = "";
        try {
            // list of tasks that waits to be submit
            java.util.Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
            // add waiting tasks
            for (int i = 0; i < this.numWorker; ++i) {
                tasks.add(new TaskRecord(i, "worker"));
            }
            for (int i = 0; i < this.numServer; ++i) {
                tasks.add(new TaskRecord(i, "server"));
            }
            Resource maxResource = response.getMaximumResourceCapability();

            if (maxResource.getMemory() < this.serverMemoryMB) {
              LOG.warn("[DMLC] memory requested exceed bound "
                        + maxResource.getMemory());
                this.serverMemoryMB = maxResource.getMemory();
            }
            if (maxResource.getMemory() < this.workerMemoryMB) {
              LOG.warn("[DMLC] memory requested exceed bound "
                        + maxResource.getMemory());
                this.workerMemoryMB = maxResource.getMemory();
            }
            if (maxResource.getVirtualCores() < this.workerCores) {
               LOG.warn("[DMLC] cores requested exceed bound "
                        + maxResource.getVirtualCores());
               this.workerCores = maxResource.getVirtualCores();
            }
            if (maxResource.getVirtualCores() < this.serverCores) {
              LOG.warn("[DMLC] cores requested exceed bound "
                        + maxResource.getVirtualCores());
                this.serverCores = maxResource.getVirtualCores();
            }
            this.submitTasks(tasks);
            LOG.info("[DMLC] ApplicationMaster started");
            while (!this.doneAllJobs()) {
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                }
            }
            assert (killedTasks.size() + finishedTasks.size() == numTasks);
            success = finishedTasks.size() == numTasks;
            LOG.info("Application completed. Stopping running containers");
            diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
                + ", finished=" + this.finishedTasks.size() + ", failed="
                + this.killedTasks.size() + "\n" + this.abortDiagnosis;
            nmClient.stop();
            LOG.info(diagnostics);
        } catch (Exception e) {
            diagnostics = e.toString();
        }
        rmClient.unregisterApplicationMaster(
                success ? FinalApplicationStatus.SUCCEEDED
                        : FinalApplicationStatus.FAILED, diagnostics,
                appTrackerUrl);
        if (!success)
            throw new Exception("Application not successful");
    }

    /**
     * check if the job finishes
     *
     * @return whether we finished all the jobs
     */
    private synchronized boolean doneAllJobs() {
        return pendingTasks.size() == 0 && runningTasks.size() == 0;
    }

    /**
     * submit tasks to request containers for the tasks
     *
     * @param tasks
     *            a collection of tasks we want to ask container for
     */
    private synchronized void submitTasks(Collection<TaskRecord> tasks) {
        for (TaskRecord r : tasks) {
            Resource resource = Records.newRecord(Resource.class);
            if (r.taskRole == "server") {
              resource.setMemory(serverMemoryMB);
              resource.setVirtualCores(serverCores);
            } else {
              resource.setMemory(workerMemoryMB);
              resource.setVirtualCores(workerCores);
            }
            Priority priority = Records.newRecord(Priority.class);
            priority.setPriority(this.appPriority);
            r.containerRequest = new ContainerRequest(resource, null, null,
                    priority);
            rmClient.addContainerRequest(r.containerRequest);
            pendingTasks.add(r);
        }
    }



    private synchronized void launchDummyTask(Container container){
        ContainerLaunchContext ctx = Records.newRecord(ContainerLaunchContext.class);
        String new_command = "./launcher.py";
        String cmd = new_command + " 1>"
            + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
            + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
            + "/stderr";
        ctx.setCommands(Collections.singletonList(cmd));
        ctx.setTokens(setupTokens());
        ctx.setLocalResources(this.workerResources);
        synchronized (this){
            this.nmClient.startContainerAsync(container, ctx);
        }
    }
    /**
     * launch the task on container
     *
     * @param container
     *            container to run the task
     * @param task
     *            the task
     */



( run in 2.860 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )