diff --git a/shared/libebm/PartitionOneDimensionalBoosting.cpp b/shared/libebm/PartitionOneDimensionalBoosting.cpp index 770bb98a4..f1ceb8840 100644 --- a/shared/libebm/PartitionOneDimensionalBoosting.cpp +++ b/shared/libebm/PartitionOneDimensionalBoosting.cpp @@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER // do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions template static ErrorEbm Flatten(BoosterShell* const pBoosterShell, - bool bExtraMissingCut, + const bool bMissing, const bool bNominal, const TermBoostFlags flags, const FloatCalc regAlpha, @@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, EBM_ASSERT(2 <= cBins); EBM_ASSERT(cSlices <= cBins); EBM_ASSERT(!bNominal || cSlices == cBins); - EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere ErrorEbm error; @@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, } else { pUpdateScore = aUpdateScore; - if(nullptr != pMissingValueTreeNode || bExtraMissingCut) { + if(bMissing) { // always put a split on the missing bin *pSplit = 1; ++pSplit; @@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, } } - if(bExtraMissingCut) { - EBM_ASSERT(!bNominal); // for Nominal we cut everywhere - if(TermBoostFlags_MissingLow & flags) { - if(nullptr == pMissingBin) { - pMissingBin = pTreeNode->GetBin(); - } - } else { - EBM_ASSERT(TermBoostFlags_MissingHigh & flags); - pMissingBin = pTreeNode->GetBin(); - } - } - EBM_ASSERT(apBins <= ppBinLast); EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0}))); @@ -266,6 +253,13 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, } EBM_ASSERT(!bNominal); + // for Nominal we cut everywhere + if(TermBoostFlags_MissingLow & flags) { + if(nullptr == pMissingBin) { + pMissingBin = pTreeNode->GetBin(); + } + } + // if !bNominal, check the bin above and below for order EBM_ASSERT(apBins == ppBinLast || *(ppBinLast - 1) < *ppBinLast); EBM_ASSERT(ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{2} : size_t{1})) || @@ -273,41 +267,43 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0); - while(true) { - iScore = 0; - do { - FloatCalc updateScore; - if(bUpdateWithHessian) { - updateScore = -CalcNegUpdate(static_cast(aGradientPair[iScore].m_sumGradients), - static_cast(aGradientPair[iScore].GetHess()), - regAlpha, - regLambda, - deltaStepMax); - } else { - updateScore = -CalcNegUpdate(static_cast(aGradientPair[iScore].m_sumGradients), - static_cast(pTreeNode->GetBin()->GetWeight()), - regAlpha, - regLambda, - deltaStepMax); - } + if(!bMissing || !(TermBoostFlags_MissingLow & flags) || ((TermBoostFlags_MissingLow & flags) && 1 != iEdge)) { + while(true) { + iScore = 0; + do { + FloatCalc updateScore; + if(bUpdateWithHessian) { + updateScore = -CalcNegUpdate(static_cast(aGradientPair[iScore].m_sumGradients), + static_cast(aGradientPair[iScore].GetHess()), + regAlpha, + regLambda, + deltaStepMax); + } else { + updateScore = -CalcNegUpdate(static_cast(aGradientPair[iScore].m_sumGradients), + static_cast(pTreeNode->GetBin()->GetWeight()), + regAlpha, + regLambda, + deltaStepMax); + } - *pUpdateScore = static_cast(updateScore); - ++pUpdateScore; + *pUpdateScore = static_cast(updateScore); + ++pUpdateScore; - ++iScore; - } while(cScores != iScore); - if(nullptr == ppBinCur) { - break; - } - EBM_ASSERT(bNominal); - ++ppBinCur; - if(ppBinLast < ppBinCur) { - break; + ++iScore; + } while(cScores != iScore); + if(nullptr == ppBinCur) { + break; + } + EBM_ASSERT(bNominal); + ++ppBinCur; + if(ppBinLast < ppBinCur) { + break; + } + determine_bin:; + const auto* const pBinCur = *ppBinCur; + const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin); + pUpdateScore = aUpdateScore + iBin * cScores; } - determine_bin:; - const auto* const pBinCur = *ppBinCur; - const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin); - pUpdateScore = aUpdateScore + iBin * cScores; } pTreeNode = pParent; @@ -345,9 +341,12 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, if(!pTreeNode->DECONSTRUCT_IsRightChildTraversal()) { // we checked earlier that countBins could be converted to a UIntSplit if(nullptr == ppBinCur) { - EBM_ASSERT(!IsConvertError(iEdge)); - *pSplit = static_cast(iEdge); - ++pSplit; + if(!bMissing || !(TermBoostFlags_MissingLow & flags) || + ((TermBoostFlags_MissingLow & flags) && 1 != iEdge)) { + EBM_ASSERT(!IsConvertError(iEdge)); + *pSplit = static_cast(iEdge); + ++pSplit; + } } pParent = pTreeNode; pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark(cBytesPerTreeNode); @@ -1035,13 +1034,11 @@ template class PartitionOneDimensionalBoo *pTotalGain = static_cast(totalGain); size_t cSlices = cSplitsMax - cSplitsRemaining + 1; - bool bExtraMissingCut = false; if(nullptr != pMissingValueTreeNode) { EBM_ASSERT(nullptr == pMissingBin); ++cSlices; } else { if(nullptr != pMissingBin && !bMissingIsolated) { - bExtraMissingCut = true; ++cSlices; } } @@ -1049,7 +1046,7 @@ template class PartitionOneDimensionalBoo cSlices = cBins; } const ErrorEbm error = Flatten(pBoosterShell, - bExtraMissingCut, + bMissing, bNominal, flags, regAlpha,