diff --git a/shared/libebm/PartitionOneDimensionalBoosting.cpp b/shared/libebm/PartitionOneDimensionalBoosting.cpp index 3143b7766..a2d4a363c 100644 --- a/shared/libebm/PartitionOneDimensionalBoosting.cpp +++ b/shared/libebm/PartitionOneDimensionalBoosting.cpp @@ -108,6 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER // do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions template static ErrorEbm Flatten(BoosterShell* const pBoosterShell, + bool bExtraMissingCut, const bool bNominal, const TermBoostFlags flags, const FloatCalc regAlpha, @@ -131,6 +132,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, EBM_ASSERT(2 <= cBins); EBM_ASSERT(cSlices <= cBins); EBM_ASSERT(!bNominal || cSlices == cBins); + EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere ErrorEbm error; @@ -176,7 +178,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, } else { pUpdateScore = aUpdateScore; - if(nullptr != pMissingValueTreeNode) { + if(nullptr != pMissingValueTreeNode || bExtraMissingCut) { // always put a split on the missing bin *pSplit = 1; ++pSplit; @@ -237,6 +239,18 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell, } } + if(bExtraMissingCut) { + EBM_ASSERT(!bNominal); // for Nominal we cut everywhere + if(TermBoostFlags_MissingLow & flags) { + if(nullptr == pMissingBin) { + pMissingBin = pTreeNode->GetBin(); + } + } else { + EBM_ASSERT(TermBoostFlags_MissingHigh & flags); + pMissingBin = pTreeNode->GetBin(); + } + } + EBM_ASSERT(apBins <= ppBinLast); EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0}))); @@ -365,6 +379,8 @@ static int FindBestSplitGain(RandomDeterministic* const pRng, const FloatCalc regLambda, const FloatCalc deltaStepMax, const MonotoneDirection monotoneDirection, + const Bin* const pMissingBin, + bool* pbMissingIsolated, const TreeNode** const ppMissingValueTreeNode) { LOG_N(Trace_Verbose, @@ -401,6 +417,9 @@ static int FindBestSplitGain(RandomDeterministic* const pRng, if(ppBinCur == ppBinLast) { // There is just one bin and therefore no splits pTreeNode->AFTER_RejectSplit(); + if(pMissingBin == *ppBinCur) { + *pbMissingIsolated = true; + } return 1; } @@ -822,10 +841,16 @@ template class PartitionOneDimensionalBoo const Bin** ppBin = apBins; const Bin* pBin = aBins; + const Bin* pMissingBin = nullptr; + bool bMissingIsolated = false; + size_t cBinsAdjusted = cBins; const TreeNode* pMissingValueTreeNode = nullptr; if(TermBoostFlags_MissingLow & flags) { if(bMissing) { + if(!bNominal) { + pMissingBin = pBin; + } *ppBin = pBin; pBin = IndexBin(pBin, cBytesPerBin); ++ppBin; @@ -879,6 +904,8 @@ template class PartitionOneDimensionalBoo regLambda, deltaStepMax, monotoneDirection, + pMissingBin, + &bMissingIsolated, &pMissingValueTreeNode); size_t cSplitsRemaining = cSplitsMax; FloatCalc totalGain = 0; @@ -952,6 +979,8 @@ template class PartitionOneDimensionalBoo regLambda, deltaStepMax, monotoneDirection, + pMissingBin, + &bMissingIsolated, &pMissingValueTreeNode); // if FindBestSplitGain returned -1 to indicate an // overflow ignore it here. We successfully made a root node split, so we might as well continue @@ -976,6 +1005,8 @@ template class PartitionOneDimensionalBoo regLambda, deltaStepMax, monotoneDirection, + pMissingBin, + &bMissingIsolated, &pMissingValueTreeNode); // if FindBestSplitGain returned -1 to indicate an // overflow ignore it here. We successfully made a root node split, so we might as well continue @@ -1007,9 +1038,23 @@ template class PartitionOneDimensionalBoo } } *pTotalGain = static_cast(totalGain); - size_t cSlices = - bNominal ? cBins : cSplitsMax - cSplitsRemaining + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0); - return Flatten(pBoosterShell, + + size_t cSlices = cSplitsMax - cSplitsRemaining + 1; + bool bExtraMissingCut = false; + if(nullptr != pMissingValueTreeNode) { + EBM_ASSERT(nullptr == pMissingBin); + ++cSlices; + } else { + if(nullptr != pMissingBin && !bMissingIsolated) { + bExtraMissingCut = true; + ++cSlices; + } + } + if(bNominal) { + cSlices = cBins; + } + const ErrorEbm error = Flatten(pBoosterShell, + bExtraMissingCut, bNominal, flags, regAlpha, @@ -1024,6 +1069,11 @@ template class PartitionOneDimensionalBoo cBins #endif // NDEBUG ); + + EBM_ASSERT(!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension)); + EBM_ASSERT(!bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1); + + return error; } };