Skip to content

Commit

Permalink
Refactoring monotone constraints (linked to #2305) (#2717)
Browse files Browse the repository at this point in the history
* Move monotone constraints to the monotone_constraints files.

* Add checks for debug mode.

* Refactored FindBestSplitsFromHistograms.

* Add headers.

* fix

* Update data_parallel_tree_learner.cpp

* simplify ComputeBestSplitForFeature

* Fix min / max issue.

* Remove duplicated check.

Co-authored-by: Guolin Ke <[email protected]>
  • Loading branch information
CharlesAuguste and guolinke authored Feb 10, 2020
1 parent d8a34df commit 3670e47
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 248 deletions.
6 changes: 6 additions & 0 deletions src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
}
}
Expand All @@ -570,6 +573,9 @@ std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_ty
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
#ifdef DEBUG
CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
}
}
Expand Down
60 changes: 24 additions & 36 deletions src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,67 +187,55 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->train_data_->FixHistogram(feature_index,
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split;
// find best threshold for smaller child
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_split);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid] && smaller_node_used_features[feature_index]) {
smaller_bests_per_thread[tid] = smaller_split;
}

this->ComputeBestSplitForFeature(
this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
this->smaller_leaf_splits_.get(),
&smaller_bests_per_thread[tid]);

// only root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) continue;
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;

// construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
this->larger_leaf_histogram_array_[feature_index].Subtract(
this->smaller_leaf_histogram_array_[feature_index]);
SplitInfo larger_split;
// find best threshold for larger child
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold(
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_split);
larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid] && larger_node_used_features[feature_index]) {
larger_bests_per_thread[tid] = larger_split;
}

this->ComputeBestSplitForFeature(
this->larger_leaf_histogram_array_, feature_index, real_feature_index,
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
this->larger_leaf_splits_.get(),
&larger_bests_per_thread[tid]);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();

auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
int leaf = this->smaller_leaf_splits_->leaf_index();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];

if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
}

SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
if (this->larger_leaf_splits_->leaf_index() >= 0) {
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
}

// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);

// set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
if (this->larger_leaf_splits_->leaf_index() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
}
}

Expand Down
Loading

0 comments on commit 3670e47

Please sign in to comment.