Skip to content

Commit

Permalink
fix(brillig): Globals entry point reachability analysis (#7188)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
vezenovm and TomAFrench authored Jan 29, 2025
1 parent 506641b commit bdcfd38
Show file tree
Hide file tree
Showing 16 changed files with 509 additions and 55 deletions.
3 changes: 2 additions & 1 deletion compiler/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ pub(crate) fn gen_brillig_for(
brillig: &Brillig,
) -> Result<GeneratedBrillig<FieldElement>, InternalError> {
// Create the entry point artifact
let globals_memory_size = brillig.globals_memory_size.get(&func.id()).copied().unwrap_or(0);
let mut entry_point = BrilligContext::new_entry_point_artifact(
arguments,
FunctionContext::return_values(func),
func.id(),
true,
brillig.globals_memory_size,
globals_memory_size,
);
entry_point.name = func.name().to_string();

Expand Down
430 changes: 420 additions & 10 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions compiler/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use acvm::{
};
use debug_show::DebugShow;

use super::{GlobalSpace, ProcedureId};
use super::{FunctionId, GlobalSpace, ProcedureId};

/// The Brillig VM does not apply a limit to the memory address space,
/// As a convention, we take use 32 bits. This means that we assume that
Expand Down Expand Up @@ -221,11 +221,14 @@ impl<F: AcirField + DebugToString> BrilligContext<F, ScratchSpace> {

/// Special brillig context to codegen global values initialization
impl<F: AcirField + DebugToString> BrilligContext<F, GlobalSpace> {
pub(crate) fn new_for_global_init(enable_debug_trace: bool) -> BrilligContext<F, GlobalSpace> {
pub(crate) fn new_for_global_init(
enable_debug_trace: bool,
entry_point: FunctionId,
) -> BrilligContext<F, GlobalSpace> {
BrilligContext {
obj: BrilligArtifact::default(),
registers: GlobalSpace::new(),
context_label: Label::globals_init(),
context_label: Label::globals_init(entry_point),
current_section: 0,
next_section: 1,
debug_show: DebugShow::new(enable_debug_trace),
Expand Down
16 changes: 12 additions & 4 deletions compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ pub(crate) enum LabelType {
/// Labels for intrinsic procedures
Procedure(ProcedureId),
/// Label for initialization of globals
GlobalInit,
/// Stores a function ID referencing the entry point
GlobalInit(FunctionId),
}

impl std::fmt::Display for LabelType {
Expand All @@ -91,7 +92,9 @@ impl std::fmt::Display for LabelType {
}
LabelType::Entrypoint => write!(f, "Entrypoint"),
LabelType::Procedure(procedure_id) => write!(f, "Procedure({:?})", procedure_id),
LabelType::GlobalInit => write!(f, "Globals Initialization"),
LabelType::GlobalInit(function_id) => {
write!(f, "Globals Initialization({function_id:?})")
}
}
}
}
Expand Down Expand Up @@ -127,8 +130,8 @@ impl Label {
Label { label_type: LabelType::Procedure(procedure_id), section: None }
}

pub(crate) fn globals_init() -> Self {
Label { label_type: LabelType::GlobalInit, section: None }
pub(crate) fn globals_init(function_id: FunctionId) -> Self {
Label { label_type: LabelType::GlobalInit(function_id), section: None }
}
}

Expand Down Expand Up @@ -334,4 +337,9 @@ impl<F: Clone + std::fmt::Debug> BrilligArtifact<F> {
pub(crate) fn set_call_stack(&mut self, call_stack: CallStack) {
self.call_stack = call_stack;
}

#[cfg(test)]
pub(crate) fn take_labels(&mut self) -> HashMap<Label, usize> {
std::mem::take(&mut self.labels)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl<F: AcirField + DebugToString> BrilligContext<F, Stack> {
context.codegen_entry_point(&arguments, &return_parameters);

if globals_init {
context.add_globals_init_instruction();
context.add_globals_init_instruction(target_function);
}

context.add_external_call_instruction(target_function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
self.obj.add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, proc_label);
}

pub(super) fn add_globals_init_instruction(&mut self) {
let globals_init_label = Label::globals_init();
pub(super) fn add_globals_init_instruction(&mut self, func_id: FunctionId) {
let globals_init_label = Label::globals_init(func_id);
self.debug_show.add_external_call_instruction(globals_init_label.to_string());
self.obj
.add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, globals_init_label);
Expand Down
4 changes: 3 additions & 1 deletion compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,14 @@ impl RegisterAllocator for ScratchSpace {
/// Globals have a separate memory space
/// This memory space is initialized once at the beginning of a program
/// and is read-only.
#[derive(Default)]
pub(crate) struct GlobalSpace {
storage: DeallocationListAllocator,
max_memory_address: usize,
}

impl GlobalSpace {
pub(super) fn new() -> Self {
pub(crate) fn new() -> Self {
Self {
storage: DeallocationListAllocator::new(Self::start()),
max_memory_address: Self::start(),
Expand Down Expand Up @@ -224,6 +225,7 @@ impl RegisterAllocator for GlobalSpace {
}
}

#[derive(Default)]
struct DeallocationListAllocator {
/// A free-list of registers that have been deallocated and can be used again.
deallocated_registers: BTreeSet<usize>,
Expand Down
42 changes: 29 additions & 13 deletions compiler/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub(crate) mod brillig_gen;
pub(crate) mod brillig_ir;

use acvm::FieldElement;
use brillig_gen::brillig_globals::convert_ssa_globals;
use brillig_gen::brillig_globals::BrilligGlobals;
use brillig_ir::{artifact::LabelType, brillig_variable::BrilligVariable, registers::GlobalSpace};

use self::{
Expand All @@ -12,15 +12,18 @@ use self::{
procedures::compile_procedure,
},
};

use crate::ssa::{
ir::{
dfg::DataFlowGraph,
function::{Function, FunctionId},
value::ValueId,
instruction::Instruction,
value::{Value, ValueId},
},
opt::inlining::called_functions_vec,
ssa_gen::Ssa,
};
use fxhash::FxHashMap as HashMap;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::{borrow::Cow, collections::BTreeSet};

pub use self::brillig_ir::procedures::ProcedureId;
Expand All @@ -31,8 +34,8 @@ pub use self::brillig_ir::procedures::ProcedureId;
pub struct Brillig {
/// Maps SSA function labels to their brillig artifact
ssa_function_to_brillig: HashMap<FunctionId, BrilligArtifact<FieldElement>>,
globals: BrilligArtifact<FieldElement>,
globals_memory_size: usize,
globals: HashMap<FunctionId, BrilligArtifact<FieldElement>>,
globals_memory_size: HashMap<FunctionId, usize>,
}

impl Brillig {
Expand All @@ -58,7 +61,7 @@ impl Brillig {
}
// Procedures are compiled as needed
LabelType::Procedure(procedure_id) => Some(Cow::Owned(compile_procedure(procedure_id))),
LabelType::GlobalInit => Some(Cow::Borrowed(&self.globals)),
LabelType::GlobalInit(function_id) => self.globals.get(&function_id).map(Cow::Borrowed),
_ => unreachable!("ICE: Expected a function or procedure label"),
}
}
Expand All @@ -72,9 +75,18 @@ impl std::ops::Index<FunctionId> for Brillig {
}

impl Ssa {
/// Compile Brillig functions and ACIR functions reachable from them
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig {
self.to_brillig_with_globals(enable_debug_trace, HashMap::default())
}

/// Compile Brillig functions and ACIR functions reachable from them
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn to_brillig_with_globals(
&self,
enable_debug_trace: bool,
used_globals_map: HashMap<FunctionId, HashSet<ValueId>>,
) -> Brillig {
// Collect all the function ids that are reachable from brillig
// That means all the functions marked as brillig and ACIR functions called by them
let brillig_reachable_function_ids = self
Expand All @@ -89,17 +101,21 @@ impl Ssa {
return brillig;
}

// Globals are computed once at compile time and shared across all functions,
let mut brillig_globals =
BrilligGlobals::new(&self.functions, used_globals_map, self.main_id);

// SSA Globals are computed once at compile time and shared across all functions,
// thus we can just fetch globals from the main function.
// This same globals graph will then be used to declare Brillig globals for the respective entry points.
let globals = (*self.functions[&self.main_id].dfg.globals).clone();
let (artifact, brillig_globals, globals_size) =
convert_ssa_globals(enable_debug_trace, globals, &self.used_global_values);
brillig.globals = artifact;
brillig.globals_memory_size = globals_size;
let globals_dfg = DataFlowGraph::from(globals);
brillig_globals.declare_globals(&globals_dfg, &mut brillig, enable_debug_trace);

for brillig_function_id in brillig_reachable_function_ids {
let globals_allocations = brillig_globals.get_brillig_globals(brillig_function_id);

let func = &self.functions[&brillig_function_id];
brillig.compile(func, enable_debug_trace, &brillig_globals);
brillig.compile(func, enable_debug_trace, &globals_allocations);
}

brillig
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::acir::{Artifacts, GeneratedAcir};
mod checks;
pub(super) mod function_builder;
pub mod ir;
mod opt;
pub(crate) mod opt;
#[cfg(test)]
pub(crate) mod parser;
pub mod ssa_gen;
Expand Down Expand Up @@ -122,8 +122,9 @@ pub(crate) fn optimize_into_acir(

drop(ssa_gen_span_guard);

let used_globals_map = std::mem::take(&mut ssa.used_globals);
let brillig = time("SSA to Brillig", options.print_codegen_timings, || {
ssa.to_brillig(options.enable_brillig_logging)
ssa.to_brillig_with_globals(options.enable_brillig_logging, used_globals_map)
});

let ssa_gen_span = span!(Level::TRACE, "ssa_generation");
Expand Down
32 changes: 21 additions & 11 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,36 @@ impl Ssa {
}

fn dead_instruction_elimination_inner(mut self, flattened: bool) -> Ssa {
let mut used_global_values: HashSet<_> = self
let mut used_globals_map: HashMap<_, _> = self
.functions
.par_iter_mut()
.flat_map(|(_, func)| func.dead_instruction_elimination(true, flattened))
.filter_map(|(id, func)| {
let set = func.dead_instruction_elimination(true, flattened);
if func.runtime().is_brillig() {
Some((*id, set))
} else {
None
}
})
.collect();

let globals = &self.functions[&self.main_id].dfg.globals;
// Check which globals are used across all functions
for (id, value) in globals.values_iter().rev() {
if used_global_values.contains(&id) {
if let Value::Instruction { instruction, .. } = &value {
let instruction = &globals[*instruction];
instruction.for_each_value(|value_id| {
used_global_values.insert(value_id);
});
for used_global_values in used_globals_map.values_mut() {
// DIE only tracks used instruction results, however, globals include constants.
// Back track globals for internal values which may be in use.
for (id, value) in globals.values_iter().rev() {
if used_global_values.contains(&id) {
if let Value::Instruction { instruction, .. } = &value {
let instruction = &globals[*instruction];
instruction.for_each_value(|value_id| {
used_global_values.insert(value_id);
});
}
}
}
}

self.used_global_values = used_global_values;
self.used_globals = used_globals_map;

self
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct PerFunctionContext<'function> {
/// Utility function to find out the direct calls of a function.
///
/// Returns the function IDs from all `Call` instructions without deduplication.
fn called_functions_vec(func: &Function) -> Vec<FunctionId> {
pub(crate) fn called_functions_vec(func: &Function) -> Vec<FunctionId> {
let mut called_function_ids = Vec::new();
for block_id in func.reachable_blocks() {
for instruction_id in func.dfg[block_id].instructions() {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod defunctionalize;
mod die;
pub(crate) mod flatten_cfg;
mod hint;
mod inlining;
pub(crate) mod inlining;
mod loop_invariant;
mod make_constrain_not_equal;
mod mem2reg;
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ impl Ssa {
if let Some(max_incr_pct) = max_bytecode_increase_percent {
if global_cache.is_none() {
let globals = (*function.dfg.globals).clone();
// DIE is run at the end of our SSA optimizations, so we mark all globals as in use here.
let used_globals = &globals.values_iter().map(|(id, _)| id).collect();
let globals_dfg = DataFlowGraph::from(globals);
// DIE is run at the end of our SSA optimizations, so we mark all globals as in use here.
let (_, brillig_globals, _) =
convert_ssa_globals(false, globals, used_globals);
convert_ssa_globals(false, &globals_dfg, used_globals, function.id());
global_cache = Some(brillig_globals);
}
let brillig_globals = global_cache.as_ref().unwrap();
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl Translator {
}
RuntimeType::Brillig(inline_type) => {
self.builder.new_brillig_function(external_name, function_id, inline_type);
self.builder.set_globals(self.globals_graph.clone());
}
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::BTreeMap;

use acvm::acir::circuit::ErrorSelector;
use fxhash::FxHashSet as HashSet;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use iter_extended::btree_map;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
Expand All @@ -20,7 +20,7 @@ use super::ValueId;
pub(crate) struct Ssa {
#[serde_as(as = "Vec<(_, _)>")]
pub(crate) functions: BTreeMap<FunctionId, Function>,
pub(crate) used_global_values: HashSet<ValueId>,
pub(crate) used_globals: HashMap<FunctionId, HashSet<ValueId>>,
pub(crate) main_id: FunctionId,
#[serde(skip)]
pub(crate) next_id: AtomicCounter<Function>,
Expand Down Expand Up @@ -59,7 +59,7 @@ impl Ssa {
error_selector_to_type: error_types,
// This field is set only after running DIE and is utilized
// for optimizing implementation of globals post-SSA.
used_global_values: HashSet::default(),
used_globals: HashMap::default(),
}
}

Expand Down
1 change: 1 addition & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"Elligator",
"endianness",
"envrc",
"EXPONENTIATE",
"Flamegraph",
"flate",
"fmtstr",
Expand Down

0 comments on commit bdcfd38

Please sign in to comment.