Skip to content

Commit

Permalink
separate flags for nominal and continuous missing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 26, 2024
1 parent 50ab4ba commit a12b61a
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 40 deletions.
17 changes: 12 additions & 5 deletions python/interpret-core/interpret/glassbox/_ebm/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,20 @@ def boost(
)

if missing == "low":
term_boost_flags_local |= Native.TermBoostFlags_MissingLow
term_boost_flags_local |= (
Native.TermBoostFlags_MissingLow
| Native.TermBoostFlags_MissingCategory
)
elif missing == "high":
term_boost_flags_local |= Native.TermBoostFlags_MissingHigh
term_boost_flags_local |= (
Native.TermBoostFlags_MissingHigh
| Native.TermBoostFlags_MissingCategory
)
elif missing == "separate":
term_boost_flags_local |= Native.TermBoostFlags_MissingSeparate
elif missing == "drop":
term_boost_flags_local |= Native.TermBoostFlags_MissingDrop
term_boost_flags_local |= (
Native.TermBoostFlags_MissingSeparate
| Native.TermBoostFlags_MissingCategory
)
elif missing != "gain":
msg = f"Unrecognized missing option {missing}."
raise Exception(msg)
Expand Down
4 changes: 0 additions & 4 deletions python/interpret-core/interpret/glassbox/_ebm/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,8 +2784,6 @@ class ExplainableBoostingClassifier(ClassifierMixin, EBMModel):
- `'separate'`: Place the missing bin in its own leaf during each boosting step,
effectively making it location-agnostic. This can lead to overfitting, especially
when the proportion of missing values is small.
- `'drop'`: Ignore the contribution of the missing bin, or split the feature into two leaves based on gain:
one for missing values and one for non-missing values.
- `'gain'`: Choose the best leaf for the missing value contribution at each boosting step, based on gain.
max_leaves : int, default=3
Maximum number of leaves allowed in each tree.
Expand Down Expand Up @@ -3158,8 +3156,6 @@ class ExplainableBoostingRegressor(RegressorMixin, EBMModel):
- `'separate'`: Place the missing bin in its own leaf during each boosting step,
effectively making it location-agnostic. This can lead to overfitting, especially
when the proportion of missing values is small.
- `'drop'`: Ignore the contribution of the missing bin, or split the feature into two leaves based on gain:
one for missing values and one for non-missing values.
- `'gain'`: Choose the best leaf for the missing value contribution at each boosting step, based on gain.
max_leaves : int, default=2
Maximum number of leaves allowed in each tree.
Expand Down
2 changes: 1 addition & 1 deletion python/interpret-core/interpret/utils/_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Native:
TermBoostFlags_MissingLow = 0x00000080
TermBoostFlags_MissingHigh = 0x00000100
TermBoostFlags_MissingSeparate = 0x00000200
TermBoostFlags_MissingDrop = 0x00000400
TermBoostFlags_MissingCategory = 0x00000400

# CreateInteractionFlags
CreateInteractionFlags_Default = 0x00000000
Expand Down
11 changes: 3 additions & 8 deletions shared/libebm/GenerateTermUpdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,22 +749,17 @@ EBM_API_BODY ErrorEbm EBM_CALLING_CONVENTION GenerateTermUpdate(void* rng,
~(TermBoostFlags_PurifyGain | TermBoostFlags_DisableNewtonGain | TermBoostFlags_DisableCategorical |
TermBoostFlags_PurifyUpdate | TermBoostFlags_DisableNewtonUpdate | TermBoostFlags_GradientSums |
TermBoostFlags_RandomSplits | TermBoostFlags_MissingLow | TermBoostFlags_MissingHigh |
TermBoostFlags_MissingSeparate | TermBoostFlags_MissingDrop)) {
TermBoostFlags_MissingSeparate | TermBoostFlags_MissingCategory)) {
LOG_0(Trace_Error, "ERROR GenerateTermUpdate flags contains unknown flags. Ignoring extras.");
}

if(TermBoostFlags_MissingLow & flags) {
if((TermBoostFlags_MissingHigh | TermBoostFlags_MissingSeparate | TermBoostFlags_MissingDrop) & flags) {
if((TermBoostFlags_MissingHigh | TermBoostFlags_MissingSeparate) & flags) {
LOG_0(Trace_Error, "ERROR GenerateTermUpdate flags contains multiple Missing value flags.");
return Error_IllegalParamVal;
}
} else if(TermBoostFlags_MissingHigh & flags) {
if((TermBoostFlags_MissingSeparate | TermBoostFlags_MissingDrop) & flags) {
LOG_0(Trace_Error, "ERROR GenerateTermUpdate flags contains multiple Missing value flags.");
return Error_IllegalParamVal;
}
} else if(TermBoostFlags_MissingSeparate & flags) {
if(TermBoostFlags_MissingDrop & flags) {
if(TermBoostFlags_MissingSeparate & flags) {
LOG_0(Trace_Error, "ERROR GenerateTermUpdate flags contains multiple Missing value flags.");
return Error_IllegalParamVal;
}
Expand Down
41 changes: 27 additions & 14 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,22 +895,35 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
bool bMissingIsolated = false;

const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
if(TermBoostFlags_MissingLow & flags) {
if(bMissing && !bNominal) {
pMissingBin = pBin;
}
} else if(TermBoostFlags_MissingHigh & flags) {
if(bMissing && !bNominal) {
pMissingBin = pBin;
// the concept of TermBoostFlags_MissingHigh does not exist for nominals
pBin = IndexBin(pBin, cBytesPerBin);

if(bNominal) {
if(TermBoostFlags_MissingCategory & flags) {
// nothing to do
} else {
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
}
}
} else {
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
if(TermBoostFlags_MissingLow & flags) {
if(bMissing) {
pMissingBin = pBin;
}
} else if(TermBoostFlags_MissingHigh & flags) {
if(bMissing) {
pMissingBin = pBin;
pBin = IndexBin(pBin, cBytesPerBin);
}
} else {
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion shared/libebm/inc/libebm.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ typedef struct _InteractionHandle {
#define TermBoostFlags_MissingLow (TERM_BOOST_FLAGS_CAST(0x00000080))
#define TermBoostFlags_MissingHigh (TERM_BOOST_FLAGS_CAST(0x00000100))
#define TermBoostFlags_MissingSeparate (TERM_BOOST_FLAGS_CAST(0x00000200))
#define TermBoostFlags_MissingDrop (TERM_BOOST_FLAGS_CAST(0x00000400))
#define TermBoostFlags_MissingCategory (TERM_BOOST_FLAGS_CAST(0x00000400))

#define CreateInteractionFlags_Default (CREATE_INTERACTION_FLAGS_CAST(0x00000000))
#define CreateInteractionFlags_DifferentialPrivacy (CREATE_INTERACTION_FLAGS_CAST(0x00000001))
Expand Down
12 changes: 5 additions & 7 deletions shared/libebm/tests/boosting_unusual_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2096,12 +2096,10 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
TermBoostFlags_PurifyUpdate,
// TermBoostFlags_GradientSums, // does not return a metric
TermBoostFlags_DisableNewtonUpdate,
TermBoostFlags_RandomSplits};
std::vector<IntEbm> boostFlagsChoose{TermBoostFlags_Default,
TermBoostFlags_MissingLow,
TermBoostFlags_MissingHigh,
TermBoostFlags_MissingSeparate,
TermBoostFlags_MissingDrop};
TermBoostFlags_RandomSplits,
TermBoostFlags_MissingCategory};
std::vector<IntEbm> boostFlagsChoose{
TermBoostFlags_Default, TermBoostFlags_MissingLow, TermBoostFlags_MissingHigh, TermBoostFlags_MissingSeparate};

double validationMetric = 1.0;
for(IntEbm classesCount = Task_Regression; classesCount < 5; ++classesCount) {
Expand Down Expand Up @@ -2175,7 +2173,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
}

TEST_CASE("stress test, boosting") {
const double expected = 26746562197367.172;
const double expected = 14939439873840.908;

double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
CHECK(validationMetricExact == expected);
Expand Down

0 comments on commit a12b61a

Please sign in to comment.