From 95c7edc7a5e31255c35126aec4829f782966b155 Mon Sep 17 00:00:00 2001 From: kzrnm Date: Sat, 3 Feb 2024 05:58:18 +0900 Subject: [PATCH] Simplify BigInteger.Multiply --- .../Numerics/BigIntegerCalculator.SquMul.cs | 332 ++++++------------ .../tests/BigInteger/multiply.cs | 8 +- .../tests/BigInteger/op_multiply.cs | 4 +- 3 files changed, 106 insertions(+), 238 deletions(-) diff --git a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs index a31153ceec8348..54d0bc9f69f6ae 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.SquMul.cs @@ -155,29 +155,13 @@ internal static #else internal const #endif - int MultiplyThreshold = 32; - + int MultiplyKaratsubaThreshold = 32; public static void Multiply(ReadOnlySpan left, ReadOnlySpan right, Span bits) { Debug.Assert(left.Length >= right.Length); Debug.Assert(bits.Length >= left.Length + right.Length); - Debug.Assert(!bits.ContainsAnyExcept(0u)); - - if (left.Length - right.Length < 3) - { - MultiplyNearLength(left, right, bits); - } - else - { - MultiplyFarLength(left, right, bits); - } - } - - private static void MultiplyFarLength(ReadOnlySpan left, ReadOnlySpan right, Span bits) - { - Debug.Assert(left.Length - right.Length >= 3); - Debug.Assert(bits.Length >= left.Length + right.Length); - Debug.Assert(!bits.ContainsAnyExcept(0u)); + Debug.Assert(bits.Trim(0u).IsEmpty); + Debug.Assert(MultiplyKaratsubaThreshold >= 2); // Executes different algorithms for computing z = a * b // based on the actual length of b. If b is "small" enough @@ -188,120 +172,91 @@ private static void MultiplyFarLength(ReadOnlySpan left, ReadOnlySpan> 32; - } - Unsafe.Add(ref resultPtr, i + left.Length) = (uint)carry; - } - } - else - { - // Based on the Toom-Cook multiplication we split left/right - // into two smaller values, doing recursive multiplication. - // The special form of this multiplication, where we - // split both operands into two operands, is also known - // as the Karatsuba algorithm... + // Result + // z0= | | | a0 * b0 | + // z1= | | a1 * b0 + a0 * b1 | | + // z2= | a1 * b1 | | | - // https://en.wikipedia.org/wiki/Toom-Cook_multiplication - // https://en.wikipedia.org/wiki/Karatsuba_algorithm + // z1 = a1 * b0 + a0 * b1 + // = (a0 + a1) * (b0 + b1) - a0 * b0 - a1 * b1 + // = (a0 + a1) * (b0 + b1) - z0 - z2 - // Say we want to compute z = a * b ... - // ... we need to determine our new length (just the half) - int n = left.Length >> 1; - if (right.Length <= n + 1) - { - // ... split left like a = (a_1 << n) + a_0 - ReadOnlySpan leftLow = left.Slice(0, n); - ReadOnlySpan leftHigh = left.Slice(n); + // Based on the Toom-Cook multiplication we split left/right + // into two smaller values, doing recursive multiplication. + // The special form of this multiplication, where we + // split both operands into two operands, is also known + // as the Karatsuba algorithm... - // ... split right like b = (b_1 << n) + b_0 - ReadOnlySpan rightLow; - uint rightHigh; - if (n < right.Length) - { - Debug.Assert(right.Length == n + 1); - rightLow = right.Slice(0, n); - rightHigh = right[n]; - } - else - { - rightLow = right; - rightHigh = 0; - } + // https://en.wikipedia.org/wiki/Toom-Cook_multiplication + // https://en.wikipedia.org/wiki/Karatsuba_algorithm - // ... prepare our result array (to reuse its memory) - Span bitsLow = bits.Slice(0, n + rightLow.Length); - Span bitsHigh = bits.Slice(n); + // Say we want to compute z = a * b ... - int carryLength = rightLow.Length; - uint[]? carryFromPool = null; - Span carry = ((uint)carryLength <= StackAllocThreshold ? - stackalloc uint[StackAllocThreshold] - : carryFromPool = ArrayPool.Shared.Rent(carryLength)).Slice(0, carryLength); + // ... we need to determine our new length (just the half) + int n = (left.Length + 1) >> 1; - // ... compute low - Multiply(leftLow, rightLow, bitsLow); - Span carryOrig = bits.Slice(n, rightLow.Length); - carryOrig.CopyTo(carry); - carryOrig.Clear(); + if (right.Length <= n) + { + // ... split left like a = (a_1 << n) + a_0 + ReadOnlySpan leftLow = left.Slice(0, n); + ReadOnlySpan leftHigh = left.Slice(n); + Debug.Assert(leftLow.Length >= leftHigh.Length); - if (rightHigh != 0) - { - // ... compute high - MultiplyNearLength(leftHigh, rightLow, bitsHigh.Slice(0, leftHigh.Length + n)); + // ... prepare our result array (to reuse its memory) + Span bitsLow = bits.Slice(0, n + right.Length); + Span bitsHigh = bits.Slice(n); - int upperRightLength = left.Length + 1; - uint[]? upperRightFromPool = null; - Span upperRight = ((uint)upperRightLength <= StackAllocThreshold ? - stackalloc uint[StackAllocThreshold] - : upperRightFromPool = ArrayPool.Shared.Rent(upperRightLength)).Slice(0, upperRightLength); - upperRight.Clear(); + // ... compute low + Multiply(leftLow, right, bitsLow); - Multiply(left, rightHigh, upperRight); + int carryLength = right.Length; + uint[]? carryFromPool = null; + Span carry = ((uint)carryLength <= StackAllocThreshold ? + stackalloc uint[StackAllocThreshold] + : carryFromPool = ArrayPool.Shared.Rent(carryLength)).Slice(0, carryLength); - AddSelf(bitsHigh, upperRight); + Span carryOrig = bits.Slice(n, right.Length); + carryOrig.CopyTo(carry); + carryOrig.Clear(); - if (upperRightFromPool != null) - ArrayPool.Shared.Return(upperRightFromPool); - } - else - { - // ... compute high - Multiply(leftHigh, rightLow, bitsHigh); - } + // ... compute high + if (leftHigh.Length < right.Length) + MultiplyKaratsuba(right, leftHigh, bitsHigh.Slice(0, leftHigh.Length + right.Length), (right.Length + 1) >> 1); + else + Multiply(leftHigh, right, bitsHigh.Slice(0, leftHigh.Length + right.Length)); + + AddSelf(bitsHigh, carry); - AddSelf(bitsHigh, carry); + if (carryFromPool != null) + ArrayPool.Shared.Return(carryFromPool); + } + else + MultiplyKaratsuba(left, right, bits, n); - if (carryFromPool != null) - ArrayPool.Shared.Return(carryFromPool); + static void MultiplyKaratsuba(ReadOnlySpan left, ReadOnlySpan right, Span bits, int n) + { + Debug.Assert(left.Length >= right.Length); + Debug.Assert(2 * n - left.Length is 0 or 1); + Debug.Assert(right.Length > n); + Debug.Assert(bits.Length >= left.Length + right.Length); + + if (right.Length < MultiplyKaratsubaThreshold) + { + Naive(left, right, bits); } else { - int n2 = n << 1; - - Debug.Assert(left.Length > right.Length); - // ... split left like a = (a_1 << n) + a_0 ReadOnlySpan leftLow = left.Slice(0, n); ReadOnlySpan leftHigh = left.Slice(n); @@ -311,47 +266,47 @@ stackalloc uint[StackAllocThreshold] ReadOnlySpan rightHigh = right.Slice(n); // ... prepare our result array (to reuse its memory) - Span bitsLow = bits.Slice(0, n2); - Span bitsHigh = bits.Slice(n2); + Span bitsLow = bits.Slice(0, n + n); + Span bitsHigh = bits.Slice(n + n); + + Debug.Assert(leftLow.Length >= leftHigh.Length); + Debug.Assert(rightLow.Length >= rightHigh.Length); + Debug.Assert(bitsLow.Length >= bitsHigh.Length); // ... compute z_0 = a_0 * b_0 (multiply again) - MultiplyNearLength(rightLow, leftLow, bitsLow); + MultiplyKaratsuba(leftLow, rightLow, bitsLow, (leftLow.Length + 1) >> 1); // ... compute z_2 = a_1 * b_1 (multiply again) - MultiplyFarLength(leftHigh, rightHigh, bitsHigh); + Multiply(leftHigh, rightHigh, bitsHigh); - int leftFoldLength = leftHigh.Length + 1; + int foldLength = n + 1; uint[]? leftFoldFromPool = null; - Span leftFold = ((uint)leftFoldLength <= StackAllocThreshold ? + Span leftFold = ((uint)foldLength <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] - : leftFoldFromPool = ArrayPool.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength); + : leftFoldFromPool = ArrayPool.Shared.Rent(foldLength)).Slice(0, foldLength); leftFold.Clear(); - int rightFoldLength = n + 1; uint[]? rightFoldFromPool = null; - Span rightFold = ((uint)rightFoldLength <= StackAllocThreshold ? + Span rightFold = ((uint)foldLength <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] - : rightFoldFromPool = ArrayPool.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength); + : rightFoldFromPool = ArrayPool.Shared.Rent(foldLength)).Slice(0, foldLength); rightFold.Clear(); - int coreLength = leftFoldLength + rightFoldLength; + // ... compute z_a = a_1 + a_0 (call it fold...) + Add(leftLow, leftHigh, leftFold); + + // ... compute z_b = b_1 + b_0 (call it fold...) + Add(rightLow, rightHigh, rightFold); + + int coreLength = foldLength + foldLength; uint[]? coreFromPool = null; Span core = ((uint)coreLength <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] : coreFromPool = ArrayPool.Shared.Rent(coreLength)).Slice(0, coreLength); core.Clear(); - Debug.Assert(bits.Length - n >= core.Length); - Debug.Assert(rightLow.Length >= rightHigh.Length); - - // ... compute z_a = a_1 + a_0 (call it fold...) - Add(leftHigh, leftLow, leftFold); - - // ... compute z_b = b_1 + b_0 (call it fold...) - Add(rightLow, rightHigh, rightFold); - - // ... compute z_1 = z_a * z_b - z_0 - z_2 - MultiplyNearLength(leftFold, rightFold, core); + // ... compute z_ab = z_a * z_b + MultiplyKaratsuba(leftFold, rightFold, core, (leftFold.Length + 1) >> 1); if (leftFoldFromPool != null) ArrayPool.Shared.Return(leftFoldFromPool); @@ -359,33 +314,24 @@ stackalloc uint[StackAllocThreshold] if (rightFoldFromPool != null) ArrayPool.Shared.Return(rightFoldFromPool); + // ... compute z_1 = z_a * z_b - z_0 - z_2 = a_0 * b_1 + a_1 * b_0 SubtractCore(bitsLow, bitsHigh, core); + Debug.Assert(ActualLength(core) <= left.Length + 1); + // ... and finally merge the result! :-) - AddSelf(bits.Slice(n), core); + AddSelf(bits.Slice(n), core.Slice(0, ActualLength(core))); if (coreFromPool != null) ArrayPool.Shared.Return(coreFromPool); } } - } - private static void MultiplyNearLength(ReadOnlySpan left, ReadOnlySpan right, Span bits) - { - Debug.Assert(left.Length - right.Length < 3); - Debug.Assert(bits.Length >= left.Length + right.Length); - Debug.Assert(!bits.ContainsAnyExcept(0u)); - // Executes different algorithms for computing z = a * b - // based on the actual length of b. If b is "small" enough - // we stick to the classic "grammar-school" method; for the - // rest we switch to implementations with less complexity - // albeit more overhead (which needs to pay off!). - - // NOTE: useful thresholds needs some "empirical" testing, - // which are smaller in DEBUG mode for testing purpose. - - if (right.Length < MultiplyThreshold) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Naive(ReadOnlySpan left, ReadOnlySpan right, Span bits) { + Debug.Assert(right.Length < MultiplyKaratsubaThreshold); + // Switching to managed references helps eliminating // index bounds check... ref uint resultPtr = ref MemoryMarshal.GetReference(bits); @@ -399,96 +345,18 @@ private static void MultiplyNearLength(ReadOnlySpan left, ReadOnlySpan> 32; } Unsafe.Add(ref resultPtr, i + left.Length) = (uint)carry; } } - else - { - // Based on the Toom-Cook multiplication we split left/right - // into two smaller values, doing recursive multiplication. - // The special form of this multiplication, where we - // split both operands into two operands, is also known - // as the Karatsuba algorithm... - - // https://en.wikipedia.org/wiki/Toom-Cook_multiplication - // https://en.wikipedia.org/wiki/Karatsuba_algorithm - - // Say we want to compute z = a * b ... - - // ... we need to determine our new length (just the half) - int n = right.Length >> 1; - int n2 = n << 1; - - // ... split left like a = (a_1 << n) + a_0 - ReadOnlySpan leftLow = left.Slice(0, n); - ReadOnlySpan leftHigh = left.Slice(n); - - // ... split right like b = (b_1 << n) + b_0 - ReadOnlySpan rightLow = right.Slice(0, n); - ReadOnlySpan rightHigh = right.Slice(n); - - // ... prepare our result array (to reuse its memory) - Span bitsLow = bits.Slice(0, n2); - Span bitsHigh = bits.Slice(n2); - - // ... compute z_0 = a_0 * b_0 (multiply again) - MultiplyNearLength(leftLow, rightLow, bitsLow); - - // ... compute z_2 = a_1 * b_1 (multiply again) - MultiplyNearLength(leftHigh, rightHigh, bitsHigh); - - int leftFoldLength = leftHigh.Length + 1; - uint[]? leftFoldFromPool = null; - Span leftFold = ((uint)leftFoldLength <= StackAllocThreshold ? - stackalloc uint[StackAllocThreshold] - : leftFoldFromPool = ArrayPool.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength); - leftFold.Clear(); - - int rightFoldLength = rightHigh.Length + 1; - uint[]? rightFoldFromPool = null; - Span rightFold = ((uint)rightFoldLength <= StackAllocThreshold ? - stackalloc uint[StackAllocThreshold] - : rightFoldFromPool = ArrayPool.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength); - rightFold.Clear(); - - int coreLength = leftFoldLength + rightFoldLength; - uint[]? coreFromPool = null; - Span core = ((uint)coreLength <= StackAllocThreshold ? - stackalloc uint[StackAllocThreshold] - : coreFromPool = ArrayPool.Shared.Rent(coreLength)).Slice(0, coreLength); - core.Clear(); - - // ... compute z_a = a_1 + a_0 (call it fold...) - Add(leftHigh, leftLow, leftFold); - - // ... compute z_b = b_1 + b_0 (call it fold...) - Add(rightHigh, rightLow, rightFold); - - // ... compute z_1 = z_a * z_b - z_0 - z_2 - MultiplyNearLength(leftFold, rightFold, core); - - if (leftFoldFromPool != null) - ArrayPool.Shared.Return(leftFoldFromPool); - - if (rightFoldFromPool != null) - ArrayPool.Shared.Return(rightFoldFromPool); - - SubtractCore(bitsHigh, bitsLow, core); - - // ... and finally merge the result! :-) - AddSelf(bits.Slice(n), core); - - if (coreFromPool != null) - ArrayPool.Shared.Return(coreFromPool); - } } private static void SubtractCore(ReadOnlySpan left, ReadOnlySpan right, Span core) diff --git a/src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs b/src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs index d1abae863ea102..d94c235c65b628 100644 --- a/src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs +++ b/src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs @@ -38,12 +38,12 @@ public static void RunMultiply_TwoLargeBigIntegers_Threshold() { // Again, with lower threshold BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.SquareThreshold, 8, () => - BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.MultiplyThreshold, 8, RunMultiply_TwoLargeBigIntegers) + BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.MultiplyKaratsubaThreshold, 8, RunMultiply_TwoLargeBigIntegers) ); // Again, with lower threshold BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.SquareThreshold, 8, () => - BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.MultiplyThreshold, 8, () => + BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.MultiplyKaratsubaThreshold, 8, () => BigIntTools.Utils.RunWithFakeThreshold(BigIntegerCalculator.StackAllocThreshold, 8, RunMultiply_TwoLargeBigIntegers) ) ); @@ -216,10 +216,10 @@ public static void RunMultiplyKaratsubaBoundary() { for (int d1 = -2; d1 <= 2; d1++) { - tempByteArray1 = GetRandomByteArray(random, BigIntegerCalculator.MultiplyThreshold + d1); + tempByteArray1 = GetRandomByteArray(random, BigIntegerCalculator.MultiplyKaratsubaThreshold + d1); for (int d2 = -4; d2 <= 4; d2++) { - tempByteArray2 = GetRandomByteArray(random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2); + tempByteArray2 = GetRandomByteArray(random, (BigIntegerCalculator.MultiplyKaratsubaThreshold + 1) * 2 + d2); VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "bMultiply"); } } diff --git a/src/libraries/System.Runtime.Numerics/tests/BigInteger/op_multiply.cs b/src/libraries/System.Runtime.Numerics/tests/BigInteger/op_multiply.cs index 08f73afe18198d..9b09a402df5953 100644 --- a/src/libraries/System.Runtime.Numerics/tests/BigInteger/op_multiply.cs +++ b/src/libraries/System.Runtime.Numerics/tests/BigInteger/op_multiply.cs @@ -169,10 +169,10 @@ public static void RunMultiplyKaratsubaBoundary() { for (int d1 = -2; d1 <= 2; d1++) { - tempByteArray1 = GetRandomByteArray(s_random, BigIntegerCalculator.MultiplyThreshold + d1); + tempByteArray1 = GetRandomByteArray(s_random, BigIntegerCalculator.MultiplyKaratsubaThreshold + d1); for (int d2 = -4; d2 <= 4; d2++) { - tempByteArray2 = GetRandomByteArray(s_random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2); + tempByteArray2 = GetRandomByteArray(s_random, (BigIntegerCalculator.MultiplyKaratsubaThreshold + 1) * 2 + d2); VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "b*"); } }