Skip to content

Commit

Permalink
change handling of the cuts for the missing bin to the beginning of t…
Browse files Browse the repository at this point in the history
…he flattening process
  • Loading branch information
paulbkoch committed Dec 24, 2024
1 parent 7cea7a3 commit 42875bc
Showing 1 changed file with 74 additions and 65 deletions.
139 changes: 74 additions & 65 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
template<bool bHessian>
static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
bool bExtraMissingCut,
const bool bMissing,
const bool bNominal,
const TermBoostFlags flags,
const FloatCalc regAlpha,
Expand All @@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
EBM_ASSERT(2 <= cBins);
EBM_ASSERT(cSlices <= cBins);
EBM_ASSERT(!bNominal || cSlices == cBins);
EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere

ErrorEbm error;

Expand Down Expand Up @@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
} else {
pUpdateScore = aUpdateScore;

if(nullptr != pMissingValueTreeNode || bExtraMissingCut) {
if(bMissing) {
// always put a split on the missing bin
*pSplit = 1;
++pSplit;
Expand Down Expand Up @@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
}
}

if(bExtraMissingCut) {
EBM_ASSERT(!bNominal); // for Nominal we cut everywhere
if(TermBoostFlags_MissingLow & flags) {
if(nullptr == pMissingBin) {
pMissingBin = pTreeNode->GetBin();
}
} else {
EBM_ASSERT(TermBoostFlags_MissingHigh & flags);
pMissingBin = pTreeNode->GetBin();
}
}

EBM_ASSERT(apBins <= ppBinLast);
EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0})));

Expand All @@ -273,41 +260,56 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,

iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);

while(true) {
iScore = 0;
do {
FloatCalc updateScore;
if(bUpdateWithHessian) {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
regAlpha,
regLambda,
deltaStepMax);
} else {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
regAlpha,
regLambda,
deltaStepMax);
while(true) { // not a real loop
if(bMissing) {
if(TermBoostFlags_MissingLow & flags) {
if(nullptr == pMissingBin) {
pMissingBin = pTreeNode->GetBin();
}
if(1 == iEdge) {
break;
}
}
}

*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;
while(true) {
iScore = 0;
do {
FloatCalc updateScore;
if(bUpdateWithHessian) {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
regAlpha,
regLambda,
deltaStepMax);
} else {
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
regAlpha,
regLambda,
deltaStepMax);
}

++iScore;
} while(cScores != iScore);
if(nullptr == ppBinCur) {
break;
}
EBM_ASSERT(bNominal);
++ppBinCur;
if(ppBinLast < ppBinCur) {
break;
*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;

++iScore;
} while(cScores != iScore);
if(nullptr == ppBinCur) {
break;
}
EBM_ASSERT(bNominal);
++ppBinCur;
if(ppBinLast < ppBinCur) {
break;
}
determine_bin:;
const auto* const pBinCur = *ppBinCur;
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
pUpdateScore = aUpdateScore + iBin * cScores;
}
determine_bin:;
const auto* const pBinCur = *ppBinCur;
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
pUpdateScore = aUpdateScore + iBin * cScores;

break;
}

pTreeNode = pParent;
Expand Down Expand Up @@ -345,9 +347,23 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
if(!pTreeNode->DECONSTRUCT_IsRightChildTraversal()) {
// we checked earlier that countBins could be converted to a UIntSplit
if(nullptr == ppBinCur) {
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
*pSplit = static_cast<UIntSplit>(iEdge);
++pSplit;
EBM_ASSERT(!bNominal);

while(true) { // not a real loop
if(bMissing) {
if(TermBoostFlags_MissingLow & flags) {
if(1 == iEdge) {
break;
}
}
}

EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
*pSplit = static_cast<UIntSplit>(iEdge);
++pSplit;

break;
}
}
pParent = pTreeNode;
pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark(cBytesPerTreeNode);
Expand Down Expand Up @@ -832,7 +848,10 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
auto* const aBins =
pBoosterShell->GetBoostingMainBins()
->Specialize<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>();
auto* const pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);
auto* pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);

SumAllBins<bHessian, cCompilerScores>(
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());

const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** const apBins =
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>**>(
Expand All @@ -844,7 +863,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* pMissingBin = nullptr;
bool bMissingIsolated = false;

size_t cBinsAdjusted = cBins;
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
if(TermBoostFlags_MissingLow & flags) {
if(bMissing) {
Expand All @@ -861,32 +879,25 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
--cBinsAdjusted;
}
}

const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** ppBinsEnd =
apBins + cBinsAdjusted;

do {
*ppBin = pBin;
pBin = IndexBin(pBin, cBytesPerBin);
++ppBin;
} while(ppBinsEnd != ppBin);
} while(pBinsEnd != pBin);

if(bNominal) {
std::sort(apBins,
ppBinsEnd,
ppBin,
CompareBin<bHessian, cCompilerScores>(
!(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
}

pRootTreeNode->BEFORE_SetBinFirst(apBins);
pRootTreeNode->BEFORE_SetBinLast(ppBinsEnd - 1);
ASSERT_BIN_OK(cBytesPerBin, *(ppBinsEnd - 1), pBoosterShell->GetDebugMainBinsEnd());

SumAllBins<bHessian, cCompilerScores>(
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());
pRootTreeNode->BEFORE_SetBinLast(ppBin - 1);
ASSERT_BIN_OK(cBytesPerBin, *(ppBin - 1), pBoosterShell->GetDebugMainBinsEnd());

EBM_ASSERT(!IsOverflowTreeNodeSize(bHessian, cScores));
const size_t cBytesPerTreeNode = GetTreeNodeSize(bHessian, cScores);
Expand Down Expand Up @@ -1040,21 +1051,19 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
*pTotalGain = static_cast<double>(totalGain);

size_t cSlices = cSplitsMax - cSplitsRemaining + 1;
bool bExtraMissingCut = false;
if(nullptr != pMissingValueTreeNode) {
EBM_ASSERT(nullptr == pMissingBin);
++cSlices;
} else {
if(nullptr != pMissingBin && !bMissingIsolated) {
bExtraMissingCut = true;
++cSlices;
}
}
if(bNominal) {
cSlices = cBins;
}
const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
bExtraMissingCut,
bMissing,
bNominal,
flags,
regAlpha,
Expand Down

0 comments on commit 42875bc

Please sign in to comment.