Skip to content

Commit

Permalink
refactor: ProvingKey has ProverPolynomials (#5940)
Browse files Browse the repository at this point in the history
Updates all proving systems (except AVM) to have a single storage
container for polynomials: a `ProverPolynomials` object belonging to the
`ProvingKey`.

The old model: Two containers for polynomials: `ProvingKey` (via
inheritance from `PrecomputedEntities` and `WitnessEntities`) and a
`ProverPolynomials` owned by the `ProverInstance`. Both of these
containers store genuine Polynomials, but the memory is "shared" so
neither was any more the owner of the polys than the other. This led to
lots of boilerplate for sharing polys between the two containers and
confusion about what needed to be shared when. There were also several
places where we would construct a temporary ProverPolynomials for use in
one round (e.g. grand product computation which needs shifts) only to
destroy it immediately afterwards.

Note: it used to make sense to have both ProvingKey and
ProverPolynomials when ProvingKey was not a Flavor-style object. I.e. it
used to be a collection of different vectors of polys, so we needed to
construct ProverPolynomials as the nice container for sumcheck to
operate on. But as soon as ProvingKey became a thing inheriting from
Precomputed Entities and WitnessEntities, this no longer made sense.

Closes AztecProtocol/barretenberg#962

Branch:
`ClientIVCBench/Full/6      21403 ms        16242 ms            1`

Master
`ClientIVCBench/Full/6      21585 ms        16480 ms            1`
  • Loading branch information
ledwards2225 authored Apr 26, 2024
1 parent e615a83 commit 0a64279
Show file tree
Hide file tree
Showing 42 changed files with 668 additions and 1,015 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ BB_PROFILE static void test_round_inner(State& state, GoblinUltraProver& prover,
// we need to get the relation_parameters and prover_polynomials from the oink_prover
prover.instance->proving_key = std::move(oink_prover.proving_key);
prover.instance->relation_parameters = oink_prover.relation_parameters;
prover.instance->prover_polynomials = GoblinUltraFlavor::ProverPolynomials(prover.instance->proving_key);
time_if_index(RELATION_CHECK, [&] { prover.execute_relation_check_rounds(); });
time_if_index(ZEROMORPH, [&] { prover.execute_zeromorph_rounds(); });
}
Expand Down
201 changes: 86 additions & 115 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,32 +234,32 @@ class ECCVMFlavor {
static auto get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities)
{
// NOTE: must match order of ShiftedEntities above!
return RefArray{ entities.transcript_mul,
entities.transcript_msm_count,
entities.transcript_accumulator_x,
entities.transcript_accumulator_y,
entities.precompute_scalar_sum,
entities.precompute_s1hi,
entities.precompute_dx,
entities.precompute_dy,
entities.precompute_tx,
entities.precompute_ty,
entities.msm_transition,
entities.msm_add,
entities.msm_double,
entities.msm_skew,
entities.msm_accumulator_x,
entities.msm_accumulator_y,
entities.msm_count,
entities.msm_round,
entities.msm_add1,
entities.msm_pc,
entities.precompute_pc,
entities.transcript_pc,
entities.precompute_round,
entities.transcript_accumulator_empty,
entities.precompute_select,
entities.z_perm };
return RefArray{ entities.transcript_mul, // column 0
entities.transcript_msm_count, // column 1
entities.transcript_accumulator_x, // column 2
entities.transcript_accumulator_y, // column 3
entities.precompute_scalar_sum, // column 4
entities.precompute_s1hi, // column 5
entities.precompute_dx, // column 6
entities.precompute_dy, // column 7
entities.precompute_tx, // column 8
entities.precompute_ty, // column 9
entities.msm_transition, // column 10
entities.msm_add, // column 11
entities.msm_double, // column 12
entities.msm_skew, // column 13
entities.msm_accumulator_x, // column 14
entities.msm_accumulator_y, // column 15
entities.msm_count, // column 16
entities.msm_round, // column 17
entities.msm_add1, // column 18
entities.msm_pc, // column 19
entities.precompute_pc, // column 20
entities.transcript_pc, // column 21
entities.precompute_round, // column 22
entities.transcript_accumulator_empty, // column 23
entities.precompute_select, // column 24
entities.z_perm }; // column 25
}
/**
* @brief A base class labelling all entities (for instance, all of the polynomials used by the prover during
Expand Down Expand Up @@ -293,72 +293,10 @@ class ECCVMFlavor {

auto get_to_be_shifted() { return ECCVMFlavor::get_to_be_shifted<DataType>(*this); }
auto get_shifted() { return ShiftedEntities<DataType>::get_all(); };
auto get_precomputed() { return PrecomputedEntities<DataType>::get_all(); };
};

public:
/**
* @brief The proving key is responsible for storing the polynomials used by the prover.
* @note TODO(Cody): Maybe multiple inheritance is the right thing here. In that case, nothing should eve
* inherit from ProvingKey.
*/
class ProvingKey : public ProvingKey_<PrecomputedEntities<Polynomial>, WitnessEntities<Polynomial>, CommitmentKey> {
public:
// Expose constructors on the base class
using Base = ProvingKey_<PrecomputedEntities<Polynomial>, WitnessEntities<Polynomial>, CommitmentKey>;
using Base::Base;

ProvingKey(const CircuitBuilder& builder)
: ProvingKey_<PrecomputedEntities<Polynomial>, WitnessEntities<Polynomial>, CommitmentKey>(
builder.get_circuit_subgroup_size(builder.get_num_gates()), 0)
{
const auto [_lagrange_first, _lagrange_last] =
compute_first_and_last_lagrange_polynomials<FF>(circuit_size);
lagrange_first = _lagrange_first;
lagrange_last = _lagrange_last;
{
Polynomial _lagrange_second(circuit_size);
_lagrange_second[1] = 1;
lagrange_second = _lagrange_second.share();
}
}

auto get_to_be_shifted() { return ECCVMFlavor::get_to_be_shifted<Polynomial>(*this); }
// The plookup wires that store plookup read data.
RefArray<Polynomial, 0> get_table_column_wires() { return {}; };
};

/**
* @brief The verification key is responsible for storing the the commitments to the precomputed (non-witnessk)
* polynomials used by the verifier.
*
* @note Note the discrepancy with what sort of data is stored here vs in the proving key. We may want to
* resolve that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for
* portability of our circuits.
*/
class VerificationKey : public VerificationKey_<PrecomputedEntities<Commitment>, VerifierCommitmentKey> {
public:
std::vector<FF> public_inputs;

VerificationKey(const size_t circuit_size, const size_t num_public_inputs)
: VerificationKey_(circuit_size, num_public_inputs)
{}

VerificationKey(const std::shared_ptr<ProvingKey>& proving_key)
: public_inputs(proving_key->public_inputs)
{
this->pcs_verification_key = std::make_shared<VerifierCommitmentKey>(proving_key->circuit_size);
this->circuit_size = proving_key->circuit_size;
this->log_circuit_size = numeric::get_msb(this->circuit_size);
this->num_public_inputs = proving_key->num_public_inputs;
this->pub_inputs_offset = proving_key->pub_inputs_offset;

for (auto [polynomial, commitment] :
zip_view(proving_key->get_precomputed_polynomials(), this->get_all())) {
commitment = proving_key->commitment_key->commit(polynomial);
}
}
};

/**
* @brief A container for polynomials produced after the first round of sumcheck.
* @todo TODO(#394) Use polynomial classes for guaranteed memory alignment.
Expand Down Expand Up @@ -432,6 +370,13 @@ class ECCVMFlavor {
}
return result;
}
// Set all shifted polynomials based on their to-be-shifted counterpart
void set_shifted()
{
for (auto [shifted, to_be_shifted] : zip_view(get_shifted(), get_to_be_shifted())) {
shifted = to_be_shifted.shifted();
}
}

/**
* @brief Compute the ECCVM flavor polynomial data required to generate an ECCVM Proof
Expand Down Expand Up @@ -513,7 +458,7 @@ class ECCVMFlavor {
table (reads come from msm_x/y3, msm_x/y4)
* @return ProverPolynomials
*/
ProverPolynomials(CircuitBuilder& builder)
ProverPolynomials(const CircuitBuilder& builder)
{
const auto msms = builder.get_msms();
const auto flattened_muls = builder.get_flattened_scalar_muls(msms);
Expand Down Expand Up @@ -652,31 +597,57 @@ class ECCVMFlavor {
msm_slice4[i] = msm_state[i].add_state[3].slice;
}
});
transcript_mul_shift = transcript_mul.shifted();
transcript_msm_count_shift = transcript_msm_count.shifted();
transcript_accumulator_x_shift = transcript_accumulator_x.shifted();
transcript_accumulator_y_shift = transcript_accumulator_y.shifted();
precompute_scalar_sum_shift = precompute_scalar_sum.shifted();
precompute_s1hi_shift = precompute_s1hi.shifted();
precompute_dx_shift = precompute_dx.shifted();
precompute_dy_shift = precompute_dy.shifted();
precompute_tx_shift = precompute_tx.shifted();
precompute_ty_shift = precompute_ty.shifted();
msm_transition_shift = msm_transition.shifted();
msm_add_shift = msm_add.shifted();
msm_double_shift = msm_double.shifted();
msm_skew_shift = msm_skew.shifted();
msm_accumulator_x_shift = msm_accumulator_x.shifted();
msm_accumulator_y_shift = msm_accumulator_y.shifted();
msm_count_shift = msm_count.shifted();
msm_round_shift = msm_round.shifted();
msm_add1_shift = msm_add1.shifted();
msm_pc_shift = msm_pc.shifted();
precompute_pc_shift = precompute_pc.shifted();
transcript_pc_shift = transcript_pc.shifted();
precompute_round_shift = precompute_round.shifted();
transcript_accumulator_empty_shift = transcript_accumulator_empty.shifted();
precompute_select_shift = precompute_select.shifted();
this->set_shifted();
}
};

/**
* @brief The proving key is responsible for storing the polynomials used by the prover.
*
*/
class ProvingKey : public ProvingKey_<FF, CommitmentKey> {
public:
// Expose constructors on the base class
using Base = ProvingKey_<FF, CommitmentKey>;
using Base::Base;

ProverPolynomials polynomials; // storage for all polynomials evaluated by the prover

ProvingKey(const CircuitBuilder& builder)
: Base(builder.get_circuit_subgroup_size(builder.get_num_gates()), 0)
, polynomials(builder)
{}
};

/**
* @brief The verification key is responsible for storing the the commitments to the precomputed (non-witnessk)
* polynomials used by the verifier.
*
* @note Note the discrepancy with what sort of data is stored here vs in the proving key. We may want to
* resolve that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for
* portability of our circuits.
*/
class VerificationKey : public VerificationKey_<PrecomputedEntities<Commitment>, VerifierCommitmentKey> {
public:
std::vector<FF> public_inputs;

VerificationKey(const size_t circuit_size, const size_t num_public_inputs)
: VerificationKey_(circuit_size, num_public_inputs)
{}

VerificationKey(const std::shared_ptr<ProvingKey>& proving_key)
: public_inputs(proving_key->public_inputs)
{
this->pcs_verification_key = std::make_shared<VerifierCommitmentKey>(proving_key->circuit_size);
this->circuit_size = proving_key->circuit_size;
this->log_circuit_size = numeric::get_msb(this->circuit_size);
this->num_public_inputs = proving_key->num_public_inputs;
this->pub_inputs_offset = proving_key->pub_inputs_offset;

for (auto [polynomial, commitment] :
zip_view(proving_key->polynomials.get_precomputed(), this->get_all())) {
commitment = proving_key->commitment_key->commit(polynomial);
}
}
};

Expand Down
40 changes: 17 additions & 23 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace bb {

ECCVMProver::ECCVMProver(CircuitBuilder& builder, const std::shared_ptr<Transcript>& transcript)
: transcript(transcript)
, prover_polynomials(builder)
{
BB_OP_COUNT_TIME_NAME("ECCVMProver(CircuitBuilder&)");

Expand All @@ -24,13 +23,6 @@ ECCVMProver::ECCVMProver(CircuitBuilder& builder, const std::shared_ptr<Transcri
// Construct the proving key; populates all polynomials except for witness polys
key = std::make_shared<ProvingKey>(builder);

// Share all unshifted polys from the prover polynomials to the proving key. Note: this means that updating a
// polynomial in one container automatically updates it in the other via the shared memory.
for (auto [prover_poly, key_poly] : zip_view(prover_polynomials.get_unshifted(), key->get_all())) {
ASSERT(flavor_get_label(prover_polynomials, prover_poly) == flavor_get_label(*key, key_poly));
key_poly = prover_poly.share();
}

commitment_key = std::make_shared<CommitmentKey>(key->circuit_size);
}

Expand All @@ -51,7 +43,7 @@ void ECCVMProver::execute_preamble_round()
*/
void ECCVMProver::execute_wire_commitments_round()
{
auto wire_polys = key->get_wires();
auto wire_polys = key->polynomials.get_wires();
auto labels = commitment_labels.get_wires();
for (size_t idx = 0; idx < wire_polys.size(); ++idx) {
transcript->send_to_verifier(labels[idx], commitment_key->commit(wire_polys[idx]));
Expand All @@ -78,8 +70,9 @@ void ECCVMProver::execute_log_derivative_commitments_round()
relation_parameters.eccvm_set_permutation_delta = relation_parameters.eccvm_set_permutation_delta.invert();
// Compute inverse polynomial for our logarithmic-derivative lookup method
compute_logderivative_inverse<Flavor, typename Flavor::LookupRelation>(
prover_polynomials, relation_parameters, key->circuit_size);
transcript->send_to_verifier(commitment_labels.lookup_inverses, commitment_key->commit(key->lookup_inverses));
key->polynomials, relation_parameters, key->circuit_size);
transcript->send_to_verifier(commitment_labels.lookup_inverses,
commitment_key->commit(key->polynomials.lookup_inverses));
}

/**
Expand All @@ -89,9 +82,9 @@ void ECCVMProver::execute_log_derivative_commitments_round()
void ECCVMProver::execute_grand_product_computation_round()
{
// Compute permutation grand product and their commitments
compute_permutation_grand_products<Flavor>(key, prover_polynomials, relation_parameters);
compute_grand_products<Flavor>(key->polynomials, relation_parameters);

transcript->send_to_verifier(commitment_labels.z_perm, commitment_key->commit(key->z_perm));
transcript->send_to_verifier(commitment_labels.z_perm, commitment_key->commit(key->polynomials.z_perm));
}

/**
Expand All @@ -108,7 +101,7 @@ void ECCVMProver::execute_relation_check_rounds()
for (size_t idx = 0; idx < gate_challenges.size(); idx++) {
gate_challenges[idx] = transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
}
sumcheck_output = sumcheck.prove(prover_polynomials, relation_parameters, alpha, gate_challenges);
sumcheck_output = sumcheck.prove(key->polynomials, relation_parameters, alpha, gate_challenges);
}

/**
Expand All @@ -118,8 +111,8 @@ void ECCVMProver::execute_relation_check_rounds()
* */
void ECCVMProver::execute_zeromorph_rounds()
{
ZeroMorph::prove(prover_polynomials.get_unshifted(),
prover_polynomials.get_to_be_shifted(),
ZeroMorph::prove(key->polynomials.get_unshifted(),
key->polynomials.get_to_be_shifted(),
sumcheck_output.claimed_evaluations.get_unshifted(),
sumcheck_output.claimed_evaluations.get_shifted(),
sumcheck_output.challenge,
Expand All @@ -146,11 +139,11 @@ void ECCVMProver::execute_transcript_consistency_univariate_opening_round()
// Get the challenge at which we evaluate the polynomials as univariates
evaluation_challenge_x = transcript->template get_challenge<FF>("Translation:evaluation_challenge_x");

translation_evaluations.op = key->transcript_op.evaluate(evaluation_challenge_x);
translation_evaluations.Px = key->transcript_Px.evaluate(evaluation_challenge_x);
translation_evaluations.Py = key->transcript_Py.evaluate(evaluation_challenge_x);
translation_evaluations.z1 = key->transcript_z1.evaluate(evaluation_challenge_x);
translation_evaluations.z2 = key->transcript_z2.evaluate(evaluation_challenge_x);
translation_evaluations.op = key->polynomials.transcript_op.evaluate(evaluation_challenge_x);
translation_evaluations.Px = key->polynomials.transcript_Px.evaluate(evaluation_challenge_x);
translation_evaluations.Py = key->polynomials.transcript_Py.evaluate(evaluation_challenge_x);
translation_evaluations.z1 = key->polynomials.transcript_z1.evaluate(evaluation_challenge_x);
translation_evaluations.z2 = key->polynomials.transcript_z2.evaluate(evaluation_challenge_x);

// Add the univariate evaluations to the transcript
transcript->send_to_verifier("Translation:op", translation_evaluations.op);
Expand All @@ -164,8 +157,9 @@ void ECCVMProver::execute_transcript_consistency_univariate_opening_round()
FF ipa_batching_challenge = transcript->template get_challenge<FF>("Translation:ipa_batching_challenge");

// Collect the polynomials and evaluations to be batched
RefArray univariate_polynomials{ key->transcript_op, key->transcript_Px, key->transcript_Py,
key->transcript_z1, key->transcript_z2, hack };
RefArray univariate_polynomials{ key->polynomials.transcript_op, key->polynomials.transcript_Px,
key->polynomials.transcript_Py, key->polynomials.transcript_z1,
key->polynomials.transcript_z2, hack };
std::array<FF, univariate_polynomials.size()> univariate_evaluations;
for (auto [eval, polynomial] : zip_view(univariate_evaluations, univariate_polynomials)) {
eval = polynomial.evaluate(evaluation_challenge_x);
Expand Down
5 changes: 1 addition & 4 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_prover.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "barretenberg/eccvm/eccvm_flavor.hpp"
#include "barretenberg/goblin/translation_evaluations.hpp"
#include "barretenberg/honk/proof_system/types/proof.hpp"
#include "barretenberg/plonk_honk_shared/library/grand_product_library.hpp"
#include "barretenberg/relations/relation_parameters.hpp"
#include "barretenberg/sumcheck/sumcheck_output.hpp"
#include "barretenberg/transcript/transcript.hpp"
Expand All @@ -18,7 +19,6 @@ class ECCVMProver {
using CommitmentKey = typename Flavor::CommitmentKey;
using ProvingKey = typename Flavor::ProvingKey;
using Polynomial = typename Flavor::Polynomial;
using ProverPolynomials = typename Flavor::ProverPolynomials;
using CommitmentLabels = typename Flavor::CommitmentLabels;
using Transcript = typename Flavor::Transcript;
using TranslationEvaluations = bb::TranslationEvaluations;
Expand Down Expand Up @@ -50,9 +50,6 @@ class ECCVMProver {

std::shared_ptr<ProvingKey> key;

// Container for spans of all polynomials required by the prover (i.e. all multivariates evaluated by Sumcheck).
ProverPolynomials prover_polynomials;

CommitmentLabels commitment_labels;

// Container for d + 1 Fold polynomials produced by Gemini
Expand Down
Loading

0 comments on commit 0a64279

Please sign in to comment.