Skip to content

Commit

Permalink
feat(avm): plumb execution hints from TS to AVM prover (#6806)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanks12 authored Jun 4, 2024
1 parent 2a3a098 commit f3234f1
Show file tree
Hide file tree
Showing 21 changed files with 452 additions and 59 deletions.
29 changes: 19 additions & 10 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,27 +521,33 @@ void vk_as_fields(const std::string& vk_path, const std::string& output_path)
*
* @param bytecode_path Path to the file containing the serialised bytecode
* @param calldata_path Path to the file containing the serialised calldata (could be empty)
* @param crs_path Path to the file containing the CRS (ignition is suitable for now)
* @param public_inputs_path Path to the file containing the serialised avm public inputs
* @param hints_path Path to the file containing the serialised avm circuit hints
* @param output_path Path (directory) to write the output proof and verification keys
*/
void avm_prove(const std::filesystem::path& bytecode_path,
const std::filesystem::path& calldata_path,
const std::filesystem::path& public_inputs_path,
const std::filesystem::path& hints_path,
const std::filesystem::path& output_path)
{
// Get Bytecode
std::vector<uint8_t> const avm_bytecode =
std::vector<uint8_t> const bytecode =
bytecode_path.extension() == ".gz" ? gunzip(bytecode_path) : read_file(bytecode_path);
std::vector<uint8_t> call_data_bytes{};
if (std::filesystem::exists(calldata_path)) {
call_data_bytes = read_file(calldata_path);
std::vector<fr> const calldata = many_from_buffer<fr>(read_file(calldata_path));
std::vector<fr> const public_inputs_vec = many_from_buffer<fr>(read_file(public_inputs_path));
std::vector<uint8_t> avm_hints;
try {
avm_hints = read_file(hints_path);
} catch (std::runtime_error const& err) {
vinfo("No hints were provided for avm proving.... Might be fine!");
}
std::vector<fr> const call_data = many_from_buffer<fr>(call_data_bytes);

// Hardcoded circuit size for now, with enough to support 16-bit range checks
init_bn254_crs(1 << 17);

// Prove execution and return vk
auto const [verification_key, proof] = avm_trace::Execution::prove(avm_bytecode, call_data);
auto const [verification_key, proof] = avm_trace::Execution::prove(bytecode, calldata, public_inputs_vec);
// TODO(ilyas): <#4887>: Currently we only need these two parts of the vk, look into pcs_verification key reqs
std::vector<uint64_t> vk_vector = { verification_key.circuit_size, verification_key.num_public_inputs };
std::vector<fr> vk_as_fields = { verification_key.circuit_size, verification_key.num_public_inputs };
Expand Down Expand Up @@ -890,11 +896,14 @@ int main(int argc, char* argv[])
std::string output_path = get_option(args, "-o", vk_path + "_fields.json");
vk_as_fields(vk_path, output_path);
} else if (command == "avm_prove") {
std::filesystem::path avm_bytecode_path = get_option(args, "-b", "./target/avm_bytecode.bin");
std::filesystem::path calldata_path = get_option(args, "-d", "./target/call_data.bin");
std::filesystem::path avm_bytecode_path = get_option(args, "--avm-bytecode", "./target/avm_bytecode.bin");
std::filesystem::path avm_calldata_path = get_option(args, "--avm-calldata", "./target/avm_calldata.bin");
std::filesystem::path avm_public_inputs_path =
get_option(args, "--avm-public-inputs", "./target/avm_public_inputs.bin");
std::filesystem::path avm_hints_path = get_option(args, "--avm-hints", "./target/avm_hints.bin");
// This outputs both files: proof and vk, under the given directory.
std::filesystem::path output_path = get_option(args, "-o", "./proofs");
avm_prove(avm_bytecode_path, calldata_path, output_path);
avm_prove(avm_bytecode_path, avm_calldata_path, avm_public_inputs_path, avm_hints_path, output_path);
} else if (command == "avm_verify") {
return avm_verify(proof_path, vk_path) ? 0 : 1;
} else if (command == "prove_ultra_honk") {
Expand Down
35 changes: 23 additions & 12 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace bb::avm_trace {
std::vector<FF> Execution::getDefaultPublicInputs()
{
std::vector<FF> public_inputs_vec(PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH);
public_inputs_vec.at(DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = 1000000000;
public_inputs_vec.at(L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = 1000000000;
public_inputs_vec.at(DA_START_GAS_LEFT_PCPI_OFFSET) = 1000000000;
public_inputs_vec.at(L2_START_GAS_LEFT_PCPI_OFFSET) = 1000000000;
return public_inputs_vec;
}

Expand All @@ -51,19 +51,25 @@ std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<u
std::vector<FF> const& public_inputs_vec,
ExecutionHints const& execution_hints)
{
// TODO: temp
info("logging to silence warning for now: ", public_inputs_vec.size());
if (public_inputs_vec.size() != PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH) {
throw_or_abort("Public inputs vector is not of PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH");
}

auto instructions = Deserialization::parse(bytecode);
std::vector<FF> returndata{};
auto trace = gen_trace(instructions, returndata, calldata, getDefaultPublicInputs(), execution_hints);
auto trace = gen_trace(instructions, returndata, calldata, public_inputs_vec, execution_hints);
auto circuit_builder = bb::AvmCircuitBuilder();
circuit_builder.set_trace(std::move(trace));

auto composer = AvmComposer();
auto prover = composer.create_prover(circuit_builder);
auto verifier = composer.create_verifier(circuit_builder);
auto proof = prover.construct_proof();

// The proof starts with the serialized public inputs
HonkProof proof(public_inputs_vec);
auto raw_proof = prover.construct_proof();
// append the raw proof after the public inputs
proof.insert(proof.end(), raw_proof.begin(), raw_proof.end());
// TODO(#4887): Might need to return PCS vk when full verify is supported
return std::make_tuple(*verifier.key, proof);
}
Expand Down Expand Up @@ -130,8 +136,8 @@ VmPublicInputs Execution::convert_public_inputs(std::vector<FF> const& public_in
// Transaction fee
kernel_inputs[TRANSACTION_FEE_SELECTOR] = public_inputs_vec[TRANSACTION_FEE_OFFSET];

kernel_inputs[DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET] = public_inputs_vec[DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET];
kernel_inputs[L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET] = public_inputs_vec[L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET];
kernel_inputs[DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET] = public_inputs_vec[DA_START_GAS_LEFT_PCPI_OFFSET];
kernel_inputs[L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET] = public_inputs_vec[L2_START_GAS_LEFT_PCPI_OFFSET];

return public_inputs;
}
Expand All @@ -146,10 +152,15 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof)
// crs_factory_);
// output_state.pcs_verification_key = std::move(pcs_verification_key);

// TODO: We hardcode public inputs for now
VmPublicInputs public_inputs = convert_public_inputs(getDefaultPublicInputs());
std::vector<std::vector<FF>> public_inputs_vec = copy_public_inputs_columns(public_inputs);
return verifier.verify_proof(proof, public_inputs_vec);
std::vector<FF> public_inputs_vec;
std::vector<FF> raw_proof;
std::copy(
proof.begin(), proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, std::back_inserter(public_inputs_vec));
std::copy(proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, proof.end(), std::back_inserter(raw_proof));

VmPublicInputs public_inputs = convert_public_inputs(public_inputs_vec);
std::vector<std::vector<FF>> public_inputs_columns = copy_public_inputs_columns(public_inputs);
return verifier.verify_proof(raw_proof, public_inputs_columns);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class Execution {
std::vector<FF> const& public_inputs,
ExecutionHints const& execution_hints);

static std::tuple<AvmFlavor::VerificationKey, bb::HonkProof> prove(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata = {},
std::vector<FF> const& public_inputs_vec = {},
ExecutionHints const& execution_hints = {});
static std::tuple<AvmFlavor::VerificationKey, bb::HonkProof> prove(
std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata = {},
std::vector<FF> const& public_inputs_vec = getDefaultPublicInputs(),
ExecutionHints const& execution_hints = {});
static bool verify(AvmFlavor::VerificationKey vk, HonkProof const& proof);
};

Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ inline const uint32_t FEE_PER_DA_GAS_OFFSET = PCPI_GLOBALS_START + 6;
inline const uint32_t FEE_PER_L2_GAS_OFFSET = PCPI_GLOBALS_START + 7;

inline const uint32_t TRANSACTION_FEE_OFFSET = PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH - 1;
inline const uint32_t DA_GAS_LEFT_PCPI_OFFSET = PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH - 3 - GAS_LENGTH;
inline const uint32_t L2_GAS_LEFT_PCPI_OFFSET = PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH - 2 - GAS_LENGTH;
inline const uint32_t DA_START_GAS_LEFT_PCPI_OFFSET = PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH - 3 - GAS_LENGTH;
inline const uint32_t L2_START_GAS_LEFT_PCPI_OFFSET = PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH - 2 - GAS_LENGTH;

// Kernel output pil offset (Where update objects are inlined)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class AvmExecutionTests : public ::testing::Test {
void SetUp() override
{
srs::init_crs_factory("../srs_db/ignition");
public_inputs_vec.at(DA_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = DEFAULT_INITIAL_DA_GAS;
public_inputs_vec.at(L2_GAS_LEFT_CONTEXT_INPUTS_OFFSET) = DEFAULT_INITIAL_L2_GAS;
public_inputs_vec.at(DA_START_GAS_LEFT_PCPI_OFFSET) = DEFAULT_INITIAL_DA_GAS;
public_inputs_vec.at(L2_START_GAS_LEFT_PCPI_OFFSET) = DEFAULT_INITIAL_L2_GAS;
};

/**
Expand Down
39 changes: 31 additions & 8 deletions yarn-project/bb-prover/src/avm_proving.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AvmCircuitInputs, Gas, PublicCircuitPublicInputs } from '@aztec/circuits.js';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';
import { AvmSimulator } from '@aztec/simulator';
Expand All @@ -7,6 +8,10 @@ import fs from 'node:fs/promises';
import { tmpdir } from 'node:os';
import path from 'path';

import {
convertAvmResultsToPxResult,
createPublicExecution,
} from '../../simulator/src/public/transitional_adaptors.js';
import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js';
import { extractVkData } from './verification_key/verification_key_data.js';

Expand All @@ -16,8 +21,14 @@ describe('AVM WitGen, proof generation and verification', () => {
it(
'Should prove valid execution contract function that performs addition',
async () => {
const startSideEffectCounter = 0;
const calldata: Fr[] = [new Fr(1), new Fr(2)];
const context = initContext({ env: initExecutionEnvironment({ calldata }) });
const environment = initExecutionEnvironment({ calldata });
const context = initContext({ env: environment });

const startGas = new Gas(context.machineState.gasLeft.daGas, context.machineState.gasLeft.l2Gas);
const oldPublicExecution = createPublicExecution(startSideEffectCounter, environment, calldata);

const internalLogger = createDebugLogger('aztec:avm-proving-test');
const logger = (msg: string, _data?: any) => internalLogger.verbose(msg);
const bytecode = getAvmTestContractBytecode('add_args_return');
Expand All @@ -27,18 +38,30 @@ describe('AVM WitGen, proof generation and verification', () => {

// First we simulate (though it's not needed in this simple case).
const simulator = new AvmSimulator(context);
const results = await simulator.executeBytecode(bytecode);
expect(results.reverted).toBe(false);
const avmResult = await simulator.executeBytecode(bytecode);
expect(avmResult.reverted).toBe(false);

// Then we prove.
const pxResult = convertAvmResultsToPxResult(
avmResult,
startSideEffectCounter,
oldPublicExecution,
startGas,
context,
simulator.getBytecode(),
);
// TODO(dbanks12): public inputs should not be empty.... Need to construct them from AvmContext?
const uncompressedBytecode = simulator.getBytecode()!;
const proofRes = await generateAvmProof(
bbPath,
bbWorkingDirectory,
const publicInputs = PublicCircuitPublicInputs.empty();
publicInputs.startGasLeft = startGas;
const avmCircuitInputs = new AvmCircuitInputs(
uncompressedBytecode,
context.environment.calldata,
logger,
publicInputs,
pxResult.avmHints,
);

// Then we prove.
const proofRes = await generateAvmProof(bbPath, bbWorkingDirectory, avmCircuitInputs, logger);
expect(proofRes.status).toEqual(BB_RESULT.SUCCESS);

// Then we test VK extraction.
Expand Down
41 changes: 34 additions & 7 deletions yarn-project/bb-prover/src/bb/execute.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type Fr } from '@aztec/circuits.js';
import { type AvmCircuitInputs } from '@aztec/circuits.js';
import { sha256 } from '@aztec/foundation/crypto';
import { type LogFn } from '@aztec/foundation/log';
import { Timer } from '@aztec/foundation/timer';
Expand Down Expand Up @@ -265,8 +265,7 @@ export async function generateProof(
export async function generateAvmProof(
pathToBB: string,
workingDirectory: string,
bytecode: Buffer,
calldata: Fr[],
input: AvmCircuitInputs,
log: LogFn,
): Promise<BBFailure | BBSuccess> {
// Check that the working directory exists
Expand All @@ -277,8 +276,10 @@ export async function generateAvmProof(
}

// Paths for the inputs
const calldataPath = join(workingDirectory, 'calldata.bin');
const bytecodePath = join(workingDirectory, 'avm_bytecode.bin');
const calldataPath = join(workingDirectory, 'avm_calldata.bin');
const publicInputsPath = join(workingDirectory, 'avm_public_inputs.bin');
const avmHintsPath = join(workingDirectory, 'avm_hints.bin');

// The proof is written to e.g. /workingDirectory/proof
const outputPath = workingDirectory;
Expand All @@ -296,19 +297,45 @@ export async function generateAvmProof(

try {
// Write the inputs to the working directory.
await fs.writeFile(bytecodePath, bytecode);
await fs.writeFile(bytecodePath, input.bytecode);
if (!filePresent(bytecodePath)) {
return { status: BB_RESULT.FAILURE, reason: `Could not write bytecode at ${bytecodePath}` };
}
await fs.writeFile(
calldataPath,
calldata.map(fr => fr.toBuffer()),
input.calldata.map(fr => fr.toBuffer()),
);
if (!filePresent(calldataPath)) {
return { status: BB_RESULT.FAILURE, reason: `Could not write calldata at ${calldataPath}` };
}

const args = ['-b', bytecodePath, '-d', calldataPath, '-o', outputPath];
// public inputs are used directly as a vector of fields in C++,
// so we serialize them as such here instead of just using toBuffer
await fs.writeFile(
publicInputsPath,
input.publicInputs.toFields().map(fr => fr.toBuffer()),
);
if (!filePresent(publicInputsPath)) {
return { status: BB_RESULT.FAILURE, reason: `Could not write publicInputs at ${publicInputsPath}` };
}

await fs.writeFile(avmHintsPath, input.avmHints.toBuffer());
if (!filePresent(avmHintsPath)) {
return { status: BB_RESULT.FAILURE, reason: `Could not write avmHints at ${avmHintsPath}` };
}

const args = [
'--avm-bytecode',
bytecodePath,
'--avm-calldata',
calldataPath,
'--avm-public-inputs',
publicInputsPath,
'--avm-hints',
avmHintsPath,
'-o',
outputPath,
];
const timer = new Timer();
const logFunction = (message: string) => {
log(`AvmCircuit (prove) BB out - ${message}`);
Expand Down
8 changes: 1 addition & 7 deletions yarn-project/bb-prover/src/prover/bb_prover.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,7 @@ export class BBNativeRollupProver implements ServerCircuitProver {
private async generateAvmProofWithBB(input: AvmCircuitInputs, workingDirectory: string): Promise<BBSuccess> {
logger.debug(`Proving avm-circuit...`);

const provingResult = await generateAvmProof(
this.config.bbBinaryPath,
workingDirectory,
input.bytecode,
input.calldata,
logger.debug,
);
const provingResult = await generateAvmProof(this.config.bbBinaryPath, workingDirectory, input, logger.debug);

if (provingResult.status === BB_RESULT.FAILURE) {
logger.error(`Failed to generate proof for avm-circuit: ${provingResult.reason}`);
Expand Down
2 changes: 2 additions & 0 deletions yarn-project/circuit-types/src/tx/processed_tx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
UnencryptedTxL2Logs,
} from '@aztec/circuit-types';
import {
type AvmExecutionHints,
Fr,
type Gas,
type GasFees,
Expand Down Expand Up @@ -55,6 +56,7 @@ export type AvmProvingRequest = {
type: typeof AVM_REQUEST;
bytecode: Buffer;
calldata: Fr[];
avmHints: AvmExecutionHints;
kernelRequest: PublicKernelNonTailRequest;
};

Expand Down
49 changes: 49 additions & 0 deletions yarn-project/circuits.js/src/structs/avm/avm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { randomInt } from '@aztec/foundation/crypto';

import { makeAvmCircuitInputs, makeAvmExecutionHints, makeAvmHint } from '../../tests/factories.js';
import { AvmCircuitInputs, AvmExecutionHints, AvmHint } from './avm.js';

describe('Avm circuit inputs', () => {
describe('AvmHint', () => {
let avmHint: AvmHint;

beforeAll(() => {
avmHint = makeAvmHint(randomInt(1000));
});

it(`serializes to buffer and deserializes it back`, () => {
const buffer = avmHint.toBuffer();
const res = AvmHint.fromBuffer(buffer);
expect(res).toEqual(avmHint);
expect(res.isEmpty()).toBe(false);
});
});
describe('AvmExecutionHints', () => {
let avmExecutionHints: AvmExecutionHints;

beforeAll(() => {
avmExecutionHints = makeAvmExecutionHints(randomInt(1000));
});

it(`serializes to buffer and deserializes it back`, () => {
const buffer = avmExecutionHints.toBuffer();
const res = AvmExecutionHints.fromBuffer(buffer);
expect(res).toEqual(avmExecutionHints);
expect(res.isEmpty()).toBe(false);
});
});
describe('AvmCircuitInputs', () => {
let avmCircuitInputs: AvmCircuitInputs;

beforeAll(() => {
avmCircuitInputs = makeAvmCircuitInputs(randomInt(2000));
});

it(`serializes to buffer and deserializes it back`, () => {
const buffer = avmCircuitInputs.toBuffer();
const res = AvmCircuitInputs.fromBuffer(buffer);
expect(res).toEqual(avmCircuitInputs);
expect(res.isEmpty()).toBe(false);
});
});
});
Loading

0 comments on commit f3234f1

Please sign in to comment.