Skip to content

Commit

Permalink
feat: Origin Tags Part 2 (#8936)
Browse files Browse the repository at this point in the history
This PR extends the Origin Tags ( a taint mechanism for stdlib
primitives) to:
1) bigfield
2) byte_array
3) safe_uint

And extends tests with checks that the tags are preserved or merged
correctly
  • Loading branch information
Rumata888 authored Oct 7, 2024
1 parent 9c9c385 commit 77c05f5
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,23 @@ template <typename Builder, typename T> class bigfield {

Builder* get_context() const { return context; }

void set_origin_tag(const bb::OriginTag& tag) const
{
for (size_t i = 0; i < 4; i++) {
binary_basis_limbs[i].element.set_origin_tag(tag);
}
prime_basis_limb.set_origin_tag(tag);
}

bb::OriginTag get_origin_tag() const
{
return bb::OriginTag(binary_basis_limbs[0].element.tag,
binary_basis_limbs[1].element.tag,
binary_basis_limbs[2].element.tag,
binary_basis_limbs[3].element.tag,
prime_basis_limb.tag);
}

static constexpr uint512_t get_maximum_unreduced_value(const size_t num_products = 1)
{
// return (uint512_t(1) << 256);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "barretenberg/circuit_checker/circuit_checker.hpp"
#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp"
#include "barretenberg/stdlib/primitives/curves/bn254.hpp"
#include "barretenberg/transcript/origin_tag.hpp"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
Expand All @@ -29,6 +30,7 @@ namespace {
auto& engine = numeric::get_debug_randomness();
}

STANDARD_TESTING_TAGS
template <typename Builder> class stdlib_bigfield : public testing::Test {

typedef stdlib::bn254<Builder> bn254;
Expand Down Expand Up @@ -84,6 +86,30 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
}

public:
static void test_basic_tag_logic()
{
auto builder = Builder();
auto input = fq::random_element();
fq_ct a(witness_ct(&builder, fr(uint256_t(input).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder, fr(uint256_t(input).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
a.binary_basis_limbs[0].element.set_origin_tag(submitted_value_origin_tag);
a.binary_basis_limbs[1].element.set_origin_tag(challenge_origin_tag);
a.prime_basis_limb.set_origin_tag(next_challenge_tag);

EXPECT_EQ(a.get_origin_tag(), first_second_third_merged_tag);

a.set_origin_tag(clear_tag);
EXPECT_EQ(a.binary_basis_limbs[0].element.get_origin_tag(), clear_tag);
EXPECT_EQ(a.binary_basis_limbs[1].element.get_origin_tag(), clear_tag);
EXPECT_EQ(a.binary_basis_limbs[2].element.get_origin_tag(), clear_tag);
EXPECT_EQ(a.binary_basis_limbs[3].element.get_origin_tag(), clear_tag);
EXPECT_EQ(a.prime_basis_limb.get_origin_tag(), clear_tag);

#ifndef NDEBUG
a.set_origin_tag(instant_death_tag);
EXPECT_THROW(a + a, std::runtime_error);
#endif
}
static void test_mul()
{
auto builder = Builder();
Expand All @@ -93,11 +119,15 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
a.set_origin_tag(submitted_value_origin_tag);
fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
b.set_origin_tag(challenge_origin_tag);

uint64_t before = builder.get_num_gates();
fq_ct c = a * b;
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
uint64_t after = builder.get_num_gates();
// Don't profile 1st repetition. It sets up a lookup table, cost is not representative of a typical mul
if (i == num_repetitions - 2) {
Expand Down Expand Up @@ -135,8 +165,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
a.set_origin_tag(next_challenge_tag);

uint64_t before = builder.get_num_gates();
fq_ct c = a.sqr();

// Squaring preserves tags
EXPECT_EQ(a.get_origin_tag(), next_challenge_tag);
uint64_t after = builder.get_num_gates();
if (i == num_repetitions - 1) {
std::cerr << "num gates per mul = " << after - before << std::endl;
Expand Down Expand Up @@ -180,9 +215,14 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct c(witness_ct(&builder, fr(uint256_t(inputs[2]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[2]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

a.set_origin_tag(challenge_origin_tag);
b.set_origin_tag(submitted_value_origin_tag);
c.set_origin_tag(next_challenge_tag);
uint64_t before = builder.get_num_gates();
fq_ct d = a.madd(b, { c });

// Madd merges tags
EXPECT_EQ(d.get_origin_tag(), first_second_third_merged_tag);
uint64_t after = builder.get_num_gates();
if (i == num_repetitions - 1) {
std::cerr << "num gates per mul = " << after - before << std::endl;
Expand Down Expand Up @@ -241,8 +281,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
to_add.emplace_back(
fq_ct::create_from_u512_as_witness(&builder, uint512_t(uint256_t(to_add_values[j]))));
}
mul_left[number_of_madds - 1].set_origin_tag(submitted_value_origin_tag);
mul_right[number_of_madds - 1].set_origin_tag(challenge_origin_tag);
to_add[number_of_madds - 1].set_origin_tag(next_challenge_tag);
uint64_t before = builder.get_num_gates();
fq_ct f = fq_ct::mult_madd(mul_left, mul_right, to_add);
// mult_madd merges tags
EXPECT_EQ(f.get_origin_tag(), first_second_third_merged_tag);
uint64_t after = builder.get_num_gates();
if (i == num_repetitions - 1) {
std::cerr << "num gates with mult_madd = " << after - before << std::endl;
Expand Down Expand Up @@ -303,8 +348,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
witness_ct(&builder,
fr(uint256_t(inputs[4]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

a.set_origin_tag(submitted_value_origin_tag);
d.set_origin_tag(challenge_origin_tag);
e.set_origin_tag(next_challenge_tag);
uint64_t before = builder.get_num_gates();
fq_ct f = fq_ct::dual_madd(a, b, c, d, { e });
// dual_madd merges tags
EXPECT_EQ(f.get_origin_tag(), first_second_third_merged_tag);
uint64_t after = builder.get_num_gates();
if (i == num_repetitions - 1) {
std::cerr << "num gates per mul = " << after - before << std::endl;
Expand Down Expand Up @@ -348,8 +398,11 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
a.set_origin_tag(submitted_value_origin_tag);
b.set_origin_tag(challenge_origin_tag);
uint64_t before = builder.get_num_gates();
fq_ct c = a / b;
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
uint64_t after = builder.get_num_gates();
if (i == num_repetitions - 1) {
std::cout << "num gates per div = " << after - before << std::endl;
Expand Down Expand Up @@ -396,7 +449,11 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
b.set_origin_tag(submitted_value_origin_tag);
c.set_origin_tag(challenge_origin_tag);
d.set_origin_tag(next_challenge_tag);
fq_ct e = (a + b) / (c + d);
EXPECT_EQ(e.get_origin_tag(), first_second_third_merged_tag);
// uint256_t modulus{ Bn254FqParams::modulus_0,
// Bn254FqParams::modulus_1,
// Bn254FqParams::modulus_2,
Expand Down Expand Up @@ -439,7 +496,14 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

b.set_origin_tag(submitted_value_origin_tag);
c.set_origin_tag(challenge_origin_tag);
d.set_origin_tag(next_challenge_tag);

fq_ct e = (a + b) * (c + d);

EXPECT_EQ(e.get_origin_tag(), first_second_third_merged_tag);
fq expected = (inputs[0] + inputs[1]) * (inputs[2] + inputs[3]);
expected = expected.from_montgomery_form();
uint512_t result = e.get_value();
Expand Down Expand Up @@ -473,7 +537,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
witness_ct(&builder,
fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));
fq_ct d(&builder, constants[1]);
b.set_origin_tag(submitted_value_origin_tag);
c.set_origin_tag(challenge_origin_tag);
d.set_origin_tag(next_challenge_tag);

fq_ct e = (a + b) * (c + d);

EXPECT_EQ(e.get_origin_tag(), first_second_third_merged_tag);
fq expected = (inputs[0] + constants[0]) * (inputs[1] + constants[1]);
expected = expected.from_montgomery_form();
uint512_t result = e.get_value();
Expand Down Expand Up @@ -510,7 +580,14 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))),
witness_ct(&builder,
fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4))));

b.set_origin_tag(submitted_value_origin_tag);
c.set_origin_tag(challenge_origin_tag);
d.set_origin_tag(next_challenge_tag);

fq_ct e = (a - b) * (c - d);

EXPECT_EQ(e.get_origin_tag(), first_second_third_merged_tag);
fq expected = (inputs[0] - inputs[1]) * (inputs[2] - inputs[3]);

expected = expected.from_montgomery_form();
Expand Down Expand Up @@ -542,8 +619,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
auto [to_sub1, to_sub1_ct] = get_random_element(&builder);
auto [to_sub2, to_sub2_ct] = get_random_element(&builder);

mul_l_ct.set_origin_tag(submitted_value_origin_tag);
mul_r1_ct.set_origin_tag(challenge_origin_tag);
divisor1_ct.set_origin_tag(next_submitted_value_origin_tag);
to_sub1_ct.set_origin_tag(next_challenge_tag);
fq_ct result_ct = fq_ct::msub_div(
{ mul_l_ct }, { mul_r1_ct - mul_r2_ct }, divisor1_ct - divisor2_ct, { to_sub1_ct, to_sub2_ct });
EXPECT_EQ(result_ct.get_origin_tag(), first_to_fourth_merged_tag);
fq expected = (-(mul_l * (mul_r1 - mul_r2) + to_sub1 + to_sub2)) / (divisor1 - divisor2);
EXPECT_EQ(result_ct.get_value().lo, uint256_t(expected));
EXPECT_EQ(result_ct.get_value().hi, uint256_t(0));
Expand All @@ -570,10 +652,18 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
// fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2,
// fq_ct::NUM_LIMB_BITS * 4))));

a.set_origin_tag(submitted_value_origin_tag);

typename bn254::bool_ct predicate_a(witness_ct(&builder, true));

predicate_a.set_origin_tag(challenge_origin_tag);
// bool_ct predicate_b(witness_ct(&builder, false));

fq_ct c = a.conditional_negate(predicate_a);

// Conditional negate merges tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);

fq_ct d = a.conditional_negate(!predicate_a);
fq_ct e = c + d;
uint512_t c_out = c.get_value();
Expand Down Expand Up @@ -663,8 +753,12 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
c = b * b + c;
expected = inputs[1] * inputs[1] + expected;
}

c.set_origin_tag(challenge_origin_tag);
// fq_ct c = a + a + a + a - b - b - b - b;
c.self_reduce();
// self_reduce preserves tags
EXPECT_EQ(c.get_origin_tag(), challenge_origin_tag);
fq result = fq(c.get_value().lo);
EXPECT_EQ(result, expected);
EXPECT_EQ(c.get_value().get_msb() < 254, true);
Expand Down Expand Up @@ -766,11 +860,16 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
stdlib::byte_array<Builder> input_arr_a(&builder, input_a);
stdlib::byte_array<Builder> input_arr_b(&builder, input_b);

input_arr_a.set_origin_tag(submitted_value_origin_tag);
input_arr_b.set_origin_tag(challenge_origin_tag);

fq_ct a(input_arr_a);
fq_ct b(input_arr_b);

fq_ct c = a * b;

EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);

fq expected = inputs[0] * inputs[1];
uint256_t result = (c.get_value().lo);
EXPECT_EQ(result, uint256_t(expected));
Expand Down Expand Up @@ -906,9 +1005,13 @@ template <typename Builder> class stdlib_bigfield : public testing::Test {
EXPECT_EQ(fq(result.get_value()), expected);

fr_ct exponent = witness_ct(&builder, exponent_val);
base.set_origin_tag(submitted_value_origin_tag);
exponent.set_origin_tag(challenge_origin_tag);
result = base.pow(exponent);
EXPECT_EQ(fq(result.get_value()), expected);

EXPECT_EQ(result.get_origin_tag(), first_two_merged_tag);

info("num gates = ", builder.get_num_gates());
bool check_result = CircuitChecker::check(builder);
EXPECT_EQ(check_result, true);
Expand All @@ -923,6 +1026,10 @@ TYPED_TEST(stdlib_bigfield, badmul)
{
TestFixture::test_division_formula_bug();
}
TYPED_TEST(stdlib_bigfield, basic_tag_logic)
{
TestFixture::test_basic_tag_logic();
}
TYPED_TEST(stdlib_bigfield, mul)
{
TestFixture::test_mul();
Expand Down
Loading

0 comments on commit 77c05f5

Please sign in to comment.