From e61c40e9c3e71f50c2d6a6c8a1688b6a8ddd4ba8 Mon Sep 17 00:00:00 2001 From: Cody Gunton Date: Wed, 5 Jun 2024 13:05:19 +0100 Subject: [PATCH] fix: Biggroup batch mul handles collisions (#6780) The following PR adds an edgecase handling mode to biggroup batch multiplication. In this mode the points are randomised in such a way as avoid weird interactions (doublings and point at infinity cases). It enables using tables for RecurisveMergeVerifier and Recursive Verifier for Protogalaxy on ultra (it also halves the gatecount from 4.5 mln to 2.2). For a batch multiplication of 5 points it increases the gate count in ultra from ~72k to ~78k --------- Co-authored-by: Rumata888 --- .../verifier/merge_recursive_verifier.cpp | 2 +- .../verifier/merge_verifier.test.cpp | 3 +- .../protogalaxy_recursive_verifier.hpp | 54 ++-------- .../stdlib/primitives/biggroup/biggroup.hpp | 9 +- .../primitives/biggroup/biggroup.test.cpp | 97 ++++++++++++++--- .../biggroup/biggroup_batch_mul.hpp | 2 +- .../biggroup/biggroup_edgecase_handling.hpp | 100 ++++++++++++++++++ .../primitives/biggroup/biggroup_impl.hpp | 29 +++-- .../biggroup/handle_points_at_infinity.hpp | 42 -------- .../stdlib/primitives/field/field.hpp | 1 + 10 files changed, 221 insertions(+), 118 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp delete mode 100644 barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/merge_recursive_verifier.cpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/merge_recursive_verifier.cpp index 5b2eefbbe72..0e0517aa541 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/merge_recursive_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/merge_recursive_verifier.cpp @@ -75,7 +75,7 @@ std::array::Element, 2> MergeRecursiveVerifier_ class RecursiveMergeVerifierTest : public test // Run the recursive verifier tests with Ultra and Mega builders // TODO(https://github.com/AztecProtocol/barretenberg/issues/1024): Ultra fails, possibly due to repeated points in // batch mul? -// using Builders = testing::Types; -using Builders = testing::Types; +using Builders = testing::Types; TYPED_TEST_SUITE(RecursiveMergeVerifierTest, Builders); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp index 613528f5272..4d6f68d8ed6 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_recursion/verifier/protogalaxy_recursive_verifier.hpp @@ -125,51 +125,6 @@ template class ProtoGalaxyRecursiveVerifier_ { return result; }; - /** - * @brief Hack method to fold the witness commitments and verification key without the batch_mul in the case where - * the recursive folding verifier is instantiated as a vanilla ultra circuit. - * - * @details In the folding recursive verifier we might hit the scenerio where we do a batch_mul(commitments, - * lagranges) where the commitments are equal. That is because when we add gates to ensure no zero commitments, - * these will be the same for all circuits, hitting an edge case in batch_mul that creates a failing constraint. - * Specifically, at some point in the algorithm we compute the difference between the points which, if they are - * equal, would be zero, case that is not supported. See https://github.com/AztecProtocol/barretenberg/issues/971. - */ - void fold_commitments(std::vector lagranges, - VerifierInstances& instances, - std::shared_ptr& accumulator) - requires IsUltraBuilder - { - using ElementNative = typename Flavor::Curve::ElementNative; - using AffineElementNative = typename Flavor::Curve::AffineElementNative; - - auto offset_generator = Commitment::from_witness(builder, AffineElementNative(ElementNative::random_element())); - - size_t vk_idx = 0; - for (auto& expected_vk : accumulator->verification_key->get_all()) { - expected_vk = offset_generator; - size_t inst = 0; - for (auto& instance : instances) { - expected_vk += instance->verification_key->get_all()[vk_idx] * lagranges[inst]; - inst++; - } - expected_vk -= offset_generator; - vk_idx++; - } - - size_t comm_idx = 0; - for (auto& comm : accumulator->witness_commitments.get_all()) { - comm = offset_generator; - size_t inst = 0; - for (auto& instance : instances) { - comm += instance->witness_commitments.get_all()[comm_idx] * lagranges[inst]; - inst++; - } - comm -= offset_generator; - comm_idx++; - } - } - /** * @brief Folds the witness commitments and verification key (part of ϕ) and stores the values in the accumulator. * @@ -179,7 +134,6 @@ template class ProtoGalaxyRecursiveVerifier_ { void fold_commitments(std::vector lagranges, VerifierInstances& instances, std::shared_ptr& accumulator) - requires(!IsUltraBuilder) { size_t vk_idx = 0; for (auto& expected_vk : accumulator->verification_key->get_all()) { @@ -187,7 +141,9 @@ template class ProtoGalaxyRecursiveVerifier_ { for (auto& instance : instances) { commitments.emplace_back(instance->verification_key->get_all()[vk_idx]); } - expected_vk = Commitment::batch_mul(commitments, lagranges); + // For ultra we need to enable edgecase prevention + expected_vk = Commitment::batch_mul( + commitments, lagranges, /*max_num_bits=*/0, /*with_edgecases=*/IsUltraBuilder); vk_idx++; } @@ -197,7 +153,9 @@ template class ProtoGalaxyRecursiveVerifier_ { for (auto& instance : instances) { commitments.emplace_back(instance->witness_commitments.get_all()[comm_idx]); } - comm = Commitment::batch_mul(commitments, lagranges); + // For ultra we need to enable edgecase prevention + comm = Commitment::batch_mul( + commitments, lagranges, /*max_num_bits=*/0, /*with_edgecases=*/IsUltraBuilder); comm_idx++; } } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp index 0e0ab416ce8..9388f59ae10 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp @@ -119,7 +119,7 @@ template class element { *this = *this - other; return *this; } - std::array checked_unconditional_add_sub(const element& other) const; + std::array checked_unconditional_add_sub(const element&) const; element operator*(const Fr& other) const; @@ -204,6 +204,9 @@ template class element { return result; } + static std::pair, std::vector> mask_points(const std::vector& _points, + const std::vector& _scalars); + static std::pair, std::vector> handle_points_at_infinity( const std::vector& _points, const std::vector& _scalars); @@ -215,7 +218,8 @@ template class element { static element wnaf_batch_mul(const std::vector& points, const std::vector& scalars); static element batch_mul(const std::vector& points, const std::vector& scalars, - const size_t max_num_bits = 0); + const size_t max_num_bits = 0, + const bool with_edgecases = false); // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) max_num_bits is unused; could implement and use // this to optimize other operations. @@ -310,6 +314,7 @@ template class element { const std::array& limb_max); static std::pair compute_offset_generators(const size_t num_rounds); + static typename NativeGroup::affine_element compute_table_offset_generator(); template >> struct four_bit_table_plookup { four_bit_table_plookup(){}; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp index 1069e722b47..f01cc3c1849 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp @@ -449,20 +449,58 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } - static void test_batch_mul_edge_cases() + static void test_batch_mul_edgecase_equivalence() { - { - // batch P + P = 2P + const size_t num_points = 5; + Builder builder; + std::vector points; + std::vector scalars; + for (size_t i = 0; i < num_points; ++i) { + points.push_back(affine_element(element::random_element())); + scalars.push_back(fr::random_element()); + } + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + + element_ct result_point2 = + element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true); + + element expected_point = g1::one; + expected_point.self_set_infinity(); + for (size_t i = 0; i < num_points; ++i) { + expected_point += (element(points[i]) * scalars[i]); + } + + expected_point = expected_point.normalize(); + + fq result2_x(result_point2.x.get_value().lo); + fq result2_y(result_point2.y.get_value().lo); + + EXPECT_EQ(result2_x, expected_point.x); + EXPECT_EQ(result2_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + + static void test_batch_mul_edge_case_set1() + { + const auto test_repeated_points = [](const uint32_t num_points) { + // batch P + ... + P = m*P + info("num points: ", num_points); std::vector points; - points.push_back(affine_element::one()); - points.push_back(affine_element::one()); std::vector scalars; - scalars.push_back(1); - scalars.push_back(1); + for (size_t idx = 0; idx < num_points; idx++) { + points.push_back(affine_element::one()); + scalars.push_back(1); + } Builder builder; ASSERT(points.size() == scalars.size()); - const size_t num_points = points.size(); std::vector circuit_points; std::vector circuit_scalars; @@ -470,9 +508,13 @@ template class stdlib_biggroup : public testing::Test { circuit_points.push_back(element_ct::from_witness(&builder, points[i])); circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); } - element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + element_ct result_point = + element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true); - element expected_point = points[0] + points[1]; + auto expected_point = element::infinity(); + for (const auto& point : points) { + expected_point += point; + } expected_point = expected_point.normalize(); fq result_x(result_point.x.get_value().lo); @@ -482,7 +524,16 @@ template class stdlib_biggroup : public testing::Test { EXPECT_EQ(result_y, expected_point.y); EXPECT_CIRCUIT_CORRECTNESS(builder); - } + }; + test_repeated_points(2); + test_repeated_points(3); + test_repeated_points(4); + test_repeated_points(5); + test_repeated_points(6); + test_repeated_points(7); + } + static void test_batch_mul_edge_case_set2() + { { // batch oo + P = P std::vector points; @@ -502,7 +553,8 @@ template class stdlib_biggroup : public testing::Test { circuit_points.push_back(element_ct::from_witness(&builder, points[i])); circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); } - element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + element_ct result_point = + element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true); element expected_point = points[1]; expected_point = expected_point.normalize(); @@ -535,7 +587,8 @@ template class stdlib_biggroup : public testing::Test { circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); } - element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + element_ct result_point = + element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true); element expected_point = points[1]; expected_point = expected_point.normalize(); @@ -1177,10 +1230,24 @@ HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul) { TestFixture::test_batch_mul(); } -HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_cases) + +HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edgecase_equivalence) +{ + if constexpr (HasGoblinBuilder) { + GTEST_SKIP(); + } else { + TestFixture::test_batch_mul_edgecase_equivalence(); + } +} +HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_case_set1) +{ + TestFixture::test_batch_mul_edge_case_set1(); +} + +HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_case_set2) { if constexpr (HasGoblinBuilder) { - TestFixture::test_batch_mul_edge_cases(); + TestFixture::test_batch_mul_edge_case_set2(); } else { GTEST_SKIP() << "https://github.com/AztecProtocol/barretenberg/issues/1000"; }; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp index e931e1c6374..f9b0598854d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp @@ -1,6 +1,6 @@ #pragma once -#include "barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp" +#include "barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp" #include namespace bb::stdlib { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp new file mode 100644 index 00000000000..bf992e7e6dd --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp @@ -0,0 +1,100 @@ +#pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" + +namespace bb::stdlib { + +/** + * @brief Compute an offset generator for use in biggroup tables + * + *@details Sometimes the points from which we construct the tables are going to be dependent in such a way that + *combining them for constructing the table is not possible without handling the edgecases such as the point at infinity + *and doubling. To avoid handling those we add multiples of this offset generator to the points. + * + * @param num_rounds + */ +template +typename G::affine_element element::compute_table_offset_generator() +{ + constexpr typename G::affine_element offset_generator = + G::derive_generators("biggroup table offset generator", 1)[0]; + + return offset_generator; +} + +/** + * @brief Given two lists of points that need to be multiplied by scalars, create a new list of length +1 with original + * points masked, but the same scalar product sum + * @details Add +1G, +2G, +4G etc to the original points and adds a new point 2ⁿ⋅G and scalar x to the lists. By + * doubling the point every time, we ensure that no +-1 combination of 6 sequential elements run into edgecases, unless + * the points are deliberately constructed to trigger it. + */ +template +std::pair>, std::vector> element::mask_points( + const std::vector& _points, const std::vector& _scalars) +{ + std::vector points; + std::vector scalars; + ASSERT(_points.size() == _scalars.size()); + using NativeFr = typename Fr::native; + auto running_scalar = NativeFr::one(); + // Get the offset generator G_offset in native and in-circuit form + auto native_offset_generator = element::compute_table_offset_generator(); + Fr last_scalar = Fr(0); + NativeFr generator_coefficient = NativeFr(2).pow(_points.size()); + auto generator_coefficient_inverse = generator_coefficient.invert(); + // For each point and scalar + for (size_t i = 0; i < _points.size(); i++) { + scalars.push_back(_scalars[i]); + // Convert point into point + 2ⁱ⋅G_offset + points.push_back(_points[i] + (native_offset_generator * running_scalar)); + // Add \frac{2ⁱ⋅scalar}{2ⁿ} to the last scalar + last_scalar += _scalars[i] * (running_scalar * generator_coefficient_inverse); + // Double the running scalar + running_scalar += running_scalar; + } + + // Add a scalar -(<(1,2,4,...,2ⁿ⁻¹ ),(scalar₀,...,scalarₙ₋₁)> / 2ⁿ) + scalars.push_back(-last_scalar); + // Add in-circuit G_offset to points + points.push_back(element(native_offset_generator * generator_coefficient)); + + return { points, scalars }; +} + +/** + * @brief Replace all pairs (∞, scalar) by the pair (one, 0) where one is a fixed generator of the curve + * @details This is a step in enabling our our multiscalar multiplication algorithms to hande points at infinity. + */ +template +std::pair>, std::vector> element::handle_points_at_infinity( + const std::vector& _points, const std::vector& _scalars) +{ + auto builder = _points[0].get_context(); + std::vector points; + std::vector scalars; + element one = element::one(builder); + + for (auto [_point, _scalar] : zip_view(_points, _scalars)) { + bool_ct is_point_at_infinity = _point.is_point_at_infinity(); + if (is_point_at_infinity.get_value() && static_cast(is_point_at_infinity.is_constant())) { + // if point is at infinity and a circuit constant we can just skip. + continue; + } + if (_scalar.get_value() == 0 && _scalar.is_constant()) { + // if scalar multiplier is 0 and also a constant, we can skip + continue; + } + Fq updated_x = Fq::conditional_assign(is_point_at_infinity, one.x, _point.x); + Fq updated_y = Fq::conditional_assign(is_point_at_infinity, one.y, _point.y); + element point(updated_x, updated_y); + Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalar); + + points.push_back(point); + scalars.push_back(scalar); + // TODO(https://github.com/AztecProtocol/barretenberg/issues/1002): if both point and scalar are constant, don't + // bother adding constraints + } + + return { points, scalars }; +} +} // namespace bb::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp index 3d2752aa22e..6aa04a80223 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp @@ -748,17 +748,28 @@ std::pair, element> element::c } /** - * Generic batch multiplication that works for all elliptic curve types. - * - * Implementation is identical to `bn254_endo_batch_mul` but WITHOUT the endomorphism transforms OR support for short - * scalars See `bn254_endo_batch_mul` for description of algorithm - **/ + * @brief Generic batch multiplication that works for all elliptic curve types. + * + * @details Implementation is identical to `bn254_endo_batch_mul` but WITHOUT the endomorphism transforms OR support for + * short scalars See `bn254_endo_batch_mul` for description of algorithm. + * + * @tparam C The circuit builder type. + * @tparam Fq The field of definition of the points in `_points`. + * @tparam Fr The field of scalars acting on `_points`. + * @tparam G The group whose arithmetic is emulated by `element`. + * @param _points + * @param _scalars + * @param max_num_bits The max of the bit lengths of the scalars. + * @param with_edgecases Use when points are linearly dependent. Randomises them. + * @return element + */ template element element::batch_mul(const std::vector& _points, const std::vector& _scalars, - const size_t max_num_bits) + const size_t max_num_bits, + const bool with_edgecases) { - const auto [points, scalars] = handle_points_at_infinity(_points, _scalars); + auto [points, scalars] = handle_points_at_infinity(_points, _scalars); if constexpr (IsSimulator) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/663) @@ -776,8 +787,12 @@ element element::batch_mul(const std::vector && std::same_as) { return goblin_batch_mul(points, scalars); } else { + if (with_edgecases) { + std::tie(points, scalars) = mask_points(points, scalars); + } const size_t num_points = points.size(); ASSERT(scalars.size() == num_points); + batch_lookup_table point_table(points); const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp deleted file mode 100644 index b211e08d622..00000000000 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once -#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" - -namespace bb::stdlib { - -/** - * @brief Replace all pairs (∞, scalar) by the pair (one, 0) where one is a fixed generator of the curve - * @details This is a step in enabling our our multiscalar multiplication algorithms to hande points at infinity. - */ -template -std::pair>, std::vector> element::handle_points_at_infinity( - const std::vector& _points, const std::vector& _scalars) -{ - auto builder = _points[0].get_context(); - std::vector points; - std::vector scalars; - element one = element::one(builder); - - for (auto [_point, _scalar] : zip_view(_points, _scalars)) { - bool_ct is_point_at_infinity = _point.is_point_at_infinity(); - if (is_point_at_infinity.get_value() && static_cast(is_point_at_infinity.is_constant())) { - // if point is at infinity and a circuit constant we can just skip. - continue; - } - if (_scalar.get_value() == 0 && _scalar.is_constant()) { - // if scalar multiplier is 0 and also a constant, we can skip - continue; - } - Fq updated_x = Fq::conditional_assign(is_point_at_infinity, one.x, _point.x); - Fq updated_y = Fq::conditional_assign(is_point_at_infinity, one.y, _point.y); - element point(updated_x, updated_y); - Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalar); - - points.push_back(point); - scalars.push_back(scalar); - // TODO(https://github.com/AztecProtocol/barretenberg/issues/1002): if both point and scalar are constant, don't - // bother adding constraints - } - - return { points, scalars }; -} -} // namespace bb::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp index f8ffa3d12e6..c3c4a19c140 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp @@ -11,6 +11,7 @@ template class field_t { public: using View = field_t; + using native = bb::fr; field_t(Builder* parent_context = nullptr); field_t(Builder* parent_context, const bb::fr& value);