Alien-XGBoost
view release on metacpan or search on metacpan
xgboost/plugin/updater_gpu/src/exact/split2node.cuh view on Meta::CPAN
/*
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../../../src/tree/param.h"
#include "node.cuh"
namespace xgboost {
namespace tree {
namespace exact {
/**
* @brief Helper function to update the child node based on the current status
* of its parent node
* @param nodes the nodes array in which the position at 'nid' will be updated
* @param nid the nodeId in the 'nodes' array corresponding to this child node
* @param grad gradient sum for this child node
* @param minChildWeight minimum child weight for the split
* @param alpha L1 regularizer for weight updates
* @param lambda lambda as in xgboost
* @param maxStep max weight step update
*/
template <typename node_id_t>
DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
const bst_gpair& grad,
const TrainParam& param) {
nodes[nid].gradSum = grad;
nodes[nid].score = CalcGain(param, grad.grad, grad.hess);
nodes[nid].weight = CalcWeight(param, grad.grad, grad.hess);
nodes[nid].id = nid;
}
/**
* @brief Helper function to update the child nodes based on the current status
* of their parent node
* @param nodes the nodes array in which the position at 'nid' will be updated
* @param pid the nodeId of the parent
* @param gradL gradient sum for the left child node
* @param gradR gradient sum for the right child node
* @param param the training parameter struct
*/
template <typename node_id_t>
DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
const bst_gpair& gradL, const bst_gpair& gradR,
const TrainParam& param) {
int childId = (pid * 2) + 1;
updateOneChildNode(nodes, childId, gradL, param);
updateOneChildNode(nodes, childId + 1, gradR, param);
}
template <typename node_id_t>
DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
const Node<node_id_t>& n, int absNodeId,
int colId, const bst_gpair& gradScan,
const bst_gpair& colSum, float thresh,
const TrainParam& param) {
bool missingLeft = true;
// get the default direction for the current node
bst_gpair missing = n.gradSum - colSum;
loss_chg_missing(gradScan, missing, n.gradSum, n.score, param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes
bst_gpair lGradSum, rGradSum;
if (missingLeft) {
lGradSum = gradScan + n.gradSum - colSum;
} else {
lGradSum = gradScan;
}
rGradSum = n.gradSum - lGradSum;
updateChildNodes(nodes, absNodeId, lGradSum, rGradSum, param);
// update default-dir, threshold and feature id for current node
nodes[absNodeId].dir = missingLeft ? LeftDir : RightDir;
nodes[absNodeId].colIdx = colId;
nodes[absNodeId].threshold = thresh;
}
template <typename node_id_t, int BLKDIM = 256>
__global__ void split2nodeKernel(
Node<node_id_t>* nodes, const Split* nodeSplits, const bst_gpair* gradScans,
const bst_gpair* gradSums, const float* vals, const int* colIds,
const int* colOffsets, const node_id_t* nodeAssigns, int nUniqKeys,
node_id_t nodeStart, int nCols, const TrainParam param) {
int uid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (uid >= nUniqKeys) {
return;
}
int absNodeId = uid + nodeStart;
Split s = nodeSplits[uid];
if (s.isSplittable(param.min_split_loss)) {
int idx = s.index;
int nodeInstId =
abs2uniqKey(idx, nodeAssigns, colIds, nodeStart, nUniqKeys);
updateNodeAndChildren(nodes, s, nodes[absNodeId], absNodeId, colIds[idx],
gradScans[idx], gradSums[nodeInstId], vals[idx],
param);
} else {
// cannot be split further, so this node is a leaf!
( run in 0.568 second using v1.01-cache-2.11-cpan-39bf76dae61 )