Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GPU predict leaf. #6187

Merged
merged 5 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class GradientBooster : public Model, public Configurable {
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void PredictLeaf(DMatrix* dmat,
std::vector<bst_float>* out_preds,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;

/*!
Expand Down
6 changes: 1 addition & 5 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ class Predictor {
unsigned ntree_limit = 0) = 0;

/**
* \fn virtual void Predictor::PredictLeaf(DMatrix* dmat,
* std::vector<bst_float>* out_preds, const gbm::GBTreeModel& model, unsigned
* ntree_limit = 0) = 0;
*
* \brief predict the leaf index of each tree, the output will be nsample *
* ntree vector this is only valid in gbtree predictor.
*
Expand All @@ -177,7 +173,7 @@ class Predictor {
* \param ntree_limit (Optional) The ntree limit.
*/

virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) = 0;

Expand Down
4 changes: 1 addition & 3 deletions src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ class GBLinear : public GradientBooster {
}
}

void PredictLeaf(DMatrix*,
std::vector<bst_float>*,
unsigned) override {
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned) override {
LOG(FATAL) << "gblinear does not support prediction of leaf index";
}

Expand Down
5 changes: 2 additions & 3 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,9 @@ class GBTree : public GradientBooster {
}

void PredictLeaf(DMatrix* p_fmat,
std::vector<bst_float>* out_preds,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override {
CHECK(configured_);
cpu_predictor_->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
}

void PredictContribution(DMatrix* p_fmat,
Expand Down
2 changes: 1 addition & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ class LearnerImpl : public LearnerIO {
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
} else {
auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
Expand Down
4 changes: 2 additions & 2 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class CPUPredictor : public Predictor {
}
}

void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
Expand All @@ -355,7 +355,7 @@ class CPUPredictor : public Predictor {
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}
std::vector<bst_float>& preds = *out_preds;
std::vector<bst_float>& preds = out_preds->HostVector();
preds.resize(info.num_row_ * ntree_limit);
// start collecting the prediction
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
Expand Down
147 changes: 125 additions & 22 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct SparsePageLoader {
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start)
: use_shared(use_shared),
data(data),
data(data),
entry_start(entry_start) {
extern __shared__ float _smem[];
smem = _smem;
Expand Down Expand Up @@ -169,7 +169,7 @@ struct DeviceAdapterLoader {
};

template <typename Loader>
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
common::Span<FeatureType const> split_types,
common::Span<RegTree::Segment const> d_cat_ptrs,
common::Span<uint32_t const> d_categories,
Expand Down Expand Up @@ -201,6 +201,49 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
return tree[nidx].LeafValue();
}

template <typename Loader>
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
Loader const& loader) {
bst_node_t nidx = 0;
RegTree::Node n = tree[nidx];
while (!n.IsLeaf()) {
float fvalue = loader.GetElement(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
nidx = n.DefaultChild();
n = tree[nidx];
} else {
if (fvalue < n.SplitCond()) {
nidx = n.LeftChild();
n = tree[nidx];
} else {
nidx = n.RightChild();
n = tree[nidx];
}
}
}
return nidx;
}

template <typename Loader, typename Data>
__global__ void PredictLeafKernel(Data data,
common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t const> d_tree_segments,
size_t tree_begin, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start, bool use_shared) {
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
if (ridx >= num_rows) {
return;
}
Loader loader(data, use_shared, num_features, num_rows, entry_start);
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
auto leaf = GetLeafIndex(ridx, d_tree, loader);
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
}
}

template <typename Loader, typename Data>
__global__ void
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
Expand Down Expand Up @@ -437,6 +480,19 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
});
}

namespace {
template <size_t kBlockThreads>
size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
// No way max_shared_memory_bytes that is equal to 0.
CHECK_GT(max_shared_memory_bytes, 0);
size_t shared_memory_bytes =
static_cast<size_t>(sizeof(float) * cols * kBlockThreads);
if (shared_memory_bytes > max_shared_memory_bytes) {
shared_memory_bytes = 0;
}
return shared_memory_bytes;
}
} // anonymous namespace

class GPUPredictor : public xgboost::Predictor {
private:
Expand All @@ -450,13 +506,10 @@ class GPUPredictor : public xgboost::Predictor {
size_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));

auto shared_memory_bytes =
static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes_) {
shared_memory_bytes = 0;
use_shared = false;
}
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes_);
bool use_shared = shared_memory_bytes != 0;

size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
Expand Down Expand Up @@ -608,13 +661,9 @@ class GPUPredictor : public xgboost::Predictor {
const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));

auto shared_memory_bytes =
static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes) {
shared_memory_bytes = 0;
use_shared = false;
}
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0;

dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
Expand Down Expand Up @@ -780,11 +829,65 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}

void PredictLeaf(DMatrix*, std::vector<bst_float>*,
const gbm::GBTreeModel&,
unsigned) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
const gbm::GBTreeModel& model,
unsigned ntree_limit) override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
ConfigureDevice(generic_param_->gpu_id);

const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes =
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes_);
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
bst_row_t num_rows = info.num_row_;
size_t entry_start = 0;

uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}
predictions->SetDevice(generic_param_->gpu_id);
predictions->Resize(num_rows * real_ntree_limit);
model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id);

if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
bst_row_t batch_offset = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
model_.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared);
batch_offset += batch.Size();
}
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
bst_row_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
model_.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared);
batch_offset += batch.Size();
}
}
}

void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Expand All @@ -801,7 +904,7 @@ class GPUPredictor : public xgboost::Predictor {

std::mutex lock_;
DeviceModel model_;
size_t max_shared_memory_bytes_;
size_t max_shared_memory_bytes_ { 0 };
};

XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
Expand Down
7 changes: 7 additions & 0 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
gen.GenerateDense(&out->Info().labels_);
}
}
if (device_ >= 0) {
out->Info().labels_.SetDevice(device_);
for (auto const& page : out->GetBatches<SparsePage>()) {
page.data.SetDevice(device_);
page.offset.SetDevice(device_);
}
}
return out;
}

Expand Down
12 changes: 7 additions & 5 deletions tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ TEST(CpuPredictor, Basic) {
}

// Test predict leaf
std::vector<float> leaf_out_predictions;
HostDeviceVector<float> leaf_out_predictions;
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
for (auto v : leaf_out_predictions) {
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
for (auto v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}

Expand Down Expand Up @@ -108,10 +109,11 @@ TEST(CpuPredictor, ExternalMemory) {
}

// Test predict leaf
std::vector<float> leaf_out_predictions;
HostDeviceVector<float> leaf_out_predictions;
cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
ASSERT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_);
for (const auto& v : leaf_out_predictions) {
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_);
for (const auto& v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}

Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ TEST(GPUPredictor, ShapStump) {
EXPECT_EQ(phis[4], 0.0);
EXPECT_EQ(phis[5], param.base_score);
}

TEST(GPUPredictor, Shap) {
LearnerModelParam param;
param.num_feature = 1;
Expand Down Expand Up @@ -224,5 +225,28 @@ TEST(GPUPredictor, Shap) {
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}

TEST(GPUPredictor, PredictLeafBasic) {
size_t constexpr kRows = 5, kCols = 5;
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({});

LearnerModelParam param;
param.num_feature = kCols;
param.base_score = 0.0;
param.num_output_group = 1;

gbm::GBTreeModel model = CreateTestModel(&param);

HostDeviceVector<float> leaf_out_predictions;
gpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model);
auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector();
for (auto v : h_leaf_out_predictions) {
ASSERT_EQ(v, 0);
}
}
} // namespace predictor
} // namespace xgboost
Loading