diff --git a/shared/libebm/PartitionOneDimensionalBoosting.cpp b/shared/libebm/PartitionOneDimensionalBoosting.cpp index 56ca79c3c..3143b7766 100644 --- a/shared/libebm/PartitionOneDimensionalBoosting.cpp +++ b/shared/libebm/PartitionOneDimensionalBoosting.cpp @@ -138,9 +138,10 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, auto* const pRootTreeNodeDebug = pBoosterShell->GetTreeNodesTemp(); size_t cSamplesExpectedDebug = static_cast(pRootTreeNodeDebug->GetBin()->GetCountSamples()); size_t cSamplesTotalDebug = 0; - bool bLookingForMissing = nullptr != pMissingValueTreeNode; #endif // NDEBUG + const Bin* pMissingBin = nullptr; + Tensor* const pInnerTermUpdate = pBoosterShell->GetInnerTermUpdate(); error = pInnerTermUpdate->SetCountSlices(iDimension, cSlices); @@ -162,7 +163,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, UIntSplit* pSplit = pInnerTermUpdate->GetSplitPointer(iDimension); FloatScore* const aUpdateScore = pInnerTermUpdate->GetTensorScoresPointer(); - FloatScore* pMissingUpdateScore = aUpdateScore; FloatScore* pUpdateScore; const Bin* const* ppBinCur = nullptr; @@ -205,7 +205,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, pParent = pTreeNode; pTreeNode = pLeftChild; } else { - bool bMissingNode = false; const Bin* const* ppBinLast; // if the pointer points to the space within the bins, then the TreeNode could not be split // and this TreeNode never had children and we never wrote a pointer to the children in this memory @@ -217,7 +216,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, IndexTreeNode(pRootTreeNode, pBoosterCore->GetCountBytesTreeNodes() - cBytesPerTreeNode)); if(pMissingValueTreeNode == GetLeftNode(pChildren)) { - bMissingNode = true; + EBM_ASSERT(nullptr == pMissingBin); + pMissingBin = pTreeNode->GetBin(); } // the node was examined and a gain calculated, so it has left and right children. @@ -226,12 +226,14 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, ppBinLast = pRightChild->BEFORE_GetBinLast(); if(pMissingValueTreeNode == pRightChild) { - bMissingNode = true; + EBM_ASSERT(nullptr == pMissingBin); + pMissingBin = pTreeNode->GetBin(); } } else { ppBinLast = pTreeNode->BEFORE_GetBinLast(); if(pMissingValueTreeNode == pTreeNode) { - bMissingNode = true; + EBM_ASSERT(nullptr == pMissingBin); + pMissingBin = pTreeNode->GetBin(); } } @@ -240,10 +242,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, #ifndef NDEBUG cSamplesTotalDebug += static_cast(pTreeNode->GetBin()->GetCountSamples()); - if(bMissingNode) { - EBM_ASSERT(bLookingForMissing); - bLookingForMissing = false; - } #endif // NDEBUG size_t iEdge; @@ -282,11 +280,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, *pUpdateScore = static_cast(updateScore); ++pUpdateScore; - if(bMissingNode) { - *pMissingUpdateScore = static_cast(updateScore); - ++pMissingUpdateScore; - } - ++iScore; } while(cScores != iScore); if(nullptr == ppBinCur) { @@ -297,7 +290,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, if(ppBinLast < ppBinCur) { break; } - bMissingNode = false; determine_bin:; const auto* const pBinCur = *ppBinCur; const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin); @@ -310,8 +302,28 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, if(nullptr == pTreeNode) { EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug); - EBM_ASSERT(!bLookingForMissing); - EBM_ASSERT(nullptr == pMissingValueTreeNode || pMissingUpdateScore == aUpdateScore + cScores); + EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin); + if(nullptr != pMissingBin) { + FloatScore hess = static_cast(pMissingBin->GetWeight()); + const auto* pGradientPair = pMissingBin->GetGradientPairs(); + const auto* const pGradientPairEnd = pGradientPair + cScores; + FloatScore* pMissingUpdateScore = aUpdateScore; + do { + if(bUpdateWithHessian) { + hess = static_cast(pGradientPair->GetHess()); + } + FloatCalc updateScore = -CalcNegUpdate(static_cast(pGradientPair->m_sumGradients), + hess, + regAlpha, + regLambda, + deltaStepMax); + + *pMissingUpdateScore = updateScore; + ++pMissingUpdateScore; + + ++pGradientPair; + } while(pGradientPairEnd != pGradientPair); + } LOG_0(Trace_Verbose, "Exited Flatten"); return Error_None; @@ -798,9 +810,6 @@ template class PartitionOneDimensionalBoo bUnseen = false; } - // Disable missing if there are only 2 bins, because we'll end up just combining the bins always then. - bMissing = bMissing && (0 == (TermBoostFlags_MissingLow & flags)); - auto* const aBins = pBoosterShell->GetBoostingMainBins() ->Specialize(); @@ -815,11 +824,20 @@ template class PartitionOneDimensionalBoo size_t cBinsAdjusted = cBins; const TreeNode* pMissingValueTreeNode = nullptr; - if(bMissing) { - pMissingValueTreeNode = pRootTreeNode; - // Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous region. - pBin = IndexBin(pBin, cBytesPerBin); - --cBinsAdjusted; + if(TermBoostFlags_MissingLow & flags) { + if(bMissing) { + *ppBin = pBin; + pBin = IndexBin(pBin, cBytesPerBin); + ++ppBin; + } + } else { + if(bMissing) { + pMissingValueTreeNode = pRootTreeNode; + // Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous + // region. + pBin = IndexBin(pBin, cBytesPerBin); + --cBinsAdjusted; + } } const Bin** ppBinsEnd =