Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/src/common/hist_util.cc view on Meta::CPAN
}
}
}
} else {
for (size_t i = 0; i < column.len; ++i) {
if (mark[column.row_ind[i]]) {
++ret;
if (ret > max_cnt) {
return max_cnt + 1;
}
}
}
}
return ret;
}
template <typename T>
inline void
MarkUsed(std::vector<bool>* p_mark, const Column<T>& column) {
std::vector<bool>& mark = *p_mark;
if (column.type == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.len; ++i) {
if (column.index[i] != std::numeric_limits<T>::max()) {
mark[i] = true;
}
}
} else {
for (size_t i = 0; i < column.len; ++i) {
mark[column.row_ind[i]] = true;
}
}
}
template <typename T>
inline std::vector<std::vector<unsigned>>
FindGroups_(const std::vector<unsigned>& feature_list,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
size_t nrow,
const FastHistParam& param) {
/* Goal: Bundle features together that has little or no "overlap", i.e.
only a few data points should have nonzero values for
member features.
Note that one-hot encoded features will be grouped together. */
std::vector<std::vector<unsigned>> groups;
std::vector<std::vector<bool>> conflict_marks;
std::vector<size_t> group_nnz;
std::vector<size_t> group_conflict_cnt;
const size_t max_conflict_cnt
= static_cast<size_t>(param.max_conflict_rate * nrow);
for (auto fid : feature_list) {
const Column<T>& column = colmat.GetColumn<T>(fid);
const size_t cur_fid_nnz = feature_nnz[fid];
bool need_new_group = true;
// randomly choose some of existing groups as candidates
std::vector<unsigned> search_groups;
for (size_t gid = 0; gid < groups.size(); ++gid) {
if (group_nnz[gid] + cur_fid_nnz <= nrow + max_conflict_cnt) {
search_groups.push_back(gid);
}
}
std::shuffle(search_groups.begin(), search_groups.end(), common::GlobalRandom());
if (param.max_search_group > 0 && search_groups.size() > param.max_search_group) {
search_groups.resize(param.max_search_group);
}
// examine each candidate group: is it okay to insert fid?
for (auto gid : search_groups) {
const size_t rest_max_cnt = max_conflict_cnt - group_conflict_cnt[gid];
const size_t cnt = GetConflictCount(conflict_marks[gid], column, rest_max_cnt);
if (cnt <= rest_max_cnt) {
need_new_group = false;
groups[gid].push_back(fid);
group_conflict_cnt[gid] += cnt;
group_nnz[gid] += cur_fid_nnz - cnt;
MarkUsed(&conflict_marks[gid], column);
break;
}
}
// create new group if necessary
if (need_new_group) {
groups.emplace_back();
groups.back().push_back(fid);
group_conflict_cnt.push_back(0);
conflict_marks.emplace_back(nrow, false);
MarkUsed(&conflict_marks.back(), column);
group_nnz.emplace_back(cur_fid_nnz);
}
}
return groups;
}
inline std::vector<std::vector<unsigned>>
FindGroups(const std::vector<unsigned>& feature_list,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
size_t nrow,
const FastHistParam& param) {
XGBOOST_TYPE_SWITCH(colmat.dtype, {
return FindGroups_<DType>(feature_list, feature_nnz, colmat, nrow, param);
});
return std::vector<std::vector<unsigned>>(); // to avoid warning message
}
inline std::vector<std::vector<unsigned>>
FastFeatureGrouping(const GHistIndexMatrix& gmat,
const ColumnMatrix& colmat,
const FastHistParam& param) {
const size_t nrow = gmat.row_ptr.size() - 1;
const size_t nfeature = gmat.cut->row_ptr.size() - 1;
std::vector<unsigned> feature_list(nfeature);
std::iota(feature_list.begin(), feature_list.end(), 0);
// sort features by nonzero counts, descending order
std::vector<size_t> feature_nnz(nfeature);
std::vector<unsigned> features_by_nnz(feature_list);
gmat.GetFeatureCounts(&feature_nnz[0]);
std::sort(features_by_nnz.begin(), features_by_nnz.end(),
[&feature_nnz](unsigned a, unsigned b) {
return feature_nnz[a] > feature_nnz[b];
});
auto groups_alt1 = FindGroups(feature_list, feature_nnz, colmat, nrow, param);
auto groups_alt2 = FindGroups(features_by_nnz, feature_nnz, colmat, nrow, param);
auto& groups = (groups_alt1.size() > groups_alt2.size()) ? groups_alt2 : groups_alt1;
// take apart small, sparse groups, as it won't help speed
{
std::vector<std::vector<unsigned>> ret;
for (const auto& group : groups) {
if (group.size() <= 1 || group.size() >= 5) {
ret.push_back(group); // keep singleton groups and large (5+) groups
} else {
( run in 2.205 seconds using v1.01-cache-2.11-cpan-ceb78f64989 )