Skip to content

Commit

Permalink
Implement GPU predict leaf.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 30, 2020
1 parent f0c6390 commit 69036af
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 36 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class GradientBooster : public Model, public Configurable {
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictLeaf(DMatrix* dmat,
std::vector<bst_float>* out_preds,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;

/*!
Expand Down
6 changes: 1 addition & 5 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ class Predictor {
unsigned ntree_limit = 0) = 0;

/**
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned
* ntree_limit = 0) = 0;
*
* \brief predict the leaf index of each tree, the output will be nsample *
* ntree vector this is only valid in gbtree predictor.
*
Expand All @@ -177,7 +173,7 @@ class Predictor {
* \param ntree_limit (Optional) The ntree limit.
*/

virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) = 0;

Expand Down
2 changes: 1 addition & 1 deletion src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class GBLinear : public GradientBooster {
}

void PredictLeaf(DMatrix *p_fmat,
std::vector<bst_float> *out_preds,
HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit) override {
LOG(FATAL) << "gblinear does not support prediction of leaf index";
}
Expand Down
5 changes: 2 additions & 3 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,9 @@ class GBTree : public GradientBooster {
}

void PredictLeaf(DMatrix* p_fmat,
std::vector<bst_float>* out_preds,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override {
CHECK(configured_);
cpu_predictor_->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
}

void PredictContribution(DMatrix* p_fmat,
Expand Down
2 changes: 1 addition & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ class LearnerImpl : public LearnerIO {
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
} else {
auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
Expand Down
4 changes: 2 additions & 2 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class CPUPredictor : public Predictor {
}
}

void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
Expand All @@ -331,7 +331,7 @@ class CPUPredictor : public Predictor {
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}
std::vector<bst_float>& preds = *out_preds;
std::vector<bst_float>& preds = out_preds->HostVector();
preds.resize(info.num_row_ * ntree_limit);
// start collecting the prediction
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
Expand Down
136 changes: 118 additions & 18 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ struct DeviceAdapterLoader {
};

template <typename Loader>
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
common::Span<FeatureType const> split_types,
common::Span<RegTree::Segment const> d_cat_ptrs,
common::Span<uint32_t const> d_categories,
Expand Down Expand Up @@ -202,6 +202,51 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
return tree[nidx].LeafValue();
}

template <typename Loader>
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
Loader const& loader) {
bst_node_t nidx = 0;
RegTree::Node n = tree[nidx];
while (!n.IsLeaf()) {
float fvalue = loader.GetElement(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
nidx = n.DefaultChild();
n = tree[nidx];
} else {
if (fvalue < n.SplitCond()) {
nidx = n.LeftChild();
n = tree[nidx];
} else {
nidx = n.RightChild();
n = tree[nidx];
}
}
}
return nidx;
}

template <typename Loader, typename Data>
__global__ void PredictLeafKernel(Data data,
common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t const> d_tree_segments,
common::Span<int32_t const> d_tree_group,
size_t tree_begin, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start,
bool use_shared, int num_group) {
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
if (ridx >= num_rows) {
return;
}
Loader loader(data, use_shared, num_features, num_rows, entry_start);
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
auto leaf = GetLeafIndex(ridx, d_tree, loader);
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
}
}

template <typename Loader, typename Data>
__global__ void
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
Expand Down Expand Up @@ -438,6 +483,17 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
});
}

namespace {
template <size_t kBlockThreads>
size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
size_t shared_memory_bytes =
static_cast<size_t>(sizeof(float) * cols * kBlockThreads);
if (shared_memory_bytes > max_shared_memory_bytes) {
shared_memory_bytes = 0;
}
return shared_memory_bytes;
}
} // anonymous namespace

class GPUPredictor : public xgboost::Predictor {
private:
Expand All @@ -451,13 +507,10 @@ class GPUPredictor : public xgboost::Predictor {
size_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));

auto shared_memory_bytes =
static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes_) {
shared_memory_bytes = 0;
use_shared = false;
}
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes_);
bool use_shared = shared_memory_bytes != 0;

size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
Expand Down Expand Up @@ -609,13 +662,9 @@ class GPUPredictor : public xgboost::Predictor {
const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));

auto shared_memory_bytes =
static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes) {
shared_memory_bytes = 0;
use_shared = false;
}
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes_);
bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0;

dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
Expand Down Expand Up @@ -781,11 +830,62 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}

void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
const gbm::GBTreeModel& model,
unsigned ntree_limit) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes =
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes_);
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
bst_row_t num_rows = info.num_row_;
size_t entry_start = 0;

uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}
predictions->SetDevice(generic_param_->gpu_id);
predictions->Resize(num_rows * real_ntree_limit);
model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id);

if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
bst_row_t batch_offset = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
model_.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared, model_.num_group);
batch_offset += batch.Size();
}
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
bst_row_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
model_.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared, model_.num_group);
batch_offset += batch.Size();
}
}
}

void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Expand Down
7 changes: 7 additions & 0 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
gen.GenerateDense(&out->Info().labels_);
}
}
if (device_ >= 0) {
out->Info().labels_.SetDevice(device_);
for (auto const& page : out->GetBatches<SparsePage>()) {
page.data.SetDevice(device_);
page.offset.SetDevice(device_);
}
}
return out;
}

Expand Down
12 changes: 7 additions & 5 deletions tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ TEST(CpuPredictor, Basic) {
}

// Test predict leaf
std::vector<float> leaf_out_predictions;
HostDeviceVector<float> leaf_out_predictions;
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
for (auto v : leaf_out_predictions) {
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
for (auto v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}

Expand Down Expand Up @@ -108,10 +109,11 @@ TEST(CpuPredictor, ExternalMemory) {
}

// Test predict leaf
std::vector<float> leaf_out_predictions;
HostDeviceVector<float> leaf_out_predictions;
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
ASSERT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_);
for (const auto& v : leaf_out_predictions) {
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_);
for (const auto& v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}

Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ TEST(GPUPredictor, ShapStump) {
EXPECT_EQ(phis[4], 0.0);
EXPECT_EQ(phis[5], param.base_score);
}

TEST(GPUPredictor, Shap) {
LearnerModelParam param;
param.num_feature = 1;
Expand Down Expand Up @@ -224,5 +225,28 @@ TEST(GPUPredictor, Shap) {
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}

TEST(GPUPredictor, PredictLeafBasic) {
size_t constexpr kRows = 5, kCols = 5;
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({});

LearnerModelParam param;
param.num_feature = kCols;
param.base_score = 0.0;
param.num_output_group = 1;

gbm::GBTreeModel model = CreateTestModel(&param);

HostDeviceVector<float> leaf_out_predictions;
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
for (auto v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}
}
} // namespace predictor
} // namespace xgboost
4 changes: 4 additions & 0 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
sys.path.append("tests/python")
import testing as tm
from test_predict import run_threaded_predict # noqa
from test_predict import run_predict_leaf # noqa

rng = np.random.RandomState(1994)

Expand Down Expand Up @@ -222,3 +223,6 @@ def test_shap_interactions(self, num_rounds, dataset, param):
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
1e-3, 1e-3)

def test_predict_leaf(self):
run_predict_leaf('gpu_predictor')
38 changes: 38 additions & 0 deletions tests/python/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,44 @@ def run_threaded_predict(X, rows, predict_func):
assert f.result()


def run_predict_leaf(predictor):
rows = 100
cols = 4
classes = 5
num_parallel_tree = 4
num_boost_round = 10
rng = np.random.RandomState(1994)
X = rng.randn(rows, cols)
y = rng.randint(low=0, high=classes, size=rows)
m = xgb.DMatrix(X, y)
booster = xgb.train(
{'num_parallel_tree': num_parallel_tree, 'num_class': classes,
'predictor': predictor}, m,
num_boost_round=num_boost_round)

leaf = booster.predict(m, pred_leaf=True)
assert leaf.shape[0] == rows
assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round

for i in range(rows):
row = leaf[i, ...]
for j in range(num_boost_round):
start = classes * num_parallel_tree * j
end = classes * num_parallel_tree * (j + 1)
layer = row[start: end]
for c in range(classes):
tree_group = layer[c * num_parallel_tree:
(c+1) * num_parallel_tree]
assert tree_group.shape[0] == num_parallel_tree
# no subsampling so tree in same forest should output same
# leaf.
assert np.all(tree_group == tree_group[0])


def test_predict_leaf():
run_predict_leaf('cpu_predictor')


class TestInplacePredict(unittest.TestCase):
'''Tests for running inplace prediction'''
def test_predict(self):
Expand Down

0 comments on commit 69036af

Please sign in to comment.