Skip to content

Commit

Permalink
optimize the multiclass collapsed specialization of the BinSumsBoosti…
Browse files Browse the repository at this point in the history
…ng function
  • Loading branch information
paulbkoch committed Apr 1, 2024
1 parent 6d13fb2 commit e218796
Showing 1 changed file with 30 additions and 35 deletions.
65 changes: 30 additions & 35 deletions shared/libebm/compute/BinSumsBoosting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,47 +153,42 @@ GPU_DEVICE NEVER_INLINE static void BinSumsBoostingInternal(BinSumsBoostingBridg
pWeight += TFloat::k_cSIMDPack;
}

// TODO: we probably want a templated version of this function for Bins with only 1 cScore so that
// we can pre-fetch the weight, count, gradient and hessian before writing them

size_t iScore = 0;
do {
TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << (TFloat::k_cSIMDShift + 1)]);
TFloat hessian;
if(bHessian) {
TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << (TFloat::k_cSIMDShift + 1)]);
TFloat hessian =
TFloat::Load(&pGradientAndHessian[(iScore << (TFloat::k_cSIMDShift + 1)) + TFloat::k_cSIMDPack]);
if(bWeight) {
gradient *= weight;
hessian = TFloat::Load(&pGradientAndHessian[(iScore << (TFloat::k_cSIMDShift + 1)) + TFloat::k_cSIMDPack]);
}
if(bWeight) {
gradient *= weight;
if(bHessian) {
hessian *= weight;
}
TFloat::Execute(
[aBins, iScore](int, const typename TFloat::T grad, const typename TFloat::T hess) {
auto* const pBin = aBins;
auto* const aGradientPair = pBin->GetGradientPairs();
auto* const pGradientPair = &aGradientPair[iScore];
typename TFloat::T binGrad = pGradientPair->m_sumGradients;
typename TFloat::T binHess = pGradientPair->GetHess();
binGrad += grad;
binHess += hess;
pGradientPair->m_sumGradients = binGrad;
pGradientPair->SetHess(binHess);
},
gradient,
hessian);
} else {
TFloat gradient = TFloat::Load(&pGradientAndHessian[iScore << TFloat::k_cSIMDShift]);
if(bWeight) {
gradient *= weight;
}
TFloat::Execute(
[aBins, iScore](int, const typename TFloat::T grad) {
auto* const pBin = aBins;
auto* const aGradientPair = pBin->GetGradientPairs();
auto* const pGradientPair = &aGradientPair[iScore];
pGradientPair->m_sumGradients += grad;
},
gradient);
}

const typename TFloat::T gradientSum = Sum(gradient);
typename TFloat::T hessianSum;
if(bHessian) {
hessianSum = Sum(hessian);
}

auto* const aGradientPair = aBins->GetGradientPairs();
auto* const pGradientPair = &aGradientPair[iScore];
typename TFloat::T binGrad = pGradientPair->m_sumGradients;
typename TFloat::T binHess;
if(bHessian) {
binHess = pGradientPair->GetHess();
}
binGrad += gradientSum;
if(bHessian) {
binHess += hessianSum;
}
pGradientPair->m_sumGradients = binGrad;
if(bHessian) {
pGradientPair->SetHess(binHess);
}

++iScore;
} while(cScores != iScore);

Expand Down

0 comments on commit e218796

Please sign in to comment.