From 2b24411d7344053c10c6434bdd1015a1af4596e2 Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 5 Jul 2020 02:18:12 +0800 Subject: [PATCH 1/5] Implement GPU predict leaf. --- include/xgboost/gbm.h | 2 +- include/xgboost/predictor.h | 6 +- src/gbm/gblinear.cc | 4 +- src/gbm/gbtree.h | 5 +- src/learner.cc | 2 +- src/predictor/cpu_predictor.cc | 4 +- src/predictor/gpu_predictor.cu | 140 ++++++++++++++++++---- tests/cpp/helpers.cc | 7 ++ tests/cpp/predictor/test_cpu_predictor.cc | 12 +- tests/cpp/predictor/test_gpu_predictor.cu | 24 ++++ tests/python-gpu/test_gpu_prediction.py | 38 ++++++ tests/python/test_predict.py | 43 +++++++ 12 files changed, 246 insertions(+), 41 deletions(-) diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 20b2fbf11218..389593d7f098 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -144,7 +144,7 @@ class GradientBooster : public Model, public Configurable { * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear */ virtual void PredictLeaf(DMatrix* dmat, - std::vector* out_preds, + HostDeviceVector* out_preds, unsigned ntree_limit = 0) = 0; /*! diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 9e448c2f1645..42a5275e1dbb 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -164,10 +164,6 @@ class Predictor { unsigned ntree_limit = 0) = 0; /** - * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, - * std::vector* out_preds, const gbm::GBTreeModel& model, unsigned - * ntree_limit = 0) = 0; - * * \brief predict the leaf index of each tree, the output will be nsample * * ntree vector this is only valid in gbtree predictor. * @@ -177,7 +173,7 @@ class Predictor { * \param ntree_limit (Optional) The ntree limit. */ - virtual void PredictLeaf(DMatrix* dmat, std::vector* out_preds, + virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 81de5bea6a5e..b6ba17269c23 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -147,9 +147,7 @@ class GBLinear : public GradientBooster { } } - void PredictLeaf(DMatrix*, - std::vector*, - unsigned) override { + void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned) override { LOG(FATAL) << "gblinear does not support prediction of leaf index"; } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index b2a990dbe304..d67f94c2c75b 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -278,10 +278,9 @@ class GBTree : public GradientBooster { } void PredictLeaf(DMatrix* p_fmat, - std::vector* out_preds, + HostDeviceVector* out_preds, unsigned ntree_limit) override { - CHECK(configured_); - cpu_predictor_->PredictLeaf(p_fmat, out_preds, model_, ntree_limit); + this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit); } void PredictContribution(DMatrix* p_fmat, diff --git a/src/learner.cc b/src/learner.cc index 19f06f270209..4e75dd7ea94b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1106,7 +1106,7 @@ class LearnerImpl : public LearnerIO { gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit, approx_contribs); } else if (pred_leaf) { - gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit); + gbm_->PredictLeaf(data.get(), out_preds, ntree_limit); } else { auto local_cache = this->GetPredictionCache(); auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 7da841cc15f4..69fa3c53fa6a 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -345,7 +345,7 @@ class CPUPredictor : public Predictor { } } - void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) override { const int nthread = omp_get_max_threads(); InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_); @@ -355,7 +355,7 @@ class CPUPredictor : public Predictor { if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } - std::vector& preds = *out_preds; + std::vector& preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0431f70bde1b..30f9a81827de 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -78,7 +78,7 @@ struct SparsePageLoader { __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) : use_shared(use_shared), - data(data), + data(data), entry_start(entry_start) { extern __shared__ float _smem[]; smem = _smem; @@ -169,7 +169,7 @@ struct DeviceAdapterLoader { }; template -__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, +__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree, common::Span split_types, common::Span d_cat_ptrs, common::Span d_categories, @@ -201,6 +201,49 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, return tree[nidx].LeafValue(); } +template +__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree, + Loader const& loader) { + bst_node_t nidx = 0; + RegTree::Node n = tree[nidx]; + while (!n.IsLeaf()) { + float fvalue = loader.GetElement(ridx, n.SplitIndex()); + // Missing value + if (isnan(fvalue)) { + nidx = n.DefaultChild(); + n = tree[nidx]; + } else { + if (fvalue < n.SplitCond()) { + nidx = n.LeftChild(); + n = tree[nidx]; + } else { + nidx = n.RightChild(); + n = tree[nidx]; + } + } + } + return nidx; +} + +template +__global__ void PredictLeafKernel(Data data, + common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + size_t tree_begin, size_t tree_end, size_t num_features, + size_t num_rows, size_t entry_start, bool use_shared) { + bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x; + if (ridx >= num_rows) { + return; + } + Loader loader(data, use_shared, num_features, num_rows, entry_start); + for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; + auto leaf = GetLeafIndex(ridx, d_tree, loader); + d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf; + } +} + template __global__ void PredictKernel(Data data, common::Span d_nodes, @@ -437,6 +480,17 @@ void ExtractPaths(dh::device_vector* paths, }); } +namespace { +template +size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) { + size_t shared_memory_bytes = + static_cast(sizeof(float) * cols * kBlockThreads); + if (shared_memory_bytes > max_shared_memory_bytes) { + shared_memory_bytes = 0; + } + return shared_memory_bytes; +} +} // anonymous namespace class GPUPredictor : public xgboost::Predictor { private: @@ -450,13 +504,10 @@ class GPUPredictor : public xgboost::Predictor { size_t num_rows = batch.Size(); auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - auto shared_memory_bytes = - static_cast(sizeof(float) * num_features * BLOCK_THREADS); - bool use_shared = true; - if (shared_memory_bytes > max_shared_memory_bytes_) { - shared_memory_bytes = 0; - use_shared = false; - } + size_t shared_memory_bytes = + SharedMemoryBytes(num_features, max_shared_memory_bytes_); + bool use_shared = shared_memory_bytes != 0; + size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); @@ -608,13 +659,9 @@ class GPUPredictor : public xgboost::Predictor { const uint32_t BLOCK_THREADS = 128; auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); - auto shared_memory_bytes = - static_cast(sizeof(float) * m->NumColumns() * BLOCK_THREADS); - bool use_shared = true; - if (shared_memory_bytes > max_shared_memory_bytes) { - shared_memory_bytes = 0; - use_shared = false; - } + size_t shared_memory_bytes = + SharedMemoryBytes(info.num_col_, max_shared_memory_bytes_); + bool use_shared = shared_memory_bytes != 0; size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( @@ -780,11 +827,62 @@ class GPUPredictor : public xgboost::Predictor { << " is not implemented in GPU Predictor."; } - void PredictLeaf(DMatrix*, std::vector*, - const gbm::GBTreeModel&, - unsigned) override { - LOG(FATAL) << "[Internal error]: " << __func__ - << " is not implemented in GPU Predictor."; + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, + const gbm::GBTreeModel& model, + unsigned ntree_limit) override { + const MetaInfo& info = p_fmat->Info(); + constexpr uint32_t kBlockThreads = 128; + size_t shared_memory_bytes = + SharedMemoryBytes(info.num_col_, max_shared_memory_bytes_); + bool use_shared = shared_memory_bytes != 0; + bst_feature_t num_features = info.num_col_; + bst_row_t num_rows = info.num_row_; + size_t entry_start = 0; + + uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group; + if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { + real_ntree_limit = static_cast(model.trees.size()); + } + predictions->SetDevice(generic_param_->gpu_id); + predictions->Resize(num_rows * real_ntree_limit); + model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id); + + if (p_fmat->PageExists()) { + for (auto const& batch : p_fmat->GetBatches()) { + batch.data.SetDevice(generic_param_->gpu_id); + batch.offset.SetDevice(generic_param_->gpu_id); + bst_row_t batch_offset = 0; + SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(), + model.learner_model_param->num_feature}; + size_t num_rows = batch.Size(); + auto grid = + static_cast(common::DivRoundUp(num_rows, kBlockThreads)); + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( + PredictLeafKernel, data, + model_.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), + model_.tree_beg_, model_.tree_end_, num_features, num_rows, + entry_start, use_shared); + batch_offset += batch.Size(); + } + } else { + for (auto const& batch : p_fmat->GetBatches()) { + bst_row_t batch_offset = 0; + EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)}; + size_t num_rows = batch.Size(); + auto grid = + static_cast(common::DivRoundUp(num_rows, kBlockThreads)); + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( + PredictLeafKernel, data, + model_.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), + model_.tree_beg_, model_.tree_end_, num_features, num_rows, + entry_start, use_shared); + batch_offset += batch.Size(); + } + } } void Configure(const std::vector>& cfg) override { diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 1b319a8873ff..9191dea61eae 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -348,6 +348,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, gen.GenerateDense(&out->Info().labels_); } } + if (device_ >= 0) { + out->Info().labels_.SetDevice(device_); + for (auto const& page : out->GetBatches()) { + page.data.SetDevice(device_); + page.offset.SetDevice(device_); + } + } return out; } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 4191682eb763..63242f68d636 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -46,9 +46,10 @@ TEST(CpuPredictor, Basic) { } // Test predict leaf - std::vector leaf_out_predictions; + HostDeviceVector leaf_out_predictions; cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); - for (auto v : leaf_out_predictions) { + auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); + for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); } @@ -108,10 +109,11 @@ TEST(CpuPredictor, ExternalMemory) { } // Test predict leaf - std::vector leaf_out_predictions; + HostDeviceVector leaf_out_predictions; cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); - ASSERT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_); - for (const auto& v : leaf_out_predictions) { + auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); + ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_); + for (const auto& v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index b48e490864b1..b0d9aff5bfe2 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -190,6 +190,7 @@ TEST(GPUPredictor, ShapStump) { EXPECT_EQ(phis[4], 0.0); EXPECT_EQ(phis[5], param.base_score); } + TEST(GPUPredictor, Shap) { LearnerModelParam param; param.num_feature = 1; @@ -224,5 +225,28 @@ TEST(GPUPredictor, Shap) { TEST(GPUPredictor, CategoricalPrediction) { TestCategoricalPrediction("gpu_predictor"); } + +TEST(GPUPredictor, PredictLeafBasic) { + size_t constexpr kRows = 5, kCols = 5; + auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix(); + auto lparam = CreateEmptyGenericParam(GPUIDX); + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + gpu_predictor->Configure({}); + + LearnerModelParam param; + param.num_feature = kCols; + param.base_score = 0.0; + param.num_output_group = 1; + + gbm::GBTreeModel model = CreateTestModel(¶m); + + HostDeviceVector leaf_out_predictions; + gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); + auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); + for (auto v : h_leaf_out_predictions) { + ASSERT_EQ(v, 0); + } +} } // namespace predictor } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index be7c1fdbe4c9..8cbf5e04f52d 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -1,4 +1,5 @@ import sys +import json import unittest import pytest @@ -9,6 +10,7 @@ sys.path.append("tests/python") import testing as tm from test_predict import run_threaded_predict # noqa +from test_predict import run_predict_leaf # noqa rng = np.random.RandomState(1994) @@ -18,6 +20,11 @@ 'num_parallel_tree': strategies.sampled_from([1, 10]), }).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0) +predict_parameter_strategy = strategies.fixed_dictionaries({ + 'max_depth': strategies.integers(1, 8), + 'num_parallel_tree': strategies.sampled_from([1, 4]), +}) + class TestGPUPredict(unittest.TestCase): def test_predict(self): @@ -223,3 +230,34 @@ def test_shap_interactions(self, num_rounds, dataset, param): assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin, 1e-3, 1e-3) + + def test_predict_leaf_basic(self): + gpu_leaf = run_predict_leaf('gpu_predictor') + cpu_leaf = run_predict_leaf('cpu_predictor') + np.testing.assert_equal(gpu_leaf, cpu_leaf) + + def run_predict_leaf_booster(self, param, num_rounds, dataset): + param = dataset.set_params(param) + m = dataset.get_dmat() + booster = xgb.train(param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds) + booster.set_param({'predictor': 'cpu_predictor'}) + cpu_leaf = booster.predict(m, pred_leaf=True) + + booster.set_param({'predictor': 'gpu_predictor'}) + gpu_leaf = booster.predict(m, pred_leaf=True) + + np.testing.assert_equal(cpu_leaf, gpu_leaf) + + @given(predict_parameter_strategy, tm.dataset_strategy) + @settings(deadline=None) + def test_predict_leaf_gbtree(self, param, dataset): + param['booster'] = 'gbtree' + param['tree_method'] = 'gpu_hist' + self.run_predict_leaf_booster(param, 10, dataset) + + @given(predict_parameter_strategy, tm.dataset_strategy) + @settings(deadline=None) + def test_predict_leaf_dart(self, param, dataset): + param['booster'] = 'dart' + param['tree_method'] = 'gpu_hist' + self.run_predict_leaf_booster(param, 10, dataset) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 1fff11d9e150..ddde67022dea 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -23,6 +23,49 @@ def run_threaded_predict(X, rows, predict_func): assert f.result() +def run_predict_leaf(predictor): + rows = 100 + cols = 4 + classes = 5 + num_parallel_tree = 4 + num_boost_round = 10 + rng = np.random.RandomState(1994) + X = rng.randn(rows, cols) + y = rng.randint(low=0, high=classes, size=rows) + m = xgb.DMatrix(X, y) + booster = xgb.train( + {'num_parallel_tree': num_parallel_tree, 'num_class': classes, + 'predictor': predictor, 'tree_method': 'hist'}, m, + num_boost_round=num_boost_round) + + empty = xgb.DMatrix(np.ones(shape=(0, cols))) + empty_leaf = booster.predict(empty, pred_leaf=True) + assert empty_leaf.shape[0] == 0 + + leaf = booster.predict(m, pred_leaf=True) + assert leaf.shape[0] == rows + assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round + + for i in range(rows): + row = leaf[i, ...] + for j in range(num_boost_round): + start = classes * num_parallel_tree * j + end = classes * num_parallel_tree * (j + 1) + layer = row[start: end] + for c in range(classes): + tree_group = layer[c * num_parallel_tree: + (c+1) * num_parallel_tree] + assert tree_group.shape[0] == num_parallel_tree + # no subsampling so tree in same forest should output same + # leaf. + assert np.all(tree_group == tree_group[0]) + return leaf + + +def test_predict_leaf(): + run_predict_leaf('cpu_predictor') + + class TestInplacePredict(unittest.TestCase): '''Tests for running inplace prediction''' def test_predict(self): From c96e97b2a748dde2804d62bf9bbb74649363aad1 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 6 Nov 2020 15:21:35 +0800 Subject: [PATCH 2/5] Set device. --- src/predictor/gpu_predictor.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 30f9a81827de..5956e5e8d264 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -830,6 +830,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, const gbm::GBTreeModel& model, unsigned ntree_limit) override { + dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); const MetaInfo& info = p_fmat->Info(); constexpr uint32_t kBlockThreads = 128; size_t shared_memory_bytes = From ae83622d5b150ee9ce8d0b1649c3b1cec4cbe962 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 11 Nov 2020 02:59:04 +0800 Subject: [PATCH 3/5] Configure device. --- src/predictor/gpu_predictor.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 5956e5e8d264..09c3d8ef8a20 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -831,6 +831,8 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel& model, unsigned ntree_limit) override { dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); + ConfigureDevice(generic_param_->gpu_id); + const MetaInfo& info = p_fmat->Info(); constexpr uint32_t kBlockThreads = 128; size_t shared_memory_bytes = From fb8e414d7b3a4277ef69171c26059c1a9fe27ba4 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 11 Nov 2020 03:16:15 +0800 Subject: [PATCH 4/5] Correct initialization. --- src/predictor/gpu_predictor.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 09c3d8ef8a20..9b18b3763cf6 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -483,6 +483,8 @@ void ExtractPaths(dh::device_vector* paths, namespace { template size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) { + // No way max_shared_memory_bytes that is equal to 0. + CHECK_GT(max_shared_memory_bytes, 0); size_t shared_memory_bytes = static_cast(sizeof(float) * cols * kBlockThreads); if (shared_memory_bytes > max_shared_memory_bytes) { @@ -902,7 +904,7 @@ class GPUPredictor : public xgboost::Predictor { std::mutex lock_; DeviceModel model_; - size_t max_shared_memory_bytes_; + size_t max_shared_memory_bytes_ { 0 }; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") From 3550f0ea8816a7d62182dec829cc3d5d2dff17d9 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 11 Nov 2020 12:43:51 +0800 Subject: [PATCH 5/5] Fix. --- src/predictor/gpu_predictor.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 9b18b3763cf6..8bbd814bce93 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -662,7 +662,7 @@ class GPUPredictor : public xgboost::Predictor { auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); size_t shared_memory_bytes = - SharedMemoryBytes(info.num_col_, max_shared_memory_bytes_); + SharedMemoryBytes(info.num_col_, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; size_t entry_start = 0;