Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/src/tree/updater_fast_hist.cc  view on Meta::CPAN

      }

      while (!qexpand_->empty()) {
        const ExpandEntry candidate = qexpand_->top();
        const int nid = candidate.nid;
        qexpand_->pop();
        if (candidate.loss_chg <= rt_eps
            || (param.max_depth > 0 && candidate.depth == param.max_depth)
            || (param.max_leaves > 0 && num_leaves == param.max_leaves) ) {
          (*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
        } else {
          tstart = dmlc::GetTime();
          this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
          time_apply_split += dmlc::GetTime() - tstart;

          tstart = dmlc::GetTime();
          const int cleft = (*p_tree)[nid].cleft();
          const int cright = (*p_tree)[nid].cright();
          hist_.AddHistRow(cleft);
          hist_.AddHistRow(cright);
          if (row_set_collection_[cleft].size() < row_set_collection_[cright].size()) {
            BuildHist(gpair, row_set_collection_[cleft], gmat, gmatb, feat_set, hist_[cleft]);
            SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
          } else {
            BuildHist(gpair, row_set_collection_[cright], gmat, gmatb, feat_set, hist_[cright]);
            SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
          }
          time_build_hist += dmlc::GetTime() - tstart;

          tstart = dmlc::GetTime();
          this->InitNewNode(cleft, gmat, gpair, *p_fmat, *p_tree);
          this->InitNewNode(cright, gmat, gpair, *p_fmat, *p_tree);
          time_init_new_node += dmlc::GetTime() - tstart;

          tstart = dmlc::GetTime();
          this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree, feat_set);
          this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree, feat_set);
          time_evaluate_split += dmlc::GetTime() - tstart;

          qexpand_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft),
                                     snode[cleft].best.loss_chg,
                                     timestamp++));
          qexpand_->push(ExpandEntry(cright, p_tree->GetDepth(cright),
                                     snode[cright].best.loss_chg,
                                     timestamp++));

          ++num_leaves;  // give two and take one, as parent is no longer a leaf
        }
      }

      // set all the rest expanding nodes to leaf
      // This post condition is not needed in current code, but may be necessary
      // when there are stopping rule that leaves qexpand non-empty
      while (!qexpand_->empty()) {
        const int nid = qexpand_->top().nid;
        qexpand_->pop();
        (*p_tree)[nid].set_leaf(snode[nid].weight * param.learning_rate);
      }
      // remember auxiliary statistics in the tree node
      for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) {
        p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
        p_tree->stat(nid).base_weight = snode[nid].weight;
        p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
        snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid));
      }

      pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});

      if (param.debug_verbose > 0) {
        double total_time = dmlc::GetTime() - gstart;
        LOG(INFO) << "\nInitData:          "
                  << std::fixed << std::setw(6) << std::setprecision(4) << time_init_data
                  << " (" << std::fixed << std::setw(5) << std::setprecision(2)
                  << time_init_data / total_time * 100 << "%)\n"
                  << "InitNewNode:       "
                  << std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node
                  << " (" << std::fixed << std::setw(5) << std::setprecision(2)
                  << time_init_new_node / total_time * 100 << "%)\n"
                  << "BuildHist:         "
                  << std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist
                  << " (" << std::fixed << std::setw(5) << std::setprecision(2)
                  << time_build_hist / total_time * 100 << "%)\n"
                  << "EvaluateSplit:     "
                  << std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split
                  << " (" << std::fixed << std::setw(5) << std::setprecision(2)
                  << time_evaluate_split / total_time * 100 << "%)\n"
                  << "ApplySplit:        "
                  << std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split
                  << " (" << std::fixed << std::setw(5) << std::setprecision(2)
                  << time_apply_split / total_time * 100 << "%)\n"
                  << "========================================\n"
                  << "Total:             "
                  << std::fixed << std::setw(6) << std::setprecision(4) << total_time;
      }
    }

    inline void BuildHist(const std::vector<bst_gpair>& gpair,
                          const RowSetCollection::Elem row_indices,
                          const GHistIndexMatrix& gmat,
                          const GHistIndexBlockMatrix& gmatb,
                          const std::vector<bst_uint>& feat_set,
                          GHistRow hist) {
      if (fhparam.enable_feature_grouping > 0) {
        hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, feat_set, hist);
      } else {
        hist_builder_.BuildHist(gpair, row_indices, gmat, feat_set, hist);
      }
    }

    inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
      hist_builder_.SubtractionTrick(self, sibling, parent);
    }

    inline bool UpdatePredictionCache(const DMatrix* data,
                                      std::vector<bst_float>* p_out_preds) {
      std::vector<bst_float>& out_preds = *p_out_preds;

      // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
      // conjunction with Update().
      if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
        return false;
      }



( run in 0.642 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )