From d42812e3033e5d585341e9d858e454da4daef718 Mon Sep 17 00:00:00 2001 From: fcarreiro Date: Wed, 3 Jul 2024 09:22:51 +0000 Subject: [PATCH] feat(avm): use template engine for codegen --- .../barretenberg/vm/generated/avm_flavor.hpp | 314 ++++---- bb-pilcom/Cargo.lock | 102 ++- bb-pilcom/bb-pil-backend/Cargo.toml | 2 + .../bb-pil-backend/src/flavor_builder.rs | 681 +----------------- .../bb-pil-backend/templates/flavor.hpp.hbs | 349 +++++++++ 5 files changed, 636 insertions(+), 812 deletions(-) create mode 100644 bb-pilcom/bb-pil-backend/templates/flavor.hpp.hbs diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp index 10929d4f372d..bdd10e57f7ad 100644 --- a/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp @@ -1,4 +1,3 @@ - #pragma once #include "barretenberg/commitment_schemes/kzg/kzg.hpp" @@ -7,20 +6,31 @@ #include "barretenberg/polynomials/barycentric.hpp" #include "barretenberg/polynomials/univariate.hpp" -#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" - #include "barretenberg/flavor/flavor.hpp" #include "barretenberg/flavor/flavor_macros.hpp" #include "barretenberg/polynomials/evaluation_domain.hpp" #include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/transcript/transcript.hpp" + +#include "barretenberg/vm/avm_trace/stats.hpp" + +// Relations #include "barretenberg/relations/generated/avm/alu.hpp" #include "barretenberg/relations/generated/avm/binary.hpp" #include "barretenberg/relations/generated/avm/conversion.hpp" #include "barretenberg/relations/generated/avm/gas.hpp" -#include "barretenberg/relations/generated/avm/incl_main_tag_err.hpp" -#include "barretenberg/relations/generated/avm/incl_mem_tag_err.hpp" #include "barretenberg/relations/generated/avm/keccakf1600.hpp" #include "barretenberg/relations/generated/avm/kernel.hpp" +#include "barretenberg/relations/generated/avm/main.hpp" +#include "barretenberg/relations/generated/avm/mem.hpp" +#include "barretenberg/relations/generated/avm/pedersen.hpp" +#include "barretenberg/relations/generated/avm/poseidon2.hpp" +#include "barretenberg/relations/generated/avm/powers.hpp" +#include "barretenberg/relations/generated/avm/sha256.hpp" + +// Lookup relations +#include "barretenberg/relations/generated/avm/incl_main_tag_err.hpp" +#include "barretenberg/relations/generated/avm/incl_mem_tag_err.hpp" #include "barretenberg/relations/generated/avm/kernel_output_lookup.hpp" #include "barretenberg/relations/generated/avm/lookup_byte_lengths.hpp" #include "barretenberg/relations/generated/avm/lookup_byte_operations.hpp" @@ -56,9 +66,6 @@ #include "barretenberg/relations/generated/avm/lookup_u16_9.hpp" #include "barretenberg/relations/generated/avm/lookup_u8_0.hpp" #include "barretenberg/relations/generated/avm/lookup_u8_1.hpp" -#include "barretenberg/relations/generated/avm/main.hpp" -#include "barretenberg/relations/generated/avm/mem.hpp" -#include "barretenberg/relations/generated/avm/pedersen.hpp" #include "barretenberg/relations/generated/avm/perm_main_alu.hpp" #include "barretenberg/relations/generated/avm/perm_main_bin.hpp" #include "barretenberg/relations/generated/avm/perm_main_conv.hpp" @@ -72,15 +79,11 @@ #include "barretenberg/relations/generated/avm/perm_main_mem_ind_addr_d.hpp" #include "barretenberg/relations/generated/avm/perm_main_pedersen.hpp" #include "barretenberg/relations/generated/avm/perm_main_pos2_perm.hpp" -#include "barretenberg/relations/generated/avm/poseidon2.hpp" -#include "barretenberg/relations/generated/avm/powers.hpp" #include "barretenberg/relations/generated/avm/range_check_da_gas_hi.hpp" #include "barretenberg/relations/generated/avm/range_check_da_gas_lo.hpp" #include "barretenberg/relations/generated/avm/range_check_l2_gas_hi.hpp" #include "barretenberg/relations/generated/avm/range_check_l2_gas_lo.hpp" -#include "barretenberg/relations/generated/avm/sha256.hpp" -#include "barretenberg/transcript/transcript.hpp" -#include "barretenberg/vm/avm_trace/stats.hpp" +#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" namespace bb { @@ -107,72 +110,75 @@ class AvmFlavor { // the unshifted and one for the shifted static constexpr size_t NUM_ALL_ENTITIES = 452; - using Relations = std::tuple, - Avm_vm::binary, - Avm_vm::conversion, - Avm_vm::gas, - Avm_vm::keccakf1600, - Avm_vm::kernel, - Avm_vm::main, - Avm_vm::mem, - Avm_vm::pedersen, - Avm_vm::poseidon2, - Avm_vm::powers, - Avm_vm::sha256, - perm_main_alu_relation, - perm_main_bin_relation, - perm_main_conv_relation, - perm_main_pos2_perm_relation, - perm_main_pedersen_relation, - perm_main_mem_a_relation, - perm_main_mem_b_relation, - perm_main_mem_c_relation, - perm_main_mem_d_relation, - perm_main_mem_ind_addr_a_relation, - perm_main_mem_ind_addr_b_relation, - perm_main_mem_ind_addr_c_relation, - perm_main_mem_ind_addr_d_relation, - lookup_byte_lengths_relation, - lookup_byte_operations_relation, - lookup_opcode_gas_relation, - range_check_l2_gas_hi_relation, - range_check_l2_gas_lo_relation, - range_check_da_gas_hi_relation, - range_check_da_gas_lo_relation, - kernel_output_lookup_relation, - lookup_into_kernel_relation, - incl_main_tag_err_relation, - incl_mem_tag_err_relation, - lookup_mem_rng_chk_lo_relation, - lookup_mem_rng_chk_mid_relation, - lookup_mem_rng_chk_hi_relation, - lookup_pow_2_0_relation, - lookup_pow_2_1_relation, - lookup_u8_0_relation, - lookup_u8_1_relation, - lookup_u16_0_relation, - lookup_u16_1_relation, - lookup_u16_2_relation, - lookup_u16_3_relation, - lookup_u16_4_relation, - lookup_u16_5_relation, - lookup_u16_6_relation, - lookup_u16_7_relation, - lookup_u16_8_relation, - lookup_u16_9_relation, - lookup_u16_10_relation, - lookup_u16_11_relation, - lookup_u16_12_relation, - lookup_u16_13_relation, - lookup_u16_14_relation, - lookup_div_u16_0_relation, - lookup_div_u16_1_relation, - lookup_div_u16_2_relation, - lookup_div_u16_3_relation, - lookup_div_u16_4_relation, - lookup_div_u16_5_relation, - lookup_div_u16_6_relation, - lookup_div_u16_7_relation>; + using Relations = std::tuple< + // Relations + Avm_vm::alu, + Avm_vm::binary, + Avm_vm::conversion, + Avm_vm::gas, + Avm_vm::keccakf1600, + Avm_vm::kernel, + Avm_vm::main, + Avm_vm::mem, + Avm_vm::pedersen, + Avm_vm::poseidon2, + Avm_vm::powers, + Avm_vm::sha256, + // Lookups + perm_main_alu_relation, + perm_main_bin_relation, + perm_main_conv_relation, + perm_main_pos2_perm_relation, + perm_main_pedersen_relation, + perm_main_mem_a_relation, + perm_main_mem_b_relation, + perm_main_mem_c_relation, + perm_main_mem_d_relation, + perm_main_mem_ind_addr_a_relation, + perm_main_mem_ind_addr_b_relation, + perm_main_mem_ind_addr_c_relation, + perm_main_mem_ind_addr_d_relation, + lookup_byte_lengths_relation, + lookup_byte_operations_relation, + lookup_opcode_gas_relation, + range_check_l2_gas_hi_relation, + range_check_l2_gas_lo_relation, + range_check_da_gas_hi_relation, + range_check_da_gas_lo_relation, + kernel_output_lookup_relation, + lookup_into_kernel_relation, + incl_main_tag_err_relation, + incl_mem_tag_err_relation, + lookup_mem_rng_chk_lo_relation, + lookup_mem_rng_chk_mid_relation, + lookup_mem_rng_chk_hi_relation, + lookup_pow_2_0_relation, + lookup_pow_2_1_relation, + lookup_u8_0_relation, + lookup_u8_1_relation, + lookup_u16_0_relation, + lookup_u16_1_relation, + lookup_u16_2_relation, + lookup_u16_3_relation, + lookup_u16_4_relation, + lookup_u16_5_relation, + lookup_u16_6_relation, + lookup_u16_7_relation, + lookup_u16_8_relation, + lookup_u16_9_relation, + lookup_u16_10_relation, + lookup_u16_11_relation, + lookup_u16_12_relation, + lookup_u16_13_relation, + lookup_u16_14_relation, + lookup_div_u16_0_relation, + lookup_div_u16_1_relation, + lookup_div_u16_2_relation, + lookup_div_u16_3_relation, + lookup_div_u16_4_relation, + lookup_div_u16_5_relation, + lookup_div_u16_6_relation, + lookup_div_u16_7_relation>; static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); @@ -197,10 +203,10 @@ class AvmFlavor { DEFINE_FLAVOR_MEMBERS(DataType, main_clk, main_sel_first) - RefVector get_selectors() { return { main_clk, main_sel_first }; }; - RefVector get_sigma_polynomials() { return {}; }; - RefVector get_id_polynomials() { return {}; }; - RefVector get_table_polynomials() { return {}; }; + RefVector get_selectors() { return { main_clk, main_sel_first }; } + RefVector get_sigma_polynomials() { return {}; } + RefVector get_id_polynomials() { return {}; } + RefVector get_table_polynomials() { return {}; } }; template class WireEntities { @@ -539,7 +545,8 @@ class AvmFlavor { lookup_div_u16_7_counts) }; - template struct DerivedWitnessEntities { + template class DerivedWitnessEntities { + public: DEFINE_FLAVOR_MEMBERS(DataType, perm_main_alu, perm_main_bin, @@ -670,74 +677,71 @@ class AvmFlavor { template static auto get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities) { - return RefArray{ - - entities.alu_a_hi, - entities.alu_a_lo, - entities.alu_b_hi, - entities.alu_b_lo, - entities.alu_cmp_rng_ctr, - entities.alu_div_u16_r0, - entities.alu_div_u16_r1, - entities.alu_div_u16_r2, - entities.alu_div_u16_r3, - entities.alu_div_u16_r4, - entities.alu_div_u16_r5, - entities.alu_div_u16_r6, - entities.alu_div_u16_r7, - entities.alu_op_add, - entities.alu_op_cast_prev, - entities.alu_op_cast, - entities.alu_op_div, - entities.alu_op_mul, - entities.alu_op_shl, - entities.alu_op_shr, - entities.alu_op_sub, - entities.alu_p_sub_a_hi, - entities.alu_p_sub_a_lo, - entities.alu_p_sub_b_hi, - entities.alu_p_sub_b_lo, - entities.alu_sel_alu, - entities.alu_sel_cmp, - entities.alu_sel_div_rng_chk, - entities.alu_sel_rng_chk_lookup, - entities.alu_sel_rng_chk, - entities.alu_u16_r0, - entities.alu_u16_r1, - entities.alu_u16_r2, - entities.alu_u16_r3, - entities.alu_u16_r4, - entities.alu_u16_r5, - entities.alu_u16_r6, - entities.alu_u8_r0, - entities.alu_u8_r1, - entities.binary_acc_ia, - entities.binary_acc_ib, - entities.binary_acc_ic, - entities.binary_mem_tag_ctr, - entities.binary_op_id, - entities.kernel_emit_l2_to_l1_msg_write_offset, - entities.kernel_emit_note_hash_write_offset, - entities.kernel_emit_nullifier_write_offset, - entities.kernel_emit_unencrypted_log_write_offset, - entities.kernel_l1_to_l2_msg_exists_write_offset, - entities.kernel_note_hash_exist_write_offset, - entities.kernel_nullifier_exists_write_offset, - entities.kernel_nullifier_non_exists_write_offset, - entities.kernel_side_effect_counter, - entities.kernel_sload_write_offset, - entities.kernel_sstore_write_offset, - entities.main_da_gas_remaining, - entities.main_internal_return_ptr, - entities.main_l2_gas_remaining, - entities.main_pc, - entities.mem_glob_addr, - entities.mem_rw, - entities.mem_sel_mem, - entities.mem_tag, - entities.mem_tsp, - entities.mem_val, - }; + return RefArray{ entities.alu_a_hi, + entities.alu_a_lo, + entities.alu_b_hi, + entities.alu_b_lo, + entities.alu_cmp_rng_ctr, + entities.alu_div_u16_r0, + entities.alu_div_u16_r1, + entities.alu_div_u16_r2, + entities.alu_div_u16_r3, + entities.alu_div_u16_r4, + entities.alu_div_u16_r5, + entities.alu_div_u16_r6, + entities.alu_div_u16_r7, + entities.alu_op_add, + entities.alu_op_cast_prev, + entities.alu_op_cast, + entities.alu_op_div, + entities.alu_op_mul, + entities.alu_op_shl, + entities.alu_op_shr, + entities.alu_op_sub, + entities.alu_p_sub_a_hi, + entities.alu_p_sub_a_lo, + entities.alu_p_sub_b_hi, + entities.alu_p_sub_b_lo, + entities.alu_sel_alu, + entities.alu_sel_cmp, + entities.alu_sel_div_rng_chk, + entities.alu_sel_rng_chk_lookup, + entities.alu_sel_rng_chk, + entities.alu_u16_r0, + entities.alu_u16_r1, + entities.alu_u16_r2, + entities.alu_u16_r3, + entities.alu_u16_r4, + entities.alu_u16_r5, + entities.alu_u16_r6, + entities.alu_u8_r0, + entities.alu_u8_r1, + entities.binary_acc_ia, + entities.binary_acc_ib, + entities.binary_acc_ic, + entities.binary_mem_tag_ctr, + entities.binary_op_id, + entities.kernel_emit_l2_to_l1_msg_write_offset, + entities.kernel_emit_note_hash_write_offset, + entities.kernel_emit_nullifier_write_offset, + entities.kernel_emit_unencrypted_log_write_offset, + entities.kernel_l1_to_l2_msg_exists_write_offset, + entities.kernel_note_hash_exist_write_offset, + entities.kernel_nullifier_exists_write_offset, + entities.kernel_nullifier_non_exists_write_offset, + entities.kernel_side_effect_counter, + entities.kernel_sload_write_offset, + entities.kernel_sstore_write_offset, + entities.main_da_gas_remaining, + entities.main_internal_return_ptr, + entities.main_l2_gas_remaining, + entities.main_pc, + entities.mem_glob_addr, + entities.mem_rw, + entities.mem_sel_mem, + entities.mem_tag, + entities.mem_tsp, + entities.mem_val }; } template @@ -844,7 +848,7 @@ class AvmFlavor { mem_tag, mem_tsp, mem_val }; - }; + } void compute_logderivative_inverses(const RelationParameters& relation_parameters) { @@ -1362,7 +1366,7 @@ class AvmFlavor { Base::mem_diff_mid = "MEM_DIFF_MID"; Base::mem_glob_addr = "MEM_GLOB_ADDR"; Base::mem_last = "MEM_LAST"; - Base::mem_lastAccess = "MEM_LASTACCESS"; + Base::mem_lastAccess = "MEM_LAST_ACCESS"; Base::mem_one_min_inv = "MEM_ONE_MIN_INV"; Base::mem_r_in_tag = "MEM_R_IN_TAG"; Base::mem_rw = "MEM_RW"; @@ -2740,4 +2744,4 @@ class AvmFlavor { }; }; -} // namespace bb +} // namespace bb \ No newline at end of file diff --git a/bb-pilcom/Cargo.lock b/bb-pilcom/Cargo.lock index d1179b177b9f..f3555dd6fa56 100644 --- a/bb-pilcom/Cargo.lock +++ b/bb-pilcom/Cargo.lock @@ -241,6 +241,7 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" name = "bb-pil-backend" version = "0.1.0" dependencies = [ + "handlebars", "itertools 0.10.5", "log", "num-bigint", @@ -249,6 +250,7 @@ dependencies = [ "powdr-ast", "powdr-number", "rand", + "serde_json", ] [[package]] @@ -272,6 +274,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -388,6 +399,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -517,6 +537,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ + "block-buffer", "crypto-common", ] @@ -641,6 +662,21 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" +[[package]] +name = "handlebars" +version = "5.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b" +dependencies = [ + "heck", + "log", + "pest", + "pest_derive", + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -976,6 +1012,51 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pest" +version = "2.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd53dff83f26735fdc1ca837098ccf133605d794cdae66acfc2bfac3ec809d95" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a548d2beca6773b1c244554d36fcf8548a8a58e74156968211567250e48e49a" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c93a82e8d145725dcbaf44e5ea887c8a869efdcc28706df2d08c69e17077183" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "pest_meta" +version = "2.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a941429fea7e08bedec25e4f6785b6ffaacc6b755da98df5ef3e7dcf4a124c4f" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "petgraph" version = "0.6.5" @@ -1335,9 +1416,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -1374,6 +1455,17 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1614,6 +1706,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ucd-trie" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/bb-pilcom/bb-pil-backend/Cargo.toml b/bb-pilcom/bb-pil-backend/Cargo.toml index 0904cd7f4454..7d1a3cef9b44 100644 --- a/bb-pilcom/bb-pil-backend/Cargo.toml +++ b/bb-pilcom/bb-pil-backend/Cargo.toml @@ -14,3 +14,5 @@ itertools = "^0.10" log = "0.4.17" rand = "0.8.5" powdr-ast = { path = "../powdr/ast" } +handlebars = { version = "5.1.2", features = ["string_helpers"] } +serde_json = "1.0.120" diff --git a/bb-pilcom/bb-pil-backend/src/flavor_builder.rs b/bb-pilcom/bb-pil-backend/src/flavor_builder.rs index f4522b603e53..f2e794970b75 100644 --- a/bb-pilcom/bb-pil-backend/src/flavor_builder.rs +++ b/bb-pilcom/bb-pil-backend/src/flavor_builder.rs @@ -1,7 +1,6 @@ -use crate::{ - file_writer::BBFiles, - utils::{get_relations_imports, map_with_newline, snake_case}, -}; +use crate::{file_writer::BBFiles, utils::snake_case}; +use handlebars::Handlebars; +use serde_json::json; pub trait FlavorBuilder { #[allow(clippy::too_many_arguments)] @@ -35,75 +34,29 @@ impl FlavorBuilder for BBFiles { shifted: &[String], all_cols_and_shifts: &[String], ) { - let first_poly = &witness[0]; - let includes = flavor_includes(&snake_case(name), relation_file_names, lookups); - let num_precomputed = fixed.len(); - let num_witness = witness.len(); - let num_all = all_cols_and_shifts.len(); - - // Top of file boilerplate - let class_aliases = create_class_aliases(); - let relation_definitions = create_relation_definitions(name, relation_file_names, lookups); - let container_size_definitions = - container_size_definitions(num_precomputed, num_witness, num_all); - - // Entities classes - let precomputed_entities = create_precomputed_entities(fixed); - let witness_entities = - create_witness_entities(witness_without_inverses, lookups, shifted, to_be_shifted); - let all_entities = create_all_entities(); - - let proving_and_verification_key = - create_proving_and_verification_key(name, lookups, to_be_shifted); - let polynomial_views = create_polynomial_views(first_poly); - - let commitment_labels_class = create_commitment_labels(all_cols); - - let verification_commitments = create_verifier_commitments(fixed); - - let transcript = generate_transcript(witness); - - let flavor_hpp = format!( - " -{includes} - -namespace bb {{ - -class {name}Flavor {{ - public: - {class_aliases} - - {container_size_definitions} - - {relation_definitions} - - static constexpr bool has_zero_row = true; - - private: - {precomputed_entities} - - {witness_entities} - - {all_entities} - - - {proving_and_verification_key} - - - {polynomial_views} - - {commitment_labels_class} - - {verification_commitments} - - {transcript} -}}; - -}} // namespace bb - - - " - ); + let mut handlebars = Handlebars::new(); + + let data = &json!({ + "name": name, + "relation_file_names": relation_file_names, + "lookups": lookups, + "fixed": fixed, + "witness": witness, + "all_cols": all_cols, + "to_be_shifted": to_be_shifted, + "shifted": shifted, + "all_cols_and_shifts": all_cols_and_shifts, + "witness_without_inverses": witness_without_inverses, + }); + + handlebars + .register_template_string( + "flavor.hpp", + std::str::from_utf8(include_bytes!("../templates/flavor.hpp.hbs")).unwrap(), + ) + .unwrap(); + + let flavor_hpp = handlebars.render("flavor.hpp", data).unwrap(); self.write_file( &self.flavor, @@ -112,585 +65,3 @@ class {name}Flavor {{ ); } } - -/// Imports located at the top of the flavor files -fn flavor_includes(name: &str, relation_file_names: &[String], lookups: &[String]) -> String { - let relation_imports = get_relations_imports(name, relation_file_names, lookups); - - format!( - "#pragma once - -#include \"barretenberg/commitment_schemes/kzg/kzg.hpp\" -#include \"barretenberg/ecc/curves/bn254/g1.hpp\" -#include \"barretenberg/flavor/relation_definitions.hpp\" -#include \"barretenberg/polynomials/barycentric.hpp\" -#include \"barretenberg/polynomials/univariate.hpp\" - -#include \"barretenberg/relations/generic_permutation/generic_permutation_relation.hpp\" - -#include \"barretenberg/vm/avm_trace/stats.hpp\" -#include \"barretenberg/flavor/flavor_macros.hpp\" -#include \"barretenberg/transcript/transcript.hpp\" -#include \"barretenberg/polynomials/evaluation_domain.hpp\" -#include \"barretenberg/polynomials/polynomial.hpp\" -#include \"barretenberg/flavor/flavor.hpp\" -{relation_imports} -" - ) -} - -/// Creates comma separated relations tuple file -fn create_relations_tuple(master_name: &str, relation_file_names: &[String]) -> String { - relation_file_names - .iter() - .map(|name| format!("{master_name}_vm::{name}")) - .collect::>() - .join(", ") -} - -/// Creates comma separated relations tuple file -fn create_lookups_tuple(lookups: &[String]) -> Option { - if lookups.is_empty() { - return None; - } - Some( - lookups - .iter() - .map(|lookup| format!("{}_relation", lookup.clone())) - .collect::>() - .join(", "), - ) -} - -/// Create Class Aliases -/// -/// Contains boilerplate defining key characteristics of the flavor class -fn create_class_aliases() -> &'static str { - r#" - using Curve = curve::BN254; - using G1 = Curve::Group; - using PCS = KZG; - - using FF = G1::subgroup_field; - using Polynomial = bb::Polynomial; - using PolynomialHandle = std::span; - using GroupElement = G1::element; - using Commitment = G1::affine_element; - using CommitmentHandle = G1::affine_element; - using CommitmentKey = bb::CommitmentKey; - using VerifierCommitmentKey = bb::VerifierCommitmentKey; - using RelationSeparator = FF; - "# -} - -/// Create relation definitions -/// -/// Contains all of the boilerplate code required to generate relation definitions. -/// We instantiate the Relations container, which contains a tuple of all of the separate relation file -/// definitions. -/// -/// We then also define some constants, making use of the preprocessor. -fn create_relation_definitions( - name: &str, - relation_file_names: &[String], - lookups: &[String], -) -> String { - // Relations tuple = ns::relation_name_0, ns::relation_name_1, ... ns::relation_name_n (comma speratated) - let comma_sep_relations = create_relations_tuple(name, relation_file_names); - let comma_sep_lookups: Option = create_lookups_tuple(lookups); - - // We only include the grand product relations if we are given lookups - let mut all_relations = comma_sep_relations.to_string(); - if let Some(lookups) = comma_sep_lookups { - all_relations = all_relations + &format!(", {lookups}"); - } - - format!(" - using Relations = std::tuple<{all_relations}>; - - static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); - - // BATCHED_RELATION_PARTIAL_LENGTH = algebraic degree of sumcheck relation *after* multiplying by the `pow_zeta` - // random polynomial e.g. For \\sum(x) [A(x) * B(x) + C(x)] * PowZeta(X), relation length = 2 and random relation - // length = 3 - static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = MAX_PARTIAL_RELATION_LENGTH + 1; - static constexpr size_t NUM_RELATIONS = std::tuple_size_v; - - template - using ProtogalaxyTupleOfTuplesOfUnivariates = - decltype(create_protogalaxy_tuple_of_tuples_of_univariates()); - using SumcheckTupleOfTuplesOfUnivariates = decltype(create_sumcheck_tuple_of_tuples_of_univariates()); - using TupleOfArraysOfValues = decltype(create_tuple_of_arrays_of_values()); - ") -} - -/// Create the number of columns boilerplate for the flavor file -fn container_size_definitions( - num_precomputed: usize, - num_witness: usize, - num_all: usize, -) -> String { - format!(" - static constexpr size_t NUM_PRECOMPUTED_ENTITIES = {num_precomputed}; - static constexpr size_t NUM_WITNESS_ENTITIES = {num_witness}; - static constexpr size_t NUM_WIRES = NUM_WITNESS_ENTITIES + NUM_PRECOMPUTED_ENTITIES; - // We have two copies of the witness entities, so we subtract the number of fixed ones (they have no shift), one for the unshifted and one for the shifted - static constexpr size_t NUM_ALL_ENTITIES = {num_all}; - - ") -} - -/// Returns a Ref Vector with the given name, -/// -/// The vector returned will reference the columns names given -/// Used in all entities declarations -fn return_ref_vector(name: &str, columns: &[String]) -> String { - let comma_sep = create_comma_separated(columns); - - format!("RefVector {name}() {{ return {{ {comma_sep} }}; }};") -} - -/// list -> "list[0], list[1], ... list[n-1]" -fn create_comma_separated(list: &[String]) -> String { - list.join(", ") -} - -/// Create Precomputed Entities -/// -/// Precomputed first contains a pointer view defining all of the precomputed columns -/// As-well as any polys conforming to tables / ids / permutations -fn create_precomputed_entities(fixed: &[String]) -> String { - let pointer_view = create_flavor_members(fixed); - - let selectors = return_ref_vector("get_selectors", fixed); - let sigma_polys = return_ref_vector("get_sigma_polynomials", &[]); - let id_polys = return_ref_vector("get_id_polynomials", &[]); - let table_polys = return_ref_vector("get_table_polynomials", &[]); - - format!( - " - template - class PrecomputedEntities : public PrecomputedEntitiesBase {{ - public: - using DataType = DataType_; - - {pointer_view} - - {selectors} - {sigma_polys} - {id_polys} - {table_polys} - }}; - " - ) -} - -// Note(md): this is witnesses WITHOUT inverses or shifts -fn create_wire_entities(witness: &[String]) -> String { - let flavor_members = create_flavor_members(witness); - - format!( - " - template - class WireEntities {{ - public: - {flavor_members} - }}; - " - ) -} - -// Note(md): this is witnesses and in future grand products -fn create_derived_witnesses(inverses: &[String]) -> String { - let flavor_members = create_flavor_members(inverses); - - format!( - " - template - struct DerivedWitnessEntities {{ - {flavor_members} - }}; - " - ) -} - -fn create_shifted_entities(shifted: &[String]) -> String { - let flavor_members = create_flavor_members(shifted); - - format!( - " - template - class ShiftedEntities {{ - public: - {flavor_members} - }}; - " - ) -} - -fn create_to_be_shifted(to_be_shifted: &[String]) -> String { - let entities_transformation = |name: &String| format!("entities.{name},"); - let entities_list = map_with_newline(to_be_shifted, entities_transformation); - - format!( - " - template - static auto get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities) {{ - return RefArray{{ - - {entities_list} - }}; - }} - " - ) -} - -fn create_witness_entities( - witness: &[String], - inverses: &[String], - shifted: &[String], - to_be_shifted: &[String], -) -> String { - let wire_entities = create_wire_entities(witness); - let derived_witnesses = create_derived_witnesses(inverses); - let shifted_entities = create_shifted_entities(shifted); - let to_be_shifted = create_to_be_shifted(to_be_shifted); - - format!( - " - {wire_entities} - - {derived_witnesses} - - {shifted_entities} - - {to_be_shifted} - - template - class WitnessEntities: public WireEntities, public DerivedWitnessEntities {{ - public: - DEFINE_COMPOUND_GET_ALL(WireEntities, DerivedWitnessEntities) - auto get_wires() {{ return WireEntities::get_all(); }}; - }}; - " - ) -} - -/// Creates container of all witness entities and shifts -fn create_all_entities() -> String { - format!( - " - template - class AllEntities: public PrecomputedEntities, - public WitnessEntities, - public ShiftedEntities {{ - public: - AllEntities() - : PrecomputedEntities{{}} - , WitnessEntities{{}} - , ShiftedEntities{{}} - {{}} - - DEFINE_COMPOUND_GET_ALL(PrecomputedEntities, WitnessEntities, ShiftedEntities) - - auto get_unshifted(){{ - return concatenate(PrecomputedEntities::get_all(), WitnessEntities::get_all()); - }} - auto get_to_be_shifted(){{ return AvmFlavor::get_to_be_shifted(*this); }} - auto get_shifted() {{ return ShiftedEntities::get_all(); }} - auto get_precomputed() {{ return PrecomputedEntities::get_all(); }} - }}; - " - ) -} - -fn create_proving_and_verification_key( - flavor_name: &str, - lookups: &[String], - to_be_shifted: &[String], -) -> String { - let get_to_be_shifted = return_ref_vector("get_to_be_shifted", to_be_shifted); - let compute_logderivative_inverses = - create_compute_logderivative_inverses(flavor_name, lookups); - - format!(" - public: - class ProvingKey : public ProvingKeyAvm_, WitnessEntities, CommitmentKey> {{ - public: - // Expose constructors on the base class - using Base = ProvingKeyAvm_, WitnessEntities, CommitmentKey>; - using Base::Base; - - {get_to_be_shifted} - - {compute_logderivative_inverses} - }}; - - using VerificationKey = VerificationKey_, VerifierCommitmentKey>; - ") -} - -fn create_polynomial_views(first_poly: &String) -> String { - format!(" - - class AllValues : public AllEntities {{ - public: - using Base = AllEntities; - using Base::Base; - }}; - - /** - * @brief A container for the prover polynomials handles. - */ - class ProverPolynomials : public AllEntities {{ - public: - // Define all operations as default, except copy construction/assignment - ProverPolynomials() = default; - ProverPolynomials& operator=(const ProverPolynomials&) = delete; - ProverPolynomials(const ProverPolynomials& o) = delete; - ProverPolynomials(ProverPolynomials&& o) noexcept = default; - ProverPolynomials& operator=(ProverPolynomials&& o) noexcept = default; - ~ProverPolynomials() = default; - - ProverPolynomials(ProvingKey& proving_key) - {{ - for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key.get_all())) {{ - ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(proving_key, key_poly)); - prover_poly = key_poly.share(); - }} - for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key.get_to_be_shifted())) {{ - ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(proving_key, key_poly) + \"_shift\")); - prover_poly = key_poly.shifted(); - }} - }} - - [[nodiscard]] size_t get_polynomial_size() const {{ return {first_poly}.size(); }} - /** - * @brief Returns the evaluations of all prover polynomials at one point on the boolean hypercube, which - * represents one row in the execution trace. - */ - [[nodiscard]] AllValues get_row(size_t row_idx) const - {{ - AllValues result; - for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {{ - result_field = polynomial[row_idx]; - }} - return result; - }} - }}; - - class PartiallyEvaluatedMultivariates : public AllEntities {{ - public: - PartiallyEvaluatedMultivariates() = default; - PartiallyEvaluatedMultivariates(const size_t circuit_size) - {{ - // Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2) - for (auto& poly : get_all()) {{ - poly = Polynomial(circuit_size / 2); - }} - }} - }}; - - /** - * @brief A container for univariates used during Protogalaxy folding and sumcheck. - * @details During folding and sumcheck, the prover evaluates the relations on these univariates. - */ - template - using ProverUnivariates = AllEntities>; - - /** - * @brief A container for univariates used during Protogalaxy folding and sumcheck with some of the computation - * optimistically ignored - * @details During folding and sumcheck, the prover evaluates the relations on these univariates. - */ - template - using OptimisedProverUnivariates = AllEntities>; - - /** - * @brief A container for univariates produced during the hot loop in sumcheck. - */ - using ExtendedEdges = ProverUnivariates; - - /** - * @brief A container for the witness commitments. - * - */ - using WitnessCommitments = WitnessEntities; - - ") -} - -fn create_flavor_members(entities: &[String]) -> String { - let pointer_list = create_comma_separated(entities); - - format!( - "DEFINE_FLAVOR_MEMBERS(DataType, {pointer_list})", - pointer_list = pointer_list - ) -} - -fn create_labels(all_ents: &[String]) -> String { - let mut labels = String::new(); - for name in all_ents { - labels.push_str(&format!( - "Base::{name} = \"{}\"; - ", - name.to_uppercase() - )); - } - labels -} - -fn create_commitment_labels(all_ents: &[String]) -> String { - let labels = create_labels(all_ents); - - format!( - " - class CommitmentLabels: public AllEntities {{ - private: - using Base = AllEntities; - - public: - CommitmentLabels() : AllEntities() - {{ - {labels} - }}; - }}; - " - ) -} - -/// Create the compute_logderivative_inverses function -/// -/// If we do not have any lookups, we do not need to include this round -fn create_compute_logderivative_inverses(flavor_name: &str, lookups: &[String]) -> String { - if lookups.is_empty() { - return "".to_string(); - } - - let compute_inverse_transformation = |lookup_name: &String| { - format!("AVM_TRACK_TIME(\"compute_logderivative_inverse/{lookup_name}_ms\", (bb::compute_logderivative_inverse<{flavor_name}Flavor, {lookup_name}_relation>(prover_polynomials, relation_parameters, this->circuit_size)));") - }; - - let compute_inverses = map_with_newline(lookups, compute_inverse_transformation); - - format!( - " - void compute_logderivative_inverses(const RelationParameters& relation_parameters) - {{ - ProverPolynomials prover_polynomials = ProverPolynomials(*this); - - {compute_inverses} - }} - " - ) -} - -fn create_key_dereference(fixed: &[String]) -> String { - let deref_transformation = |name: &String| format!("{name} = verification_key->{name};"); - - map_with_newline(fixed, deref_transformation) -} - -fn create_verifier_commitments(fixed: &[String]) -> String { - let key_dereference = create_key_dereference(fixed); - - format!( - " - class VerifierCommitments : public AllEntities {{ - private: - using Base = AllEntities; - - public: - VerifierCommitments(const std::shared_ptr& verification_key) - {{ - {key_dereference} - }} - }}; -" - ) -} - -fn generate_transcript(witness: &[String]) -> String { - // Transformations - let declaration_transform = |c: &_| format!("Commitment {c};"); - let deserialize_transform = |name: &_| { - format!( - "{name} = deserialize_from_buffer(Transcript::proof_data, num_frs_read);", - ) - }; - let serialize_transform = - |name: &_| format!("serialize_to_buffer({name}, Transcript::proof_data);"); - - // Perform Transformations - let declarations = map_with_newline(witness, declaration_transform); - let deserialize_wires = map_with_newline(witness, deserialize_transform); - let serialize_wires = map_with_newline(witness, serialize_transform); - - format!(" - class Transcript : public NativeTranscript {{ - public: - uint32_t circuit_size; - - {declarations} - - std::vector> sumcheck_univariates; - std::array sumcheck_evaluations; - std::vector zm_cq_comms; - Commitment zm_cq_comm; - Commitment zm_pi_comm; - - Transcript() = default; - - Transcript(const std::vector& proof) - : NativeTranscript(proof) - {{}} - - void deserialize_full_transcript() - {{ - size_t num_frs_read = 0; - circuit_size = deserialize_from_buffer(proof_data, num_frs_read); - size_t log_n = numeric::get_msb(circuit_size); - - {deserialize_wires} - - for (size_t i = 0; i < log_n; ++i) {{ - sumcheck_univariates.emplace_back( - deserialize_from_buffer>( - Transcript::proof_data, num_frs_read)); - }} - sumcheck_evaluations = deserialize_from_buffer>( - Transcript::proof_data, num_frs_read); - for (size_t i = 0; i < log_n; ++i) {{ - zm_cq_comms.push_back(deserialize_from_buffer(proof_data, num_frs_read)); - }} - zm_cq_comm = deserialize_from_buffer(proof_data, num_frs_read); - zm_pi_comm = deserialize_from_buffer(proof_data, num_frs_read); - }} - - void serialize_full_transcript() - {{ - size_t old_proof_length = proof_data.size(); - Transcript::proof_data.clear(); - size_t log_n = numeric::get_msb(circuit_size); - - serialize_to_buffer(circuit_size, Transcript::proof_data); - - {serialize_wires} - - for (size_t i = 0; i < log_n; ++i) {{ - serialize_to_buffer(sumcheck_univariates[i], Transcript::proof_data); - }} - serialize_to_buffer(sumcheck_evaluations, Transcript::proof_data); - for (size_t i = 0; i < log_n; ++i) {{ - serialize_to_buffer(zm_cq_comms[i], proof_data); - }} - serialize_to_buffer(zm_cq_comm, proof_data); - serialize_to_buffer(zm_pi_comm, proof_data); - - // sanity check to make sure we generate the same length of proof as before. - ASSERT(proof_data.size() == old_proof_length); - }} - }}; - ") -} diff --git a/bb-pilcom/bb-pil-backend/templates/flavor.hpp.hbs b/bb-pilcom/bb-pil-backend/templates/flavor.hpp.hbs new file mode 100644 index 000000000000..88b8386f8911 --- /dev/null +++ b/bb-pilcom/bb-pil-backend/templates/flavor.hpp.hbs @@ -0,0 +1,349 @@ +#pragma once + +#include "barretenberg/commitment_schemes/kzg/kzg.hpp" +#include "barretenberg/ecc/curves/bn254/g1.hpp" +#include "barretenberg/flavor/relation_definitions.hpp" +#include "barretenberg/polynomials/barycentric.hpp" +#include "barretenberg/polynomials/univariate.hpp" + +#include "barretenberg/flavor/flavor_macros.hpp" +#include "barretenberg/transcript/transcript.hpp" +#include "barretenberg/polynomials/evaluation_domain.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/flavor/flavor.hpp" + +#include "barretenberg/vm/{{snakeCase name}}_trace/stats.hpp" + +// Relations +{{#each relation_file_names as |r|}} +#include "barretenberg/relations/generated/{{snakeCase ../name}}/{{r}}.hpp" +{{/each}} + +// Lookup relations +#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" +{{#each lookups as |r|}} +#include "barretenberg/relations/generated/{{snakeCase ../name}}/{{r}}.hpp" +{{/each}} + +namespace bb { + +class {{name}}Flavor { + public: + using Curve = curve::BN254; + using G1 = Curve::Group; + using PCS = KZG; + + using FF = G1::subgroup_field; + using Polynomial = bb::Polynomial; + using PolynomialHandle = std::span; + using GroupElement = G1::element; + using Commitment = G1::affine_element; + using CommitmentHandle = G1::affine_element; + using CommitmentKey = bb::CommitmentKey; + using VerifierCommitmentKey = bb::VerifierCommitmentKey; + using RelationSeparator = FF; + + static constexpr size_t NUM_PRECOMPUTED_ENTITIES = {{len fixed}}; + static constexpr size_t NUM_WITNESS_ENTITIES = {{len witness}}; + static constexpr size_t NUM_WIRES = NUM_WITNESS_ENTITIES + NUM_PRECOMPUTED_ENTITIES; + // We have two copies of the witness entities, so we subtract the number of fixed ones (they have no shift), one for the unshifted and one for the shifted + static constexpr size_t NUM_ALL_ENTITIES = {{len all_cols_and_shifts}}; + + using Relations = std::tuple< + // Relations + {{#each relation_file_names as |item|}}{{#if @index}},{{/if}}{{../name}}_vm::{{item}}{{/each}}, + // Lookups + {{#each lookups as |item|}}{{#if @index}},{{/if}}{{item}}_relation{{/each}} + >; + + static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); + + // BATCHED_RELATION_PARTIAL_LENGTH = algebraic degree of sumcheck relation *after* multiplying by the `pow_zeta` + // random polynomial e.g. For \sum(x) [A(x) * B(x) + C(x)] * PowZeta(X), relation length = 2 and random relation + // length = 3 + static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = MAX_PARTIAL_RELATION_LENGTH + 1; + static constexpr size_t NUM_RELATIONS = std::tuple_size_v; + + template + using ProtogalaxyTupleOfTuplesOfUnivariates = + decltype(create_protogalaxy_tuple_of_tuples_of_univariates()); + using SumcheckTupleOfTuplesOfUnivariates = decltype(create_sumcheck_tuple_of_tuples_of_univariates()); + using TupleOfArraysOfValues = decltype(create_tuple_of_arrays_of_values()); + + static constexpr bool has_zero_row = true; + + private: + template + class PrecomputedEntities : public PrecomputedEntitiesBase { + public: + using DataType = DataType_; + + DEFINE_FLAVOR_MEMBERS(DataType, {{#each fixed as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}}) + + RefVector get_selectors() { return { + {{#each fixed as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}} + }; } + RefVector get_sigma_polynomials() { return {}; } + RefVector get_id_polynomials() { return {}; } + RefVector get_table_polynomials() { return {}; } + }; + + template + class WireEntities { + public: + DEFINE_FLAVOR_MEMBERS(DataType, {{#each witness_without_inverses as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}}) + }; + + template + class DerivedWitnessEntities { + public: + DEFINE_FLAVOR_MEMBERS(DataType, {{#each lookups as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}}) + }; + + template + class ShiftedEntities { + public: + DEFINE_FLAVOR_MEMBERS(DataType, {{#each shifted as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}}) + }; + + template + static auto get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities) { + return RefArray{ + {{#each to_be_shifted as |item|}}{{#if @index}},{{/if}}entities.{{item}}{{/each}} + }; + } + + template + class WitnessEntities: public WireEntities, public DerivedWitnessEntities { + public: + DEFINE_COMPOUND_GET_ALL(WireEntities, DerivedWitnessEntities) + auto get_wires() { return WireEntities::get_all(); }; + }; + + template + class AllEntities : public PrecomputedEntities + , public WitnessEntities + , public ShiftedEntities { + public: + AllEntities() + : PrecomputedEntities{} + , WitnessEntities{} + , ShiftedEntities{} + {} + + DEFINE_COMPOUND_GET_ALL(PrecomputedEntities, WitnessEntities, ShiftedEntities) + + auto get_unshifted() { + return concatenate(PrecomputedEntities::get_all(), WitnessEntities::get_all()); + } + auto get_to_be_shifted() { return {{name}}Flavor::get_to_be_shifted(*this); } + auto get_shifted() { return ShiftedEntities::get_all(); } + auto get_precomputed() { return PrecomputedEntities::get_all(); } + }; + + public: + class ProvingKey : public ProvingKey{{name}}_, WitnessEntities, CommitmentKey> { + public: + // Expose constructors on the base class + using Base = ProvingKey{{name}}_, WitnessEntities, CommitmentKey>; + using Base::Base; + + RefVector get_to_be_shifted() { return { + {{#each to_be_shifted as |item|}}{{#if @index}},{{/if}}{{item}}{{/each}} + }; } + + void compute_logderivative_inverses(const RelationParameters& relation_parameters) + { + ProverPolynomials prover_polynomials = ProverPolynomials(*this); + + {{#each lookups as |item|}} + AVM_TRACK_TIME("compute_logderivative_inverse/{{item}}_ms", (bb::compute_logderivative_inverse<{{../name}}Flavor, {{item}}_relation>(prover_polynomials, relation_parameters, this->circuit_size))); + {{/each}} + } + }; + + using VerificationKey = VerificationKey_, VerifierCommitmentKey>; + + class AllValues : public AllEntities { + public: + using Base = AllEntities; + using Base::Base; + }; + + /** + * @brief A container for the prover polynomials handles. + */ + class ProverPolynomials : public AllEntities { + public: + // Define all operations as default, except copy construction/assignment + ProverPolynomials() = default; + ProverPolynomials& operator=(const ProverPolynomials&) = delete; + ProverPolynomials(const ProverPolynomials& o) = delete; + ProverPolynomials(ProverPolynomials&& o) noexcept = default; + ProverPolynomials& operator=(ProverPolynomials&& o) noexcept = default; + ~ProverPolynomials() = default; + + ProverPolynomials(ProvingKey& proving_key) + { + for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key.get_all())) { + ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(proving_key, key_poly)); + prover_poly = key_poly.share(); + } + for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key.get_to_be_shifted())) { + ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(proving_key, key_poly) + "_shift")); + prover_poly = key_poly.shifted(); + } + } + + [[nodiscard]] size_t get_polynomial_size() const { return {{witness.0}}.size(); } + /** + * @brief Returns the evaluations of all prover polynomials at one point on the boolean hypercube, which + * represents one row in the execution trace. + */ + [[nodiscard]] AllValues get_row(size_t row_idx) const + { + AllValues result; + for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) { + result_field = polynomial[row_idx]; + } + return result; + } + }; + + class PartiallyEvaluatedMultivariates : public AllEntities { + public: + PartiallyEvaluatedMultivariates() = default; + PartiallyEvaluatedMultivariates(const size_t circuit_size) + { + // Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2) + for (auto& poly : get_all()) { + poly = Polynomial(circuit_size / 2); + } + } + }; + + /** + * @brief A container for univariates used during Protogalaxy folding and sumcheck. + * @details During folding and sumcheck, the prover evaluates the relations on these univariates. + */ + template + using ProverUnivariates = AllEntities>; + + /** + * @brief A container for univariates used during Protogalaxy folding and sumcheck with some of the computation + * optimistically ignored + * @details During folding and sumcheck, the prover evaluates the relations on these univariates. + */ + template + using OptimisedProverUnivariates = AllEntities>; + + /** + * @brief A container for univariates produced during the hot loop in sumcheck. + */ + using ExtendedEdges = ProverUnivariates; + + /** + * @brief A container for the witness commitments. + * + */ + using WitnessCommitments = WitnessEntities; + + class CommitmentLabels: public AllEntities { + private: + using Base = AllEntities; + + public: + CommitmentLabels() : AllEntities() + { + {{#each all_cols as |item|}} + Base::{{item}} = "{{shoutySnakeCase item}}"; + {{/each}} + }; + }; + + class VerifierCommitments : public AllEntities { + private: + using Base = AllEntities; + + public: + VerifierCommitments(const std::shared_ptr& verification_key) + { + {{#each fixed as |item|}} + {{item}} = verification_key->{{item}}; + {{/each}} + } + }; + + class Transcript : public NativeTranscript { + public: + uint32_t circuit_size; + + {{#each witness as |w|}} + Commitment {{w}}; + {{/each}} + + std::vector> sumcheck_univariates; + std::array sumcheck_evaluations; + std::vector zm_cq_comms; + Commitment zm_cq_comm; + Commitment zm_pi_comm; + + Transcript() = default; + + Transcript(const std::vector& proof) + : NativeTranscript(proof) + {} + + void deserialize_full_transcript() + { + size_t num_frs_read = 0; + circuit_size = deserialize_from_buffer(proof_data, num_frs_read); + size_t log_n = numeric::get_msb(circuit_size); + + {{#each witness as |w|}} + {{w}} = deserialize_from_buffer(Transcript::proof_data, num_frs_read); + {{/each}} + + for (size_t i = 0; i < log_n; ++i) { + sumcheck_univariates.emplace_back( + deserialize_from_buffer>( + Transcript::proof_data, num_frs_read)); + } + sumcheck_evaluations = deserialize_from_buffer>( + Transcript::proof_data, num_frs_read); + for (size_t i = 0; i < log_n; ++i) { + zm_cq_comms.push_back(deserialize_from_buffer(proof_data, num_frs_read)); + } + zm_cq_comm = deserialize_from_buffer(proof_data, num_frs_read); + zm_pi_comm = deserialize_from_buffer(proof_data, num_frs_read); + } + + void serialize_full_transcript() + { + size_t old_proof_length = proof_data.size(); + Transcript::proof_data.clear(); + size_t log_n = numeric::get_msb(circuit_size); + + serialize_to_buffer(circuit_size, Transcript::proof_data); + + {{#each witness as |w|}} + serialize_to_buffer({{w}}, Transcript::proof_data); + {{/each}} + + for (size_t i = 0; i < log_n; ++i) { + serialize_to_buffer(sumcheck_univariates[i], Transcript::proof_data); + } + serialize_to_buffer(sumcheck_evaluations, Transcript::proof_data); + for (size_t i = 0; i < log_n; ++i) { + serialize_to_buffer(zm_cq_comms[i], proof_data); + } + serialize_to_buffer(zm_cq_comm, proof_data); + serialize_to_buffer(zm_pi_comm, proof_data); + + // sanity check to make sure we generate the same length of proof as before. + ASSERT(proof_data.size() == old_proof_length); + } + }; +}; + +} // namespace bb \ No newline at end of file