Skip to content

Commit

Permalink
reorganize tree building code for future options
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 24, 2024
1 parent d9e1032 commit f246e66
Showing 1 changed file with 44 additions and 26 deletions.
70 changes: 44 additions & 26 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
auto* const pRootTreeNodeDebug = pBoosterShell->GetTreeNodesTemp<bHessian>();
size_t cSamplesExpectedDebug = static_cast<size_t>(pRootTreeNodeDebug->GetBin()->GetCountSamples());
size_t cSamplesTotalDebug = 0;
bool bLookingForMissing = nullptr != pMissingValueTreeNode;
#endif // NDEBUG

const Bin<FloatMain, UIntMain, true, true, bHessian>* pMissingBin = nullptr;

Tensor* const pInnerTermUpdate = pBoosterShell->GetInnerTermUpdate();

error = pInnerTermUpdate->SetCountSlices(iDimension, cSlices);
Expand All @@ -162,7 +163,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
UIntSplit* pSplit = pInnerTermUpdate->GetSplitPointer(iDimension);

FloatScore* const aUpdateScore = pInnerTermUpdate->GetTensorScoresPointer();
FloatScore* pMissingUpdateScore = aUpdateScore;
FloatScore* pUpdateScore;

const Bin<FloatMain, UIntMain, true, true, bHessian>* const* ppBinCur = nullptr;
Expand Down Expand Up @@ -205,7 +205,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
pParent = pTreeNode;
pTreeNode = pLeftChild;
} else {
bool bMissingNode = false;
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* ppBinLast;
// if the pointer points to the space within the bins, then the TreeNode could not be split
// and this TreeNode never had children and we never wrote a pointer to the children in this memory
Expand All @@ -217,7 +216,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
IndexTreeNode(pRootTreeNode, pBoosterCore->GetCountBytesTreeNodes() - cBytesPerTreeNode));

if(pMissingValueTreeNode == GetLeftNode(pChildren)) {
bMissingNode = true;
EBM_ASSERT(nullptr == pMissingBin);
pMissingBin = pTreeNode->GetBin();
}

// the node was examined and a gain calculated, so it has left and right children.
Expand All @@ -226,12 +226,14 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
ppBinLast = pRightChild->BEFORE_GetBinLast();

if(pMissingValueTreeNode == pRightChild) {
bMissingNode = true;
EBM_ASSERT(nullptr == pMissingBin);
pMissingBin = pTreeNode->GetBin();
}
} else {
ppBinLast = pTreeNode->BEFORE_GetBinLast();
if(pMissingValueTreeNode == pTreeNode) {
bMissingNode = true;
EBM_ASSERT(nullptr == pMissingBin);
pMissingBin = pTreeNode->GetBin();
}
}

Expand All @@ -240,10 +242,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,

#ifndef NDEBUG
cSamplesTotalDebug += static_cast<size_t>(pTreeNode->GetBin()->GetCountSamples());
if(bMissingNode) {
EBM_ASSERT(bLookingForMissing);
bLookingForMissing = false;
}
#endif // NDEBUG

size_t iEdge;
Expand Down Expand Up @@ -282,11 +280,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
*pUpdateScore = static_cast<FloatScore>(updateScore);
++pUpdateScore;

if(bMissingNode) {
*pMissingUpdateScore = static_cast<FloatScore>(updateScore);
++pMissingUpdateScore;
}

++iScore;
} while(cScores != iScore);
if(nullptr == ppBinCur) {
Expand All @@ -297,7 +290,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
if(ppBinLast < ppBinCur) {
break;
}
bMissingNode = false;
determine_bin:;
const auto* const pBinCur = *ppBinCur;
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
Expand All @@ -310,8 +302,28 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
if(nullptr == pTreeNode) {
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);

EBM_ASSERT(!bLookingForMissing);
EBM_ASSERT(nullptr == pMissingValueTreeNode || pMissingUpdateScore == aUpdateScore + cScores);
EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
if(nullptr != pMissingBin) {
FloatScore hess = static_cast<FloatCalc>(pMissingBin->GetWeight());
const auto* pGradientPair = pMissingBin->GetGradientPairs();
const auto* const pGradientPairEnd = pGradientPair + cScores;
FloatScore* pMissingUpdateScore = aUpdateScore;
do {
if(bUpdateWithHessian) {
hess = static_cast<FloatCalc>(pGradientPair->GetHess());
}
FloatCalc updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(pGradientPair->m_sumGradients),
hess,
regAlpha,
regLambda,
deltaStepMax);

*pMissingUpdateScore = updateScore;
++pMissingUpdateScore;

++pGradientPair;
} while(pGradientPairEnd != pGradientPair);
}

LOG_0(Trace_Verbose, "Exited Flatten");
return Error_None;
Expand Down Expand Up @@ -798,9 +810,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
bUnseen = false;
}

// Disable missing if there are only 2 bins, because we'll end up just combining the bins always then.
bMissing = bMissing && (0 == (TermBoostFlags_MissingLow & flags));

auto* const aBins =
pBoosterShell->GetBoostingMainBins()
->Specialize<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>();
Expand All @@ -815,11 +824,20 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo

size_t cBinsAdjusted = cBins;
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous region.
pBin = IndexBin(pBin, cBytesPerBin);
--cBinsAdjusted;
if(TermBoostFlags_MissingLow & flags) {
if(bMissing) {
*ppBin = pBin;
pBin = IndexBin(pBin, cBytesPerBin);
++ppBin;
}
} else {
if(bMissing) {
pMissingValueTreeNode = pRootTreeNode;
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
--cBinsAdjusted;
}
}

const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** ppBinsEnd =
Expand Down

0 comments on commit f246e66

Please sign in to comment.