Skip to content

Commit

Permalink
chore: Rework defunctionalize pass to not rely on DFG bugs (#7222)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Jan 29, 2025
1 parent bdcfd38 commit 2d2c73b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 52 deletions.
8 changes: 6 additions & 2 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,19 @@ impl DataFlowGraph {
if let Some(existing) = self.functions.get(&function) {
return *existing;
}
self.values.insert(Value::Function(function))
let result = self.values.insert(Value::Function(function));
self.functions.insert(function, result);
result
}

/// Gets or creates a ValueId for the given FunctionId.
pub(crate) fn import_foreign_function(&mut self, function: &str) -> ValueId {
if let Some(existing) = self.foreign_functions.get(function) {
return *existing;
}
self.values.insert(Value::ForeignFunction(function.to_owned()))
let result = self.values.insert(Value::ForeignFunction(function.to_owned()));
self.foreign_functions.insert(function.to_owned(), result);
result
}

/// Gets or creates a ValueId for the given Intrinsic.
Expand Down
157 changes: 111 additions & 46 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! with a non-literal target can be replaced with a call to an apply function.
//! The apply function is a dispatch function that takes the function id as a parameter
//! and dispatches to the correct target.
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::collections::{BTreeMap, BTreeSet};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -80,21 +80,54 @@ impl DefunctionalizationContext {

/// Defunctionalize a single function
fn defunctionalize(&mut self, func: &mut Function) {
let mut call_target_values = HashSet::new();

for block_id in func.reachable_blocks() {
let block = &func.dfg[block_id];
let instructions = block.instructions().to_vec();
let block = &mut func.dfg[block_id];

// Temporarily take the parameters here just to avoid cloning them
let parameters = block.take_parameters();
for parameter in &parameters {
if func.dfg.type_of_value(*parameter) == Type::Function {
func.dfg.set_type_of_value(*parameter, Type::field());
}
}

let block = &mut func.dfg[block_id];
block.set_parameters(parameters);

// Do the same for the terminator
let mut terminator = block.take_terminator();
terminator.map_values_mut(|value| map_function_to_field(func, value).unwrap_or(value));

let block = &mut func.dfg[block_id];
block.set_terminator(terminator);

for instruction_id in instructions {
let instruction = func.dfg[instruction_id].clone();
// Now we can finally change each instruction, replacing
// each first class function with a field value and replacing calls
// to a first class function to a call to the relevant `apply` function.
#[allow(clippy::unnecessary_to_owned)] // clippy is wrong here
for instruction_id in block.instructions().to_vec() {
let mut instruction = func.dfg[instruction_id].clone();
let mut replacement_instruction = None;

if remove_first_class_functions_in_instruction(func, &mut instruction) {
func.dfg[instruction_id] = instruction.clone();
}

#[allow(clippy::unnecessary_to_owned)] // clippy is wrong here
for result in func.dfg.instruction_results(instruction_id).to_vec() {
if func.dfg.type_of_value(result) == Type::Function {
func.dfg.set_type_of_value(result, Type::field());
}
}

// Operate on call instructions
let (target_func_id, arguments) = match &instruction {
Instruction::Call { func: target_func_id, arguments } => {
(*target_func_id, arguments)
}
_ => continue,
_ => {
continue;
}
};

match func.dfg[target_func_id] {
Expand All @@ -116,43 +149,15 @@ impl DefunctionalizationContext {
arguments.insert(0, target_func_id);
}
let func = apply_function_value_id;
call_target_values.insert(func);

replacement_instruction = Some(Instruction::Call { func, arguments });
}
Value::Function(..) => {
call_target_values.insert(target_func_id);
}
_ => {}
}
if let Some(new_instruction) = replacement_instruction {
func.dfg[instruction_id] = new_instruction;
}
}
}

// Change the type of all the values that are not call targets to NativeField
let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id);
for value_id in value_ids {
if let Type::Function = func.dfg[value_id].get_type().as_ref() {
match &func.dfg[value_id] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
if !call_target_values.contains(&value_id) {
let field = NumericType::NativeField;
let new_value =
func.dfg.make_constant(function_id_to_field(*id), field);
func.dfg.set_value_from_id(value_id, new_value);
}
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value_id, Type::field());
}
_ => {}
}
}
}
}

/// Returns the apply function for the given signature
Expand All @@ -161,6 +166,54 @@ impl DefunctionalizationContext {
}
}

/// Replace any first class functions used in an instruction with a field value.
/// This applies to any function used anywhere else other than the function position
/// of a call instruction. Returns true if the instruction was modified
fn remove_first_class_functions_in_instruction(
func: &mut Function,
instruction: &mut Instruction,
) -> bool {
let mut modified = false;
let mut map_value = |value: ValueId| {
if let Some(new_value) = map_function_to_field(func, value) {
modified = true;
new_value
} else {
value
}
};

if let Instruction::Call { func: _, arguments } = instruction {
for arg in arguments {
*arg = map_value(*arg);
}
} else {
instruction.map_values_mut(map_value);
}

modified
}

/// Try to map the given function literal to a field, returning Some(field) on success.
/// Returns none if the given value was not a function or doesn't need to be mapped.
fn map_function_to_field(func: &mut Function, value: ValueId) -> Option<ValueId> {
if let Type::Function = func.dfg[value].get_type().as_ref() {
match &func.dfg[value] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
let new_value = function_id_to_field(*id);
return Some(func.dfg.make_constant(new_value, NumericType::NativeField));
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value, Type::field());
}
_ => (),
}
}
None
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
Expand Down Expand Up @@ -252,13 +305,25 @@ fn create_apply_functions(
variants_map: BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>,
) -> ApplyFunctions {
let mut apply_functions = HashMap::default();
for ((signature, runtime), variants) in variants_map.into_iter() {
for ((mut signature, runtime), variants) in variants_map.into_iter() {
assert!(
!variants.is_empty(),
"ICE: at least one variant should exist for a dynamic call {signature:?}"
);
let dispatches_to_multiple_functions = variants.len() > 1;

for param in &mut signature.params {
if *param == Type::Function {
*param = Type::field();
}
}

for ret in &mut signature.returns {
if *ret == Type::Function {
*ret = Type::field();
}
}

let id = if dispatches_to_multiple_functions {
create_apply_function(ssa, signature.clone(), runtime, variants)
} else {
Expand All @@ -282,7 +347,7 @@ fn create_apply_function(
function_ids: Vec<FunctionId>,
) -> FunctionId {
assert!(!function_ids.is_empty());
let globals = ssa.functions[&function_ids[0]].dfg.globals.clone();
let globals = ssa.main().dfg.globals.clone();
ssa.add_fn(|id| {
let mut function_builder = FunctionBuilder::new("apply".to_string(), id);
function_builder.set_globals(globals);
Expand Down Expand Up @@ -386,10 +451,10 @@ mod tests {
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
v8 = call f1(f3, v0) -> u32
v9 = add v0, u32 1
v10 = eq v8, v9
constrain v8 == v9
return
}
brillig(inline) fn wrapper f1 {
Expand Down Expand Up @@ -419,10 +484,10 @@ mod tests {
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(Field 3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
v8 = call f1(Field 3, v0) -> u32
v9 = add v0, u32 1
v10 = eq v8, v9
constrain v8 == v9
return
}
brillig(inline) fn wrapper f1 {
Expand Down
8 changes: 4 additions & 4 deletions compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ mod test {
v21 = add v1, v2
v23 = array_set v19, index v21, value Field 128
call f1(v23)
v25 = add v2, u32 1
jmp b1(v25)
v24 = add v2, u32 1
jmp b1(v24)
}
brillig(inline) fn foo f1 {
b0(v0: [Field; 5]):
Expand Down Expand Up @@ -685,8 +685,8 @@ mod test {
v21 = add v1, v2
v23 = array_set v14, index v21, value Field 128
call f1(v23)
v25 = add v2, u32 1
jmp b1(v25)
v24 = add v2, u32 1
jmp b1(v24)
}
brillig(inline) fn foo f1 {
b0(v0: [Field; 5]):
Expand Down

0 comments on commit 2d2c73b

Please sign in to comment.