Skip to content

Commit

Permalink
simplify and improve some parts of the tree building
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Dec 24, 2024
1 parent 7cea7a3 commit ebf134d
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions shared/libebm/PartitionOneDimensionalBoosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,10 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
auto* const aBins =
pBoosterShell->GetBoostingMainBins()
->Specialize<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>();
auto* const pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);
auto* pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);

SumAllBins<bHessian, cCompilerScores>(
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());

const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** const apBins =
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>**>(
Expand All @@ -844,7 +847,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* pMissingBin = nullptr;
bool bMissingIsolated = false;

size_t cBinsAdjusted = cBins;
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
if(TermBoostFlags_MissingLow & flags) {
if(bMissing) {
Expand All @@ -861,32 +863,25 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
// region.
pBin = IndexBin(pBin, cBytesPerBin);
--cBinsAdjusted;
}
}

const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** ppBinsEnd =
apBins + cBinsAdjusted;

do {
*ppBin = pBin;
pBin = IndexBin(pBin, cBytesPerBin);
++ppBin;
} while(ppBinsEnd != ppBin);
} while(pBinsEnd != pBin);

if(bNominal) {
std::sort(apBins,
ppBinsEnd,
ppBin,
CompareBin<bHessian, cCompilerScores>(
!(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
}

pRootTreeNode->BEFORE_SetBinFirst(apBins);
pRootTreeNode->BEFORE_SetBinLast(ppBinsEnd - 1);
ASSERT_BIN_OK(cBytesPerBin, *(ppBinsEnd - 1), pBoosterShell->GetDebugMainBinsEnd());

SumAllBins<bHessian, cCompilerScores>(
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());
pRootTreeNode->BEFORE_SetBinLast(ppBin - 1);
ASSERT_BIN_OK(cBytesPerBin, *(ppBin - 1), pBoosterShell->GetDebugMainBinsEnd());

EBM_ASSERT(!IsOverflowTreeNodeSize(bHessian, cScores));
const size_t cBytesPerTreeNode = GetTreeNodeSize(bHessian, cScores);
Expand Down

0 comments on commit ebf134d

Please sign in to comment.