Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(avm): calldata gadget preliminaries #7227

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions barretenberg/cpp/pil/avm/main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace main(256);
pol constant sel_first = [1] + [0]*; // Used mostly to toggle off the first row consisting
// only in first element of shifted polynomials.

//===== PUBLIC COLUMNS=========================================================
pol public calldata;

//===== KERNEL INPUTS =========================================================
// Kernel lookup selector opcodes
pol commit sel_q_kernel_lookup;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
[[maybe_unused]] auto kernel_kernel_value_out = View(new_term.kernel_kernel_value_out); \
[[maybe_unused]] auto kernel_kernel_side_effect_out = View(new_term.kernel_kernel_side_effect_out); \
[[maybe_unused]] auto kernel_kernel_metadata_out = View(new_term.kernel_kernel_metadata_out); \
[[maybe_unused]] auto main_calldata = View(new_term.main_calldata); \
[[maybe_unused]] auto alu_a_hi = View(new_term.alu_a_hi); \
[[maybe_unused]] auto alu_a_lo = View(new_term.alu_a_lo); \
[[maybe_unused]] auto alu_b_hi = View(new_term.alu_b_hi); \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "barretenberg/vm/avm_trace/avm_execution.hpp"
#include "barretenberg/bb/log.hpp"
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/vm/avm_trace/avm_common.hpp"
#include "barretenberg/vm/avm_trace/avm_deserialization.hpp"
#include "barretenberg/vm/avm_trace/avm_helper.hpp"
Expand Down Expand Up @@ -78,10 +79,11 @@ std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<u
auto prover = composer.create_prover(circuit_builder);
auto verifier = composer.create_verifier(circuit_builder);

// The proof starts with the serialized public inputs
// Proof structure: public_inputs | calldata_size | calldata | raw proof
HonkProof proof(public_inputs_vec);
proof.emplace_back(calldata.size());
proof.insert(proof.end(), calldata.begin(), calldata.end());
auto raw_proof = prover.construct_proof();
// append the raw proof after the public inputs
proof.insert(proof.end(), raw_proof.begin(), raw_proof.end());
// TODO(#4887): Might need to return PCS vk when full verify is supported
return std::make_tuple(*verifier.key, proof);
Expand Down Expand Up @@ -261,14 +263,23 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof)
// crs_factory_);
// output_state.pcs_verification_key = std::move(pcs_verification_key);

// Proof structure: public_inputs | calldata_size | calldata | raw proof
std::vector<FF> public_inputs_vec;
std::vector<FF> calldata;
std::vector<FF> raw_proof;
std::copy(
proof.begin(), proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, std::back_inserter(public_inputs_vec));
std::copy(proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, proof.end(), std::back_inserter(raw_proof));

// This can be made nicer using BB's serialize::read, probably.
const auto public_inputs_offset = proof.begin();
const auto calldata_size_offset = public_inputs_offset + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH;
const auto calldata_offset = calldata_size_offset + 1;
const auto raw_proof_offset = calldata_offset + static_cast<int64_t>(uint64_t(*calldata_size_offset));

std::copy(public_inputs_offset, calldata_size_offset, std::back_inserter(public_inputs_vec));
std::copy(calldata_offset, raw_proof_offset, std::back_inserter(calldata));
std::copy(raw_proof_offset, proof.end(), std::back_inserter(raw_proof));

VmPublicInputs public_inputs = convert_public_inputs(public_inputs_vec);
std::vector<std::vector<FF>> public_inputs_columns = copy_public_inputs_columns(public_inputs);
std::vector<std::vector<FF>> public_inputs_columns = copy_public_inputs_columns(public_inputs, calldata);
return verifier.verify_proof(raw_proof, public_inputs_columns);
}

Expand Down Expand Up @@ -309,7 +320,7 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
uint32_t start_side_effect_counter =
!public_inputs_vec.empty() ? static_cast<uint32_t>(public_inputs_vec[PCPI_START_SIDE_EFFECT_COUNTER_OFFSET])
: 0;
AvmTraceBuilder trace_builder(public_inputs, execution_hints, start_side_effect_counter);
AvmTraceBuilder trace_builder(public_inputs, execution_hints, start_side_effect_counter, calldata);

// Copied version of pc maintained in trace builder. The value of pc is evolving based
// on opcode logic and therefore is not maintained here. However, the next opcode in the execution
Expand Down Expand Up @@ -436,8 +447,7 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
trace_builder.op_calldata_copy(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
calldata);
std::get<uint32_t>(inst.operands.at(3)));
break;
// Machine State - Gas
case OpCode::L2GASLEFT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ bool is_operand_indirect(uint8_t ind_value, uint8_t operand_idx)
return static_cast<bool>((ind_value & (1 << operand_idx)) >> operand_idx);
}

std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs)
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs,
std::vector<FF> const& calldata)
{
// We convert to a vector as the pil generated verifier is generic and unaware of the KERNEL_INPUTS_LENGTH
// For each of the public input vectors
Expand All @@ -158,7 +159,8 @@ std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& pu
return { std::move(public_inputs_kernel_inputs),
std::move(public_inputs_kernel_value_outputs),
std::move(public_inputs_kernel_side_effect_outputs),
std::move(public_inputs_kernel_metadata_outputs) };
std::move(public_inputs_kernel_metadata_outputs),
calldata };
}

} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bool is_operand_indirect(uint8_t ind_value, uint8_t operand_idx);
// There are 4 public input columns, one for inputs, and 3 for the kernel outputs {value, side effect counter, metadata}
// The verifier is generic, and so accepts vectors of these values rather than the fixed length arrays that are used
// during circuit building. This method copies each array into a vector to be used by the verifier.
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs);
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs,
std::vector<FF> const& calldata);

} // namespace bb::avm_trace
51 changes: 39 additions & 12 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ namespace bb::avm_trace {
*/
AvmTraceBuilder::AvmTraceBuilder(VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter)
uint32_t side_effect_counter,
std::vector<FF> calldata)
// NOTE: we initialise the environment builder here as it requires public inputs
: kernel_trace_builder(std::move(public_inputs))
, calldata(std::move(calldata))
, side_effect_counter(side_effect_counter)
, initial_side_effect_counter(side_effect_counter)
, execution_hints(std::move(execution_hints))
Expand Down Expand Up @@ -1886,10 +1888,8 @@ void AvmTraceBuilder::op_div(
* @param cd_offset The starting index of the region in calldata to be copied.
* @param copy_size The number of finite field elements to be copied into memory.
* @param dst_offset The starting index of memory where calldata will be copied to.
* @param call_data_mem The vector containing calldata.
*/
void AvmTraceBuilder::op_calldata_copy(
uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset, std::vector<FF> const& call_data_mem)
void AvmTraceBuilder::op_calldata_copy(uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset)
{
// We parallelize storing memory operations in chunk of 3, i.e., 1 per intermediate register.
// The variable pos is an index pointing to the first storing operation (pertaining to intermediate
Expand All @@ -1912,7 +1912,7 @@ void AvmTraceBuilder::op_calldata_copy(
uint32_t rwc(0);
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

FF ia = call_data_mem.at(cd_offset + pos);
FF ia = calldata.at(cd_offset + pos);
uint32_t mem_op_a(1);
uint32_t rwa = 1;

Expand All @@ -1934,7 +1934,7 @@ void AvmTraceBuilder::op_calldata_copy(
call_ptr, clk, IntermRegister::IA, mem_addr_a, ia, AvmMemoryTag::U0, AvmMemoryTag::FF);

if (copy_size - pos > 1) {
ib = call_data_mem.at(cd_offset + pos + 1);
ib = calldata.at(cd_offset + pos + 1);
mem_op_b = 1;
mem_addr_b = direct_dst_offset + pos + 1;
rwb = 1;
Expand All @@ -1945,7 +1945,7 @@ void AvmTraceBuilder::op_calldata_copy(
}

if (copy_size - pos > 2) {
ic = call_data_mem.at(cd_offset + pos + 2);
ic = calldata.at(cd_offset + pos + 2);
mem_op_c = 1;
mem_addr_c = direct_dst_offset + pos + 2;
rwc = 1;
Expand Down Expand Up @@ -3762,7 +3762,9 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c

main_trace.at(*trace_size - 1).main_sel_last = FF(1);

// Memory trace inclusion
/**********************************************************************************************
* MEMORY TRACE INCLUSION
**********************************************************************************************/

// We compute in the main loop the timestamp and global address for next row.
// Perform initialization for index 0 outside of the loop provided that mem trace exists.
Expand Down Expand Up @@ -3866,7 +3868,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

// Alu trace inclusion
/**********************************************************************************************
* ALU TRACE INCLUSION
**********************************************************************************************/

for (size_t i = 0; i < alu_trace_size; i++) {
auto const& src = alu_trace.at(i);
auto& dest = main_trace.at(i);
Expand Down Expand Up @@ -4013,6 +4018,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/**********************************************************************************************
* GADGET TABLES INCLUSION
**********************************************************************************************/

// Add Conversion Gadget table
for (size_t i = 0; i < conv_trace_size; i++) {
auto const& src = conv_trace.at(i);
Expand Down Expand Up @@ -4067,6 +4076,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
dest.pedersen_sel_pedersen = FF(1);
}

/**********************************************************************************************
* BINARY TRACE INCLUSION
**********************************************************************************************/

// Add Binary Trace table
for (size_t i = 0; i < bin_trace_size; i++) {
auto const& src = bin_trace.at(i);
Expand Down Expand Up @@ -4132,7 +4145,9 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/////////// GAS ACCOUNTING //////////////////////////
/**********************************************************************************************
* GAS TRACE INCLUSION
**********************************************************************************************/

// Add the gas cost table to the main trace
// TODO: do i need a way to produce an interupt that will stop the execution of the trace when the gas left
Expand Down Expand Up @@ -4222,11 +4237,14 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
dest.main_da_gas_remaining = current_da_gas_remaining;
}

/////////// END OF GAS ACCOUNTING //////////////////////////

// Adding extra row for the shifted values at the top of the execution trace.
Row first_row = Row{ .main_sel_first = FF(1), .mem_lastAccess = FF(1) };
main_trace.insert(main_trace.begin(), first_row);

/**********************************************************************************************
* RANGE CHECKS AND SELECTORS INCLUSION
**********************************************************************************************/

auto const old_trace_size = main_trace.size();

auto new_trace_size = range_check_required ? old_trace_size
Expand Down Expand Up @@ -4316,6 +4334,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/**********************************************************************************************
* KERNEL TRACE INCLUSION
**********************************************************************************************/

// Write the kernel trace into the main trace
// 1. The write offsets are constrained to be non changing over the entire trace, so we fill in the values
// until we
Expand Down Expand Up @@ -4494,6 +4516,11 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
std::get<KERNEL_OUTPUTS_METADATA>(kernel_trace_builder.public_inputs).at(i);
}

// calldata column inclusion
for (size_t i = 0; i < calldata.size(); i++) {
main_trace.at(i).main_calldata = calldata.at(i);
}

// Get tag_err counts from the mem_trace_builder
if (range_check_required) {
finalise_mem_trace_lookup_counts();
Expand Down
11 changes: 5 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class AvmTraceBuilder {
public:
AvmTraceBuilder(VmPublicInputs public_inputs = {},
ExecutionHints execution_hints = {},
uint32_t side_effect_counter = 0);
uint32_t side_effect_counter = 0,
std::vector<FF> calldata = {});

std::vector<Row> finalize(uint32_t min_trace_size = 0, bool range_check_required = ENABLE_PROVING);
void reset();
Expand Down Expand Up @@ -158,11 +159,7 @@ class AvmTraceBuilder {
// CALLDATACOPY opcode with direct/indirect memory access, i.e.,
// direct: M[dst_offset:dst_offset+copy_size] = calldata[cd_offset:cd_offset+copy_size]
// indirect: M[M[dst_offset]:M[dst_offset]+copy_size] = calldata[cd_offset:cd_offset+copy_size]
void op_calldata_copy(uint8_t indirect,
uint32_t cd_offset,
uint32_t copy_size,
uint32_t dst_offset,
std::vector<FF> const& call_data_mem);
void op_calldata_copy(uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset);

// REVERT Opcode (that just call return under the hood for now)
std::vector<FF> op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size);
Expand Down Expand Up @@ -241,6 +238,8 @@ class AvmTraceBuilder {
AvmPedersenTraceBuilder pedersen_trace_builder;
AvmEccTraceBuilder ecc_trace_builder;

std::vector<FF> calldata{};

/**
* @brief Create a kernel lookup opcode object
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ template <typename FF> std::vector<std::string> AvmFullRow<FF>::names()
"kernel_kernel_value_out",
"kernel_kernel_side_effect_out",
"kernel_kernel_metadata_out",
"main_calldata",
"alu_a_hi",
"alu_a_lo",
"alu_b_hi",
Expand Down Expand Up @@ -412,18 +413,19 @@ template <typename FF> std::ostream& operator<<(std::ostream& os, AvmFullRow<FF>
<< field_to_string(row.main_clk) << "," << field_to_string(row.main_sel_first) << ","
<< field_to_string(row.kernel_kernel_inputs) << "," << field_to_string(row.kernel_kernel_value_out) << ","
<< field_to_string(row.kernel_kernel_side_effect_out) << ","
<< field_to_string(row.kernel_kernel_metadata_out) << "," << field_to_string(row.alu_a_hi) << ","
<< field_to_string(row.alu_a_lo) << "," << field_to_string(row.alu_b_hi) << ","
<< field_to_string(row.alu_b_lo) << "," << field_to_string(row.alu_borrow) << ","
<< field_to_string(row.alu_cf) << "," << field_to_string(row.alu_clk) << ","
<< field_to_string(row.alu_cmp_rng_ctr) << "," << field_to_string(row.alu_div_u16_r0) << ","
<< field_to_string(row.alu_div_u16_r1) << "," << field_to_string(row.alu_div_u16_r2) << ","
<< field_to_string(row.alu_div_u16_r3) << "," << field_to_string(row.alu_div_u16_r4) << ","
<< field_to_string(row.alu_div_u16_r5) << "," << field_to_string(row.alu_div_u16_r6) << ","
<< field_to_string(row.alu_div_u16_r7) << "," << field_to_string(row.alu_divisor_hi) << ","
<< field_to_string(row.alu_divisor_lo) << "," << field_to_string(row.alu_ff_tag) << ","
<< field_to_string(row.alu_ia) << "," << field_to_string(row.alu_ib) << "," << field_to_string(row.alu_ic)
<< "," << field_to_string(row.alu_in_tag) << "," << field_to_string(row.alu_op_add) << ","
<< field_to_string(row.kernel_kernel_metadata_out) << "," << field_to_string(row.main_calldata) << ","
<< field_to_string(row.alu_a_hi) << "," << field_to_string(row.alu_a_lo) << ","
<< field_to_string(row.alu_b_hi) << "," << field_to_string(row.alu_b_lo) << ","
<< field_to_string(row.alu_borrow) << "," << field_to_string(row.alu_cf) << ","
<< field_to_string(row.alu_clk) << "," << field_to_string(row.alu_cmp_rng_ctr) << ","
<< field_to_string(row.alu_div_u16_r0) << "," << field_to_string(row.alu_div_u16_r1) << ","
<< field_to_string(row.alu_div_u16_r2) << "," << field_to_string(row.alu_div_u16_r3) << ","
<< field_to_string(row.alu_div_u16_r4) << "," << field_to_string(row.alu_div_u16_r5) << ","
<< field_to_string(row.alu_div_u16_r6) << "," << field_to_string(row.alu_div_u16_r7) << ","
<< field_to_string(row.alu_divisor_hi) << "," << field_to_string(row.alu_divisor_lo) << ","
<< field_to_string(row.alu_ff_tag) << "," << field_to_string(row.alu_ia) << ","
<< field_to_string(row.alu_ib) << "," << field_to_string(row.alu_ic) << ","
<< field_to_string(row.alu_in_tag) << "," << field_to_string(row.alu_op_add) << ","
<< field_to_string(row.alu_op_cast) << "," << field_to_string(row.alu_op_cast_prev) << ","
<< field_to_string(row.alu_op_div) << "," << field_to_string(row.alu_op_div_a_lt_b) << ","
<< field_to_string(row.alu_op_div_std) << "," << field_to_string(row.alu_op_eq) << ","
Expand Down
Loading
Loading