Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: backpropagate constants in ACIR during optimization #3926

Merged
merged 24 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
49aef5c
feat: backpropagate constants in ACIR during optimization
TomAFrench Jan 2, 2024
7ec72eb
fix: skip optimizations on unconstrained programs
TomAFrench Jan 8, 2024
4ecf580
fix: don't panic on unsolvable opcodes
TomAFrench Jan 2, 2024
85f4b2b
chore: fix test
TomAFrench Jan 2, 2024
8b9b181
chore: run two rounds of backpropagation
TomAFrench Jan 8, 2024
b61697f
chore: remove `get_outputs_vec` for directives
TomAFrench Feb 9, 2024
c225495
chore: remove decomposition of circuit
TomAFrench Feb 9, 2024
833ca8e
chore: improve documentation
TomAFrench Feb 9, 2024
4e623b0
chore: cspell
TomAFrench Feb 9, 2024
6bbb9c9
chore: docs
TomAFrench Feb 9, 2024
d01d60d
Merge branch 'master' into tf/full-backprop
TomAFrench Feb 9, 2024
111ac0f
Merge branch 'master' into tf/full-backprop
kevaundray Feb 12, 2024
5abfa61
Merge branch 'master' into tf/full-backprop
TomAFrench Feb 13, 2024
ed6d941
Update acvm-repo/acvm/src/compiler/optimizers/general.rs
TomAFrench Feb 13, 2024
c2dc0ea
Update acvm-repo/acvm/src/compiler/optimizers/constant_backpropagatio…
TomAFrench Feb 13, 2024
505aac9
chore: update based on review comments
TomAFrench Feb 14, 2024
ce3675d
chore: add comment
TomAFrench Feb 14, 2024
58b37a6
fix: preserve opcodes for which outputs are required witnesses
TomAFrench Feb 15, 2024
333cfbf
chore: fix clippy warning
TomAFrench Feb 20, 2024
89223c2
chore: cargo fmt
TomAFrench Feb 20, 2024
d56c9f9
Merge branch 'master' into tf/full-backprop
TomAFrench Feb 20, 2024
d9c244f
Merge branch 'master' into tf/full-backprop
TomAFrench Feb 26, 2024
74b870c
fix: include all inputs in inputs vector for `Sha256Compression`
TomAFrench Mar 2, 2024
25258e3
Merge branch 'master' into tf/full-backprop
TomAFrench Mar 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions acvm-repo/acvm/src/compiler/optimizers/constant_backpropagation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};

use crate::{
compiler::optimizers::GeneralOptimizer,
pwg::{
arithmetic::ExpressionSolver, directives::solve_directives, BrilligSolver,
BrilligSolverStatus,
},
};
use acir::{
circuit::{
brillig::{Brillig, BrilligInputs},
directives::Directive,
opcodes::BlackBoxFuncCall,
Circuit, Opcode,
},
native_types::{Expression, Witness, WitnessMap},
};
use acvm_blackbox_solver::StubbedBlackBoxSolver;

/// `ConstantBackpropagationOptimizer` will attempt to determine any constant witnesses within the program.
/// It does this by attempting to solve the program without any inputs (i.e. using an empty witness map),
/// any values which it can determine are then enforced to be constant values.
///
/// The optimizer will then replace any witnesses wherever they appear within the circuit with these constant values.
/// This is repeated until the circuit stabilizes.
pub(crate) struct ConstantBackpropagationOptimizer {
circuit: Circuit,
}

impl ConstantBackpropagationOptimizer {
/// Creates a new `ConstantBackpropagationOptimizer`
pub(crate) fn new(circuit: Circuit) -> Self {
Self { circuit }
}

fn gather_known_witnesses(&self) -> WitnessMap {
// We do not want to affect the circuit's interface so avoid optimizing away these witnesses.
let mut required_witnesses: BTreeSet<Witness> = self
.circuit
.private_parameters
.union(&self.circuit.public_parameters.0)
.chain(&self.circuit.return_values.0)
.copied()
.collect();

for opcode in &self.circuit.opcodes {
match &opcode {
Opcode::BlackBoxFuncCall(func_call) => {
required_witnesses.extend(
func_call.get_inputs_vec().into_iter().map(|func_input| func_input.witness),
);
required_witnesses.extend(func_call.get_outputs_vec());
}

Opcode::MemoryInit { init, .. } => {
required_witnesses.extend(init);
}

Opcode::MemoryOp { op, .. } => {
required_witnesses.insert(op.index.to_witness().unwrap());
required_witnesses.insert(op.value.to_witness().unwrap());
}

_ => (),
};
}

let mut known_witnesses = WitnessMap::new();
for opcode in self.circuit.opcodes.iter().rev() {
if let Opcode::AssertZero(expr) = opcode {
let _ = ExpressionSolver::solve(&mut known_witnesses, expr);
}
}

// We want to retain any references to required witnesses so we "forget" these assignments.
let known_witnesses: BTreeMap<_, _> = known_witnesses
.into_iter()
.filter(|(witness, _)| !required_witnesses.contains(witness))
.collect();

known_witnesses.into()
}

/// Returns a `Circuit` where with any constant witnesses replaced with the constant they resolve to.
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn backpropagate_constants(
circuit: Circuit,
order_list: Vec<usize>,
) -> (Circuit, Vec<usize>) {
let old_circuit_size = circuit.opcodes.len();

let optimizer = Self::new(circuit);
let (circuit, order_list) = optimizer.backpropagate_constants_iteration(order_list);

let new_circuit_size = circuit.opcodes.len();
if new_circuit_size < old_circuit_size {
Self::backpropagate_constants(circuit, order_list)
} else {
(circuit, order_list)
}
}

/// Applies a single round of constant backpropagation to a `Circuit`.
pub(crate) fn backpropagate_constants_iteration(
mut self,
order_list: Vec<usize>,
) -> (Circuit, Vec<usize>) {
let mut known_witnesses = self.gather_known_witnesses();

let opcodes = std::mem::take(&mut self.circuit.opcodes);

fn remap_expression(known_witnesses: &WitnessMap, expression: Expression) -> Expression {
GeneralOptimizer::optimize(ExpressionSolver::evaluate(&expression, known_witnesses))
}

let mut new_order_list = Vec::with_capacity(order_list.len());
let mut new_opcodes = Vec::with_capacity(opcodes.len());
for (idx, opcode) in opcodes.into_iter().enumerate() {
let new_opcode = match opcode {
Opcode::AssertZero(expression) => {
let new_expr = remap_expression(&known_witnesses, expression);
if new_expr.is_zero() {
continue;
}

// Attempt to solve the opcode to see if we can determine the value of any witnesses in the expression.
// We only do this _after_ we apply any simplifications to create the new opcode as we want to
// keep the constraint on the witness which we are solving for here.
let solve_result = ExpressionSolver::solve(&mut known_witnesses, &new_expr);
// It doesn't matter what the result is. We expect most opcodes to not be solved successfully so we discard errors.
// At the same time, if the expression can be solved then we track this by the updates to `known_witnesses`
drop(solve_result);

Opcode::AssertZero(new_expr)
}
Opcode::Brillig(brillig) => {
let remapped_inputs = brillig
.inputs
.into_iter()
.map(|input| match input {
BrilligInputs::Single(expr) => {
BrilligInputs::Single(remap_expression(&known_witnesses, expr))
}
BrilligInputs::Array(expr_array) => {
let new_input: Vec<_> = expr_array
.into_iter()
.map(|expr| remap_expression(&known_witnesses, expr))
.collect();

BrilligInputs::Array(new_input)
}
input @ BrilligInputs::MemoryArray(_) => input,
})
.collect();

let remapped_predicate = brillig
.predicate
.map(|predicate| remap_expression(&known_witnesses, predicate));

let new_brillig = Brillig {
inputs: remapped_inputs,
predicate: remapped_predicate,
..brillig
};

if let Ok(mut solver) = BrilligSolver::new(
&known_witnesses,
&HashMap::new(),
&new_brillig,
&StubbedBlackBoxSolver,
idx,
) {
match solver.solve() {
Ok(BrilligSolverStatus::Finished) => {
// Write execution outputs
match solver.finalize(&mut known_witnesses, &new_brillig) {
Ok(()) => {
// If we've managed to execute the brillig opcode at compile time, we can now just write in the
// results as constants for the rest of the circuit.
//
// TODO: what if these are required to be witnesses e.g. inputs to black box functions + array get/set?
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
continue
},
_ => Opcode::Brillig(new_brillig),
}
}
_ => Opcode::Brillig(new_brillig),
}
} else {
Opcode::Brillig(new_brillig)
}
}

Opcode::Directive(Directive::ToLeRadix { a, b, radix }) => {
if b.iter().all(|output| known_witnesses.contains_key(output)) {
continue;
} else {
let directive = Directive::ToLeRadix { a, b, radix };
let result = solve_directives(&mut known_witnesses, &directive);
let Directive::ToLeRadix { a, b, radix } = directive;
match result {
Ok(()) => continue,
Err(_) => Opcode::Directive(Directive::ToLeRadix {
a: remap_expression(&known_witnesses, a),
b,
radix,
}),
}
}
}

Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
match known_witnesses.get(&input.witness) {
Some(known_value) if known_value.num_bits() <= input.num_bits => {
continue;
}
_ => opcode,
}
}

Opcode::BlackBoxFuncCall(_)
| Opcode::MemoryOp { .. }
| Opcode::MemoryInit { .. } => opcode,
};

new_opcodes.push(new_opcode);
new_order_list.push(order_list[idx]);
}

self.circuit.opcodes = new_opcodes;

(self.circuit, new_order_list)
}
}
20 changes: 19 additions & 1 deletion acvm-repo/acvm/src/compiler/optimizers/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ impl GeneralOptimizer {
pub(crate) fn optimize(opcode: Expression) -> Expression {
// XXX: Perhaps this optimization can be done on the fly
let opcode = remove_zero_coefficients(opcode);
simplify_mul_terms(opcode)
let opcode = simplify_mul_terms(opcode);
simplify_linear_terms(opcode)
}
}

Expand Down Expand Up @@ -42,3 +43,20 @@ fn simplify_mul_terms(mut gate: Expression) -> Expression {
gate.mul_terms = hash_map.into_iter().map(|((w_l, w_r), scale)| (scale, w_l, w_r)).collect();
gate
}

// Simplifies all linear terms with the same variables
fn simplify_linear_terms(mut gate: Expression) -> Expression {
let mut hash_map: IndexMap<Witness, FieldElement> = IndexMap::new();

// Canonicalize the ordering of the terms, lets just order by variable name
for (scale, witness) in gate.linear_combinations.into_iter() {
*hash_map.entry(witness).or_insert_with(FieldElement::zero) += scale;
}

gate.linear_combinations = hash_map
.into_iter()
.filter(|(_, scale)| scale != &FieldElement::zero())
.map(|(witness, scale)| (scale, witness))
.collect();
gate
}
21 changes: 17 additions & 4 deletions acvm-repo/acvm/src/compiler/optimizers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use acir::circuit::{Circuit, Opcode};

mod constant_backpropagation;
mod general;
mod redundant_range;
mod unused_memory;
Expand All @@ -8,6 +9,7 @@ pub(crate) use general::GeneralOptimizer;
pub(crate) use redundant_range::RangeOptimizer;
use tracing::info;

use self::constant_backpropagation::ConstantBackpropagationOptimizer;
use self::unused_memory::UnusedMemoryOptimizer;

use super::{transform_assert_messages, AcirTransformationMap};
Expand All @@ -26,6 +28,15 @@ pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) {
/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`].
#[tracing::instrument(level = "trace", name = "optimize_acir" skip(acir))]
pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec<usize>) {
// Track original acir opcode positions throughout the transformation passes of the compilation
// by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
let acir_opcode_positions = (0..acir.opcodes.len()).collect();

if acir.opcodes.len() == 1 && matches!(acir.opcodes[0], Opcode::Brillig(_)) {
info!("Program is fully unconstrained, skipping optimization pass");
return (acir, acir_opcode_positions);
}

info!("Number of opcodes before: {}", acir.opcodes.len());

// General optimizer pass
Expand All @@ -42,20 +53,22 @@ pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec<usize>) {
.collect();
let acir = Circuit { opcodes, ..acir };

// Track original acir opcode positions throughout the transformation passes of the compilation
// by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
let acir_opcode_positions = (0..acir.opcodes.len()).collect();

// Unused memory optimization pass
let memory_optimizer = UnusedMemoryOptimizer::new(acir);
let (acir, acir_opcode_positions) =
memory_optimizer.remove_unused_memory_initializations(acir_opcode_positions);

let (acir, acir_opcode_positions) =
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
ConstantBackpropagationOptimizer::backpropagate_constants(acir, acir_opcode_positions);

// Range optimization pass
let range_optimizer = RangeOptimizer::new(acir);
let (acir, acir_opcode_positions) =
range_optimizer.replace_redundant_ranges(acir_opcode_positions);

let (acir, acir_opcode_positions) =
ConstantBackpropagationOptimizer::backpropagate_constants(acir, acir_opcode_positions);

info!("Number of opcodes after: {}", acir.opcodes.len());

(acir, acir_opcode_positions)
Expand Down
Loading
Loading