From 836a3f2453afae83dae9cb57df11b7cf4ef90af8 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Mon, 2 Oct 2023 12:10:33 +0000 Subject: [PATCH] feat: Maintain shape of foreign call arguments --- acvm-repo/acvm/src/pwg/brillig.rs | 4 +- acvm-repo/acvm/tests/solver.rs | 15 ++++--- acvm-repo/acvm_js/src/foreign_call/inputs.rs | 8 ++-- acvm-repo/acvm_js/src/foreign_call/outputs.rs | 12 +++--- acvm-repo/brillig/src/foreign_call.rs | 40 ++++++++++++++++--- acvm-repo/brillig/src/lib.rs | 2 +- acvm-repo/brillig_vm/src/lib.rs | 40 ++++++++++--------- .../noirc_evaluator/src/brillig/brillig_ir.rs | 4 +- compiler/noirc_printable_type/src/lib.rs | 28 ++++++------- tooling/nargo/src/ops/foreign_calls.rs | 19 +++++---- 10 files changed, 106 insertions(+), 66 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 4941e20d5b8..9b0ecd87492 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -1,5 +1,5 @@ use acir::{ - brillig::{RegisterIndex, Value}, + brillig::{ForeignCallParam, RegisterIndex, Value}, circuit::{ brillig::{Brillig, BrilligInputs, BrilligOutputs}, OpcodeLocation, @@ -159,5 +159,5 @@ pub struct ForeignCallWaitInfo { /// An identifier interpreted by the caller process pub function: String, /// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process - pub inputs: Vec>, + pub inputs: Vec, } diff --git a/acvm-repo/acvm/tests/solver.rs b/acvm-repo/acvm/tests/solver.rs index ca0ca99ba07..d2bd3347945 100644 --- a/acvm-repo/acvm/tests/solver.rs +++ b/acvm-repo/acvm/tests/solver.rs @@ -145,7 +145,8 @@ fn inversion_brillig_oracle_equivalence() { assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); // As caller of VM, need to resolve foreign calls - let foreign_call_result = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse()); + let foreign_call_result = + Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse()); // Alter Brillig oracle opcode with foreign call resolution acvm.resolve_pending_foreign_call(foreign_call_result.into()); @@ -274,7 +275,8 @@ fn double_inversion_brillig_oracle() { acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); - let x_plus_y_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse()); + let x_plus_y_inverse = + Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse()); // Resolve Brillig foreign call acvm.resolve_pending_foreign_call(x_plus_y_inverse.into()); @@ -291,7 +293,8 @@ fn double_inversion_brillig_oracle() { acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); - let i_plus_j_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse()); + let i_plus_j_inverse = + Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse()); assert_ne!(x_plus_y_inverse, i_plus_j_inverse); // Alter Brillig oracle opcode @@ -396,7 +399,8 @@ fn oracle_dependent_execution() { assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); // Resolve Brillig foreign call - let x_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse()); + let x_inverse = + Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse()); acvm.resolve_pending_foreign_call(x_inverse.into()); // After filling data request, continue solving @@ -412,7 +416,8 @@ fn oracle_dependent_execution() { assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); // Resolve Brillig foreign call - let y_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse()); + let y_inverse = + Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse()); acvm.resolve_pending_foreign_call(y_inverse.into()); // We've resolved all the brillig foreign calls so we should be able to complete execution now. diff --git a/acvm-repo/acvm_js/src/foreign_call/inputs.rs b/acvm-repo/acvm_js/src/foreign_call/inputs.rs index 1238f9b1cb2..728fc2d3d2f 100644 --- a/acvm-repo/acvm_js/src/foreign_call/inputs.rs +++ b/acvm-repo/acvm_js/src/foreign_call/inputs.rs @@ -1,12 +1,14 @@ -use acvm::brillig_vm::brillig::Value; +use acvm::brillig_vm::brillig::ForeignCallParam; use crate::js_witness_map::field_element_to_js_string; -pub(super) fn encode_foreign_call_inputs(foreign_call_inputs: &[Vec]) -> js_sys::Array { +pub(super) fn encode_foreign_call_inputs( + foreign_call_inputs: &[ForeignCallParam], +) -> js_sys::Array { let inputs = js_sys::Array::default(); for input in foreign_call_inputs { let input_array = js_sys::Array::default(); - for value in input { + for value in input.values() { let hex_js_string = field_element_to_js_string(&value.to_field()); input_array.push(&hex_js_string); } diff --git a/acvm-repo/acvm_js/src/foreign_call/outputs.rs b/acvm-repo/acvm_js/src/foreign_call/outputs.rs index eea686b9369..630b1afb6fd 100644 --- a/acvm-repo/acvm_js/src/foreign_call/outputs.rs +++ b/acvm-repo/acvm_js/src/foreign_call/outputs.rs @@ -1,20 +1,20 @@ -use acvm::brillig_vm::brillig::{ForeignCallOutput, ForeignCallResult, Value}; +use acvm::brillig_vm::brillig::{ForeignCallParam, ForeignCallResult, Value}; use wasm_bindgen::JsValue; use crate::js_witness_map::js_value_to_field_element; -fn decode_foreign_call_output(output: JsValue) -> Result { +fn decode_foreign_call_output(output: JsValue) -> Result { if output.is_string() { let value = Value::from(js_value_to_field_element(output)?); - Ok(ForeignCallOutput::Single(value)) + Ok(ForeignCallParam::Single(value)) } else if output.is_array() { let output = js_sys::Array::from(&output); let mut values: Vec = Vec::with_capacity(output.length() as usize); for elem in output.iter() { - values.push(Value::from(js_value_to_field_element(elem)?)) + values.push(Value::from(js_value_to_field_element(elem)?)); } - Ok(ForeignCallOutput::Array(values)) + Ok(ForeignCallParam::Array(values)) } else { return Err("Non-string-or-array element in foreign_call_handler return".into()); } @@ -23,7 +23,7 @@ fn decode_foreign_call_output(output: JsValue) -> Result Result { - let mut values: Vec = Vec::with_capacity(js_array.length() as usize); + let mut values: Vec = Vec::with_capacity(js_array.length() as usize); for elem in js_array.iter() { values.push(decode_foreign_call_output(elem)?); } diff --git a/acvm-repo/brillig/src/foreign_call.rs b/acvm-repo/brillig/src/foreign_call.rs index 3c041d00d02..1359d7d604d 100644 --- a/acvm-repo/brillig/src/foreign_call.rs +++ b/acvm-repo/brillig/src/foreign_call.rs @@ -3,32 +3,60 @@ use serde::{Deserialize, Serialize}; /// Single output of a [foreign call][crate::Opcode::ForeignCall]. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] -pub enum ForeignCallOutput { +pub enum ForeignCallParam { Single(Value), Array(Vec), } +impl From for ForeignCallParam { + fn from(value: Value) -> Self { + ForeignCallParam::Single(value) + } +} + +impl From> for ForeignCallParam { + fn from(values: Vec) -> Self { + ForeignCallParam::Array(values) + } +} + +impl ForeignCallParam { + pub fn values(&self) -> Vec { + match self { + ForeignCallParam::Single(value) => vec![*value], + ForeignCallParam::Array(values) => values.clone(), + } + } + + pub fn unwrap_value(&self) -> Value { + match self { + ForeignCallParam::Single(value) => *value, + ForeignCallParam::Array(_) => panic!("Expected single value, found array"), + } + } +} + /// Represents the full output of a [foreign call][crate::Opcode::ForeignCall]. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] pub struct ForeignCallResult { /// Resolved output values of the foreign call. - pub values: Vec, + pub values: Vec, } impl From for ForeignCallResult { fn from(value: Value) -> Self { - ForeignCallResult { values: vec![ForeignCallOutput::Single(value)] } + ForeignCallResult { values: vec![value.into()] } } } impl From> for ForeignCallResult { fn from(values: Vec) -> Self { - ForeignCallResult { values: vec![ForeignCallOutput::Array(values)] } + ForeignCallResult { values: vec![values.into()] } } } -impl From> for ForeignCallResult { - fn from(values: Vec) -> Self { +impl From> for ForeignCallResult { + fn from(values: Vec) -> Self { ForeignCallResult { values } } } diff --git a/acvm-repo/brillig/src/lib.rs b/acvm-repo/brillig/src/lib.rs index 042caab99c9..422590d1f49 100644 --- a/acvm-repo/brillig/src/lib.rs +++ b/acvm-repo/brillig/src/lib.rs @@ -16,7 +16,7 @@ mod opcodes; mod value; pub use black_box::BlackBoxOp; -pub use foreign_call::{ForeignCallOutput, ForeignCallResult}; +pub use foreign_call::{ForeignCallParam, ForeignCallResult}; pub use opcodes::{ BinaryFieldOp, BinaryIntOp, HeapArray, HeapVector, RegisterIndex, RegisterOrMemory, }; diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 8628a53a27e..e2c8ae6521a 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -12,8 +12,8 @@ //! [acvm]: https://crates.io/crates/acvm use acir::brillig::{ - BinaryFieldOp, BinaryIntOp, ForeignCallOutput, ForeignCallResult, HeapArray, HeapVector, - Opcode, RegisterIndex, RegisterOrMemory, Value, + BinaryFieldOp, BinaryIntOp, ForeignCallParam, ForeignCallResult, HeapArray, HeapVector, Opcode, + RegisterIndex, RegisterOrMemory, Value, }; use acir::FieldElement; // Re-export `brillig`. @@ -54,7 +54,7 @@ pub enum VMStatus { function: String, /// Input values /// Each input is a list of values as an input can be either a single value or a memory pointer - inputs: Vec>, + inputs: Vec, }, } @@ -119,7 +119,11 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { /// Sets the status of the VM to `ForeignCallWait`. /// Indicating that the VM is now waiting for a foreign call to be resolved. - fn wait_for_foreign_call(&mut self, function: String, inputs: Vec>) -> VMStatus { + fn wait_for_foreign_call( + &mut self, + function: String, + inputs: Vec, + ) -> VMStatus { self.status(VMStatus::ForeignCallWait { function, inputs }) } @@ -210,7 +214,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { for (destination, output) in destinations.iter().zip(values) { match destination { RegisterOrMemory::RegisterIndex(value_index) => match output { - ForeignCallOutput::Single(value) => { + ForeignCallParam::Single(value) => { self.registers.set(*value_index, *value); } _ => unreachable!( @@ -219,7 +223,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { }, RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => { match output { - ForeignCallOutput::Array(values) => { + ForeignCallParam::Array(values) => { if values.len() != *size { invalid_foreign_call_result = true; break; @@ -236,7 +240,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { } RegisterOrMemory::HeapVector(HeapVector { pointer: pointer_index, size: size_index }) => { match output { - ForeignCallOutput::Array(values) => { + ForeignCallParam::Array(values) => { // Set our size in the size register self.registers.set(*size_index, Value::from(values.len())); // Convert the destination pointer to a usize @@ -330,14 +334,12 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { self.status.clone() } - fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> Vec { + fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> ForeignCallParam { match input { - RegisterOrMemory::RegisterIndex(value_index) => { - vec![self.registers.get(value_index)] - } + RegisterOrMemory::RegisterIndex(value_index) => self.registers.get(value_index).into(), RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => { let start = self.registers.get(pointer_index); - self.memory.read_slice(start.to_usize(), size).to_vec() + self.memory.read_slice(start.to_usize(), size).to_vec().into() } RegisterOrMemory::HeapVector(HeapVector { pointer: pointer_index, @@ -345,7 +347,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { }) => { let start = self.registers.get(pointer_index); let size = self.registers.get(size_index); - self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec() + self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec().into() } } } @@ -922,7 +924,7 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "double".into(), - inputs: vec![vec![Value::from(5u128)]] + inputs: vec![Value::from(5u128).into()] } ); @@ -983,7 +985,7 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "matrix_2x2_transpose".into(), - inputs: vec![initial_matrix] + inputs: vec![initial_matrix.into()] } ); @@ -1056,13 +1058,13 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "string_double".into(), - inputs: vec![input_string.clone()] + inputs: vec![input_string.clone().into()] } ); // Push result we're waiting for vm.foreign_call_results.push(ForeignCallResult { - values: vec![ForeignCallOutput::Array(output_string.clone())], + values: vec![ForeignCallParam::Array(output_string.clone())], }); // Resume VM @@ -1118,7 +1120,7 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "matrix_2x2_transpose".into(), - inputs: vec![initial_matrix.clone()] + inputs: vec![initial_matrix.clone().into()] } ); @@ -1203,7 +1205,7 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "matrix_2x2_transpose".into(), - inputs: vec![matrix_a, matrix_b] + inputs: vec![matrix_a.into(), matrix_b.into()] } ); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index d1ce1b551b2..a5647c38b60 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -1009,7 +1009,7 @@ pub(crate) mod tests { use std::vec; use acvm::acir::brillig::{ - BinaryIntOp, ForeignCallOutput, ForeignCallResult, HeapVector, RegisterIndex, + BinaryIntOp, ForeignCallParam, ForeignCallResult, HeapVector, RegisterIndex, RegisterOrMemory, Value, }; use acvm::brillig_vm::{Registers, VMStatus, VM}; @@ -1133,7 +1133,7 @@ pub(crate) mod tests { Registers { inner: vec![] }, vec![], bytecode, - vec![ForeignCallResult { values: vec![ForeignCallOutput::Array(number_sequence)] }], + vec![ForeignCallResult { values: vec![ForeignCallParam::Array(number_sequence)] }], &DummyBlackBoxSolver, ); let status = vm.process_opcodes(); diff --git a/compiler/noirc_printable_type/src/lib.rs b/compiler/noirc_printable_type/src/lib.rs index 46f59f665a0..45d8926fd83 100644 --- a/compiler/noirc_printable_type/src/lib.rs +++ b/compiler/noirc_printable_type/src/lib.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, str}; -use acvm::{brillig_vm::brillig::Value, FieldElement}; +use acvm::{brillig_vm::brillig::ForeignCallParam, FieldElement}; use iter_extended::vecmap; use regex::{Captures, Regex}; use serde::{Deserialize, Serialize}; @@ -75,14 +75,14 @@ pub enum ForeignCallError { ParsingError(#[from] serde_json::Error), } -impl TryFrom<&[Vec]> for PrintableValueDisplay { +impl TryFrom<&[ForeignCallParam]> for PrintableValueDisplay { type Error = ForeignCallError; - fn try_from(foreign_call_inputs: &[Vec]) -> Result { + fn try_from(foreign_call_inputs: &[ForeignCallParam]) -> Result { let (is_fmt_str, foreign_call_inputs) = foreign_call_inputs.split_last().ok_or(ForeignCallError::MissingForeignCallInputs)?; - if is_fmt_str[0].to_field().is_one() { + if is_fmt_str.unwrap_value().to_field().is_one() { convert_fmt_string_inputs(foreign_call_inputs) } else { convert_string_inputs(foreign_call_inputs) @@ -91,7 +91,7 @@ impl TryFrom<&[Vec]> for PrintableValueDisplay { } fn convert_string_inputs( - foreign_call_inputs: &[Vec], + foreign_call_inputs: &[ForeignCallParam], ) -> Result { // Fetch the PrintableType from the foreign call input // The remaining input values should hold what is to be printed @@ -101,7 +101,7 @@ fn convert_string_inputs( // We must use a flat map here as each value in a struct will be in a separate input value let mut input_values_as_fields = - input_values.iter().flat_map(|values| vecmap(values, |value| value.to_field())); + input_values.iter().flat_map(|param| vecmap(param.values(), |value| value.to_field())); let value = decode_value(&mut input_values_as_fields, &printable_type); @@ -109,12 +109,12 @@ fn convert_string_inputs( } fn convert_fmt_string_inputs( - foreign_call_inputs: &[Vec], + foreign_call_inputs: &[ForeignCallParam], ) -> Result { - let (message_as_values, input_and_printable_values) = + let (message, input_and_printable_values) = foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; - let message_as_fields = vecmap(message_as_values, |value| value.to_field()); + let message_as_fields = vecmap(message.values(), |value| value.to_field()); let message_as_string = decode_string_value(&message_as_fields); let (num_values, input_and_printable_values) = input_and_printable_values @@ -122,7 +122,7 @@ fn convert_fmt_string_inputs( .ok_or(ForeignCallError::MissingForeignCallInputs)?; let mut output = Vec::new(); - let num_values = num_values[0].to_field().to_u128() as usize; + let num_values = num_values.unwrap_value().to_field().to_u128() as usize; for (i, printable_value) in input_and_printable_values .iter() @@ -134,7 +134,7 @@ fn convert_fmt_string_inputs( let mut input_values_as_fields = input_and_printable_values[i..(i + type_size)] .iter() - .flat_map(|values| vecmap(values, |value| value.to_field())); + .flat_map(|param| vecmap(param.values(), |value| value.to_field())); let value = decode_value(&mut input_values_as_fields, &printable_type); @@ -145,9 +145,9 @@ fn convert_fmt_string_inputs( } fn fetch_printable_type( - printable_type_as_values: &[Value], + printable_type: &ForeignCallParam, ) -> Result { - let printable_type_as_fields = vecmap(printable_type_as_values, |value| value.to_field()); + let printable_type_as_fields = vecmap(printable_type.values(), |value| value.to_field()); let printable_type_as_string = decode_string_value(&printable_type_as_fields); let printable_type: PrintableType = serde_json::from_str(&printable_type_as_string)?; @@ -307,7 +307,7 @@ fn decode_value( } } -fn decode_string_value(field_elements: &[FieldElement]) -> String { +pub fn decode_string_value(field_elements: &[FieldElement]) -> String { // TODO: Replace with `into` when Char is supported let string_as_slice = vecmap(field_elements, |e| { let mut field_as_bytes = e.to_be_bytes(); diff --git a/tooling/nargo/src/ops/foreign_calls.rs b/tooling/nargo/src/ops/foreign_calls.rs index db8cdceb20a..68962978fed 100644 --- a/tooling/nargo/src/ops/foreign_calls.rs +++ b/tooling/nargo/src/ops/foreign_calls.rs @@ -1,5 +1,6 @@ use acvm::{ - acir::brillig::{ForeignCallOutput, ForeignCallResult, Value}, + acir::brillig::{ForeignCallResult, Value}, + brillig_vm::brillig::ForeignCallParam, pwg::ForeignCallWaitInfo, }; use iter_extended::vecmap; @@ -52,24 +53,26 @@ impl ForeignCall { Ok(ForeignCallResult { values: vec![] }) } Some(ForeignCall::Sequence) => { - let sequence_length: u128 = foreign_call.inputs[0][0].to_field().to_u128(); + let sequence_length: u128 = + foreign_call.inputs[0].unwrap_value().to_field().to_u128(); let sequence = vecmap(0..sequence_length, Value::from); Ok(ForeignCallResult { values: vec![ - ForeignCallOutput::Single(sequence_length.into()), - ForeignCallOutput::Array(sequence), + ForeignCallParam::Single(sequence_length.into()), + ForeignCallParam::Array(sequence), ], }) } Some(ForeignCall::ReverseSequence) => { - let sequence_length: u128 = foreign_call.inputs[0][0].to_field().to_u128(); + let sequence_length: u128 = + foreign_call.inputs[0].unwrap_value().to_field().to_u128(); let sequence = vecmap((0..sequence_length).rev(), Value::from); Ok(ForeignCallResult { values: vec![ - ForeignCallOutput::Single(sequence_length.into()), - ForeignCallOutput::Array(sequence), + ForeignCallParam::Single(sequence_length.into()), + ForeignCallParam::Array(sequence), ], }) } @@ -77,7 +80,7 @@ impl ForeignCall { } } - fn execute_println(foreign_call_inputs: &[Vec]) -> Result<(), NargoError> { + fn execute_println(foreign_call_inputs: &[ForeignCallParam]) -> Result<(), NargoError> { let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?; println!("{display_values}"); Ok(())