Skip to content

Commit

Permalink
5542: unit tests for SET opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanmon committed Apr 4, 2024
1 parent 52ea38a commit 59fb213
Showing 1 changed file with 125 additions and 16 deletions.
141 changes: 125 additions & 16 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_mem_opcodes.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ class AvmMemOpcodeTests : public ::testing::Test {
trace = trace_builder.finalize();
}

static std::function<bool(Row)> gen_matcher(FF clk, uint32_t sub_clk)
{
return [clk, sub_clk](Row r) { return r.avm_mem_clk == clk && r.avm_mem_sub_clk == sub_clk; };
};

void compute_common_indices(FF clk, bool indirect)
{
// Find the memory trace position corresponding to the write sub-operation of register ic.
auto row =
std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(clk, AvmMemTraceBuilder::SUB_CLK_STORE_C));
ASSERT_TRUE(row != trace.end());
mem_c_idx = static_cast<size_t>(row - trace.begin());

// Find the memory trace position of the indirect load for register ic.
if (indirect) {
row = std::ranges::find_if(
trace.begin(), trace.end(), gen_matcher(clk, AvmMemTraceBuilder::SUB_CLK_IND_LOAD_C));
ASSERT_TRUE(row != trace.end());
mem_ind_c_idx = static_cast<size_t>(row - trace.begin());
}
}

void compute_mov_indices(bool indirect)
{
// Find the first row enabling the MOV selector
Expand All @@ -55,30 +77,20 @@ class AvmMemOpcodeTests : public ::testing::Test {

auto clk = row->avm_main_clk;

auto gen_matcher = [clk](uint32_t sub_clk) {
return [clk, sub_clk](Row r) { return r.avm_mem_clk == clk && r.avm_mem_sub_clk == sub_clk; };
};

// Find the memory trace position corresponding to the load sub-operation of register ia.
row = std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(AvmMemTraceBuilder::SUB_CLK_LOAD_A));
row = std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(clk, AvmMemTraceBuilder::SUB_CLK_LOAD_A));
ASSERT_TRUE(row != trace.end());
mem_a_idx = static_cast<size_t>(row - trace.begin());

// Find the memory trace position corresponding to the write sub-operation of register ic.
row = std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(AvmMemTraceBuilder::SUB_CLK_STORE_C));
ASSERT_TRUE(row != trace.end());
mem_c_idx = static_cast<size_t>(row - trace.begin());

// Find the memory trace position of the indirect loads.
// Find the memory trace position of the indirect load for register ia.
if (indirect) {
row = std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(AvmMemTraceBuilder::SUB_CLK_IND_LOAD_A));
row = std::ranges::find_if(
trace.begin(), trace.end(), gen_matcher(clk, AvmMemTraceBuilder::SUB_CLK_IND_LOAD_A));
ASSERT_TRUE(row != trace.end());
mem_ind_a_idx = static_cast<size_t>(row - trace.begin());

row = std::ranges::find_if(trace.begin(), trace.end(), gen_matcher(AvmMemTraceBuilder::SUB_CLK_IND_LOAD_C));
ASSERT_TRUE(row != trace.end());
mem_ind_c_idx = static_cast<size_t>(row - trace.begin());
}

compute_common_indices(clk, indirect);
}

void validate_mov_trace(bool indirect,
Expand Down Expand Up @@ -221,6 +233,103 @@ TEST_F(AvmMemOpcodeTests, indirectMovInvalidAddressTag)
validate_trace_proof(std::move(trace));
}

TEST_F(AvmMemOpcodeTests, directSet)
{
trace_builder.op_set(0, 5683, 99, AvmMemoryTag::U128);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();

compute_common_indices(0, false);
auto const& row = trace.at(1);

EXPECT_THAT(row,
AllOf(Field(&Row::avm_main_tag_err, 0),
Field(&Row::avm_main_ic, 5683),
Field(&Row::avm_main_mem_idx_c, 99),
Field(&Row::avm_main_mem_op_c, 1),
Field(&Row::avm_main_rwc, 1),
Field(&Row::avm_main_ind_op_c, 0)));

EXPECT_THAT(trace.at(mem_c_idx),
AllOf(Field(&Row::avm_mem_val, 5683),
Field(&Row::avm_mem_addr, 99),
Field(&Row::avm_mem_op_c, 1),
Field(&Row::avm_mem_rw, 1),
Field(&Row::avm_mem_ind_op_c, 0)));

validate_trace_proof(std::move(trace));
}

TEST_F(AvmMemOpcodeTests, indirectSet)
{
trace_builder.op_set(0, 100, 10, AvmMemoryTag::U32);
trace_builder.op_set(1, 1979, 10, AvmMemoryTag::U64); // Set 1979 at memory index 100
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();

compute_common_indices(1, true);
auto const& row = trace.at(2);

EXPECT_THAT(row,
AllOf(Field(&Row::avm_main_tag_err, 0),
Field(&Row::avm_main_ic, 1979),
Field(&Row::avm_main_mem_idx_c, 100),
Field(&Row::avm_main_mem_op_c, 1),
Field(&Row::avm_main_rwc, 1),
Field(&Row::avm_main_ind_op_c, 1),
Field(&Row::avm_main_ind_c, 10)));

EXPECT_THAT(trace.at(mem_c_idx),
AllOf(Field(&Row::avm_mem_val, 1979),
Field(&Row::avm_mem_addr, 100),
Field(&Row::avm_mem_op_c, 1),
Field(&Row::avm_mem_rw, 1),
Field(&Row::avm_mem_ind_op_c, 0),
Field(&Row::avm_mem_w_in_tag, static_cast<uint32_t>(AvmMemoryTag::U64)),
Field(&Row::avm_mem_tag, static_cast<uint32_t>(AvmMemoryTag::U64))));

EXPECT_THAT(trace.at(mem_ind_c_idx),
AllOf(Field(&Row::avm_mem_val, 100),
Field(&Row::avm_mem_addr, 10),
Field(&Row::avm_mem_op_c, 0),
Field(&Row::avm_mem_rw, 0),
Field(&Row::avm_mem_ind_op_c, 1),
Field(&Row::avm_mem_r_in_tag, static_cast<uint32_t>(AvmMemoryTag::U32)),
Field(&Row::avm_mem_tag, static_cast<uint32_t>(AvmMemoryTag::U32))));

validate_trace_proof(std::move(trace));
}

TEST_F(AvmMemOpcodeTests, indirectSetWrongTag)
{
trace_builder.op_set(0, 100, 10, AvmMemoryTag::U8); // The address 100 has incorrect tag U8.
trace_builder.op_set(1, 1979, 10, AvmMemoryTag::U64); // Set 1979 at memory index 100
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();

compute_common_indices(1, true);
auto const& row = trace.at(2);

EXPECT_THAT(row,
AllOf(Field(&Row::avm_main_tag_err, 1),
Field(&Row::avm_main_mem_op_c, 1),
Field(&Row::avm_main_rwc, 1),
Field(&Row::avm_main_ind_op_c, 1),
Field(&Row::avm_main_ind_c, 10)));

EXPECT_THAT(trace.at(mem_ind_c_idx),
AllOf(Field(&Row::avm_mem_val, 100),
Field(&Row::avm_mem_addr, 10),
Field(&Row::avm_mem_op_c, 0),
Field(&Row::avm_mem_rw, 0),
Field(&Row::avm_mem_ind_op_c, 1),
Field(&Row::avm_mem_r_in_tag, static_cast<uint32_t>(AvmMemoryTag::U32)),
Field(&Row::avm_mem_tag, static_cast<uint32_t>(AvmMemoryTag::U8)),
Field(&Row::avm_mem_tag_err, 1)));

validate_trace_proof(std::move(trace));
}

/******************************************************************************
*
* MEMORY OPCODE NEGATIVE TESTS
Expand Down

0 comments on commit 59fb213

Please sign in to comment.