diff --git a/shared/libebm/compute/BinSumsBoosting.hpp b/shared/libebm/compute/BinSumsBoosting.hpp index 939dd9bcb..fe37a7817 100644 --- a/shared/libebm/compute/BinSumsBoosting.hpp +++ b/shared/libebm/compute/BinSumsBoosting.hpp @@ -153,47 +153,42 @@ GPU_DEVICE NEVER_INLINE static void BinSumsBoostingInternal(BinSumsBoostingBridg pWeight += TFloat::k_cSIMDPack; } - // TODO: we probably want a templated version of this function for Bins with only 1 cScore so that - // we can pre-fetch the weight, count, gradient and hessian before writing them - size_t iScore = 0; do { + TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << (TFloat::k_cSIMDShift + 1)]); + TFloat hessian; if(bHessian) { - TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << (TFloat::k_cSIMDShift + 1)]); - TFloat hessian = - TFloat::Load(&pGradientAndHessian[(iScore << (TFloat::k_cSIMDShift + 1)) + TFloat::k_cSIMDPack]); - if(bWeight) { - gradient *= weight; + hessian = TFloat::Load(&pGradientAndHessian[(iScore << (TFloat::k_cSIMDShift + 1)) + TFloat::k_cSIMDPack]); + } + if(bWeight) { + gradient *= weight; + if(bHessian) { hessian *= weight; } - TFloat::Execute( - [aBins, iScore](int, const typename TFloat::T grad, const typename TFloat::T hess) { - auto* const pBin = aBins; - auto* const aGradientPair = pBin->GetGradientPairs(); - auto* const pGradientPair = &aGradientPair[iScore]; - typename TFloat::T binGrad = pGradientPair->m_sumGradients; - typename TFloat::T binHess = pGradientPair->GetHess(); - binGrad += grad; - binHess += hess; - pGradientPair->m_sumGradients = binGrad; - pGradientPair->SetHess(binHess); - }, - gradient, - hessian); - } else { - TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << TFloat::k_cSIMDShift]); - if(bWeight) { - gradient *= weight; - } - TFloat::Execute( - [aBins, iScore](int, const typename TFloat::T grad) { - auto* const pBin = aBins; - auto* const aGradientPair = pBin->GetGradientPairs(); - auto* const pGradientPair = &aGradientPair[iScore]; - pGradientPair->m_sumGradients += grad; - }, - gradient); } + + const typename TFloat::T gradientSum = Sum(gradient); + typename TFloat::T hessianSum; + if(bHessian) { + hessianSum = Sum(hessian); + } + + auto* const aGradientPair = aBins->GetGradientPairs(); + auto* const pGradientPair = &aGradientPair[iScore]; + typename TFloat::T binGrad = pGradientPair->m_sumGradients; + typename TFloat::T binHess; + if(bHessian) { + binHess = pGradientPair->GetHess(); + } + binGrad += gradientSum; + if(bHessian) { + binHess += hessianSum; + } + pGradientPair->m_sumGradients = binGrad; + if(bHessian) { + pGradientPair->SetHess(binHess); + } + ++iScore; } while(cScores != iScore);