Skip to content

Commit

Permalink
change handling of the cuts for the missing bin to the beginning of t…
Browse files Browse the repository at this point in the history
…he flattening process
  • Loading branch information
paulbkoch committed Dec 24, 2024
1 parent ebf134d commit 9b723a0
Showing 1 changed file with 50 additions and 53 deletions.
103 changes: 50 additions & 53 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
template<bool bHessian>
static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
bool bExtraMissingCut,
const bool bMissing,
const bool bNominal,
const TermBoostFlags flags,
const FloatCalc regAlpha,
Expand All @@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
EBM_ASSERT(2 <= cBins);
EBM_ASSERT(cSlices <= cBins);
EBM_ASSERT(!bNominal || cSlices == cBins);
EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere

ErrorEbm error;

Expand Down Expand Up @@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
} else {
pUpdateScore = aUpdateScore;

if(nullptr != pMissingValueTreeNode || bExtraMissingCut) {
if(bMissing) {
// always put a split on the missing bin
*pSplit = 1;
++pSplit;
Expand Down Expand Up @@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
}
}

if(bExtraMissingCut) {
EBM_ASSERT(!bNominal); // for Nominal we cut everywhere
if(TermBoostFlags_MissingLow & flags) {
if(nullptr == pMissingBin) {
pMissingBin = pTreeNode->GetBin();
}
} else {
EBM_ASSERT(TermBoostFlags_MissingHigh & flags);
pMissingBin = pTreeNode->GetBin();
}
}

EBM_ASSERT(apBins <= ppBinLast);
EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0})));

Expand All @@ -266,48 +253,57 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
}
EBM_ASSERT(!bNominal);

// for Nominal we cut everywhere
if(TermBoostFlags_MissingLow & flags) {
if(nullptr == pMissingBin) {
pMissingBin = pTreeNode->GetBin();
}
}

// if !bNominal, check the bin above and below for order
EBM_ASSERT(apBins == ppBinLast || *(ppBinLast - 1) < *ppBinLast);
EBM_ASSERT(ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{2} : size_t{1})) ||
*ppBinLast < *(ppBinLast + 1));

iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);

while(true) {
iScore = 0;
do {
FloatCalc updateScore;
if(bUpdateWithHessian) {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
regAlpha,
regLambda,
deltaStepMax);
} else {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
regAlpha,
regLambda,
deltaStepMax);
}
if(!bMissing || !(TermBoostFlags_MissingLow & flags) || ((TermBoostFlags_MissingLow & flags) && 1 != iEdge)) {
while(true) {
iScore = 0;
do {
FloatCalc updateScore;
if(bUpdateWithHessian) {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
regAlpha,
regLambda,
deltaStepMax);
} else {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
regAlpha,
regLambda,
deltaStepMax);
}

*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;
*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;

++iScore;
} while(cScores != iScore);
if(nullptr == ppBinCur) {
break;
}
EBM_ASSERT(bNominal);
++ppBinCur;
if(ppBinLast < ppBinCur) {
break;
++iScore;
} while(cScores != iScore);
if(nullptr == ppBinCur) {
break;
}
EBM_ASSERT(bNominal);
++ppBinCur;
if(ppBinLast < ppBinCur) {
break;
}
determine_bin:;
const auto* const pBinCur = *ppBinCur;
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
pUpdateScore = aUpdateScore + iBin * cScores;
}
determine_bin:;
const auto* const pBinCur = *ppBinCur;
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
pUpdateScore = aUpdateScore + iBin * cScores;
}

pTreeNode = pParent;
Expand Down Expand Up @@ -345,9 +341,12 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
if(!pTreeNode->DECONSTRUCT_IsRightChildTraversal()) {
// we checked earlier that countBins could be converted to a UIntSplit
if(nullptr == ppBinCur) {
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
*pSplit = static_cast<UIntSplit>(iEdge);
++pSplit;
if(!bMissing || !(TermBoostFlags_MissingLow & flags) ||
((TermBoostFlags_MissingLow & flags) && 1 != iEdge)) {
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
*pSplit = static_cast<UIntSplit>(iEdge);
++pSplit;
}
}
pParent = pTreeNode;
pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark(cBytesPerTreeNode);
Expand Down Expand Up @@ -1035,21 +1034,19 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
*pTotalGain = static_cast<double>(totalGain);

size_t cSlices = cSplitsMax - cSplitsRemaining + 1;
bool bExtraMissingCut = false;
if(nullptr != pMissingValueTreeNode) {
EBM_ASSERT(nullptr == pMissingBin);
++cSlices;
} else {
if(nullptr != pMissingBin && !bMissingIsolated) {
bExtraMissingCut = true;
++cSlices;
}
}
if(bNominal) {
cSlices = cBins;
}
const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
bExtraMissingCut,
bMissing,
bNominal,
flags,
regAlpha,
Expand Down

0 comments on commit 9b723a0

Please sign in to comment.