Skip to content

Commit

Permalink
fix subset bug (#2748)
Browse files Browse the repository at this point in the history
* fix subset bug

* typo

* add fixme tag

* bin mapper

* fix test

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
guolinke and StrikerRUS authored Feb 10, 2020
1 parent 1c1a276 commit d8a34df
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 29 deletions.
72 changes: 53 additions & 19 deletions include/LightGBM/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,21 @@ class FeatureGroup {
num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_);
}
if (is_multi_val_) {
multi_bin_data_.clear();
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1;
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
multi_bin_data_.emplace_back(Bin::CreateSparseBin(num_data, bin_mappers_[i]->num_bin() + addi));
} else {
multi_bin_data_.emplace_back(Bin::CreateDenseBin(num_data, bin_mappers_[i]->num_bin() + addi));
}
}
} else {
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
CreateBinData(num_data, is_multi_val_, true, false);
}

FeatureGroup(const FeatureGroup& other, int num_data) {
num_feature_ = other.num_feature_;
is_multi_val_ = other.is_multi_val_;
is_sparse_ = other.is_sparse_;
num_total_bin_ = other.num_total_bin_;
bin_offsets_ = other.bin_offsets_;

bin_mappers_.reserve(other.bin_mappers_.size());
for (auto& bin_mapper : other.bin_mappers_) {
bin_mappers_.emplace_back(new BinMapper(*bin_mapper));
}
CreateBinData(num_data, is_multi_val_, !is_sparse_, is_sparse_);
}

FeatureGroup(std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
Expand All @@ -76,13 +78,7 @@ class FeatureGroup {
num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_);
}
if (bin_mappers_[0]->sparse_rate() >= kSparseThreshold) {
is_sparse_ = true;
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
is_sparse_ = false;
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
CreateBinData(num_data, false, false, false);
}

/*!
Expand Down Expand Up @@ -167,6 +163,16 @@ class FeatureGroup {
}
}

void ReSize(int num_data) {
if (!is_multi_val_) {
bin_data_->ReSize(num_data);
} else {
for (int i = 0; i < num_feature_; ++i) {
multi_bin_data_[i]->ReSize(num_data);
}
}
}

inline void CopySubset(const FeatureGroup* full_feature, const data_size_t* used_indices, data_size_t num_used_indices) {
if (!is_multi_val_) {
bin_data_->CopySubset(full_feature->bin_data_.get(), used_indices, num_used_indices);
Expand Down Expand Up @@ -327,6 +333,34 @@ class FeatureGroup {
}

private:

void CreateBinData(int num_data, bool is_multi_val, bool force_dense, bool force_sparse) {
if (is_multi_val) {
multi_bin_data_.clear();
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1;
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
multi_bin_data_.emplace_back(Bin::CreateSparseBin(
num_data, bin_mappers_[i]->num_bin() + addi));
} else {
multi_bin_data_.emplace_back(
Bin::CreateDenseBin(num_data, bin_mappers_[i]->num_bin() + addi));
}
}
is_multi_val_ = true;
} else {
if (force_sparse || (!force_dense && num_feature_ == 1 &&
bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
is_sparse_ = true;
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
is_sparse_ = false;
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
is_multi_val_ = false;
}
}

/*! \brief Number of features */
int num_feature_;
/*! \brief Bin mapper for sub features */
Expand Down
14 changes: 4 additions & 10 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,15 +720,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
num_groups_ = dataset->num_groups_;
// copy feature bin mapper data
for (int i = 0; i < num_groups_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers;
for (int j = 0; j < dataset->feature_groups_[i]->num_feature_; ++j) {
bin_mappers.emplace_back(new BinMapper(*(dataset->feature_groups_[i]->bin_mappers_[j])));
}
feature_groups_.emplace_back(new FeatureGroup(
dataset->feature_groups_[i]->num_feature_,
dataset->feature_groups_[i]->is_multi_val_,
&bin_mappers,
num_data_));
feature_groups_.emplace_back(new FeatureGroup(*dataset->feature_groups_[i], num_data_));
}
feature_groups_.shrink_to_fit();
used_feature_map_ = dataset->used_feature_map_;
Expand Down Expand Up @@ -806,7 +798,7 @@ void Dataset::ReSize(data_size_t num_data) {
#pragma omp parallel for schedule(static)
for (int group = 0; group < num_groups_; ++group) {
OMP_LOOP_EX_BEGIN();
feature_groups_[group]->bin_data_->ReSize(num_data_);
feature_groups_[group]->ReSize(num_data_);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -1399,6 +1391,8 @@ void Dataset::AddFeaturesFrom(Dataset* other) {
PushVector(&group_feature_cnt_, other->group_feature_cnt_);
PushVector(&forced_bin_bounds_, other->forced_bin_bounds_);
feature_groups_.reserve(other->feature_groups_.size());
// FIXME: fix the multiple multi-val feature groups, they need to be merged
// into one multi-val group
for (auto& fg : other->feature_groups_) {
feature_groups_.emplace_back(new FeatureGroup(*fg));
}
Expand Down

0 comments on commit d8a34df

Please sign in to comment.