Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Maintain shape of foreign call arguments #2935

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use acir::{
brillig::{RegisterIndex, Value},
brillig::{ForeignCallParam, RegisterIndex, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
OpcodeLocation,
Expand Down Expand Up @@ -159,5 +159,5 @@ pub struct ForeignCallWaitInfo {
/// An identifier interpreted by the caller process
pub function: String,
/// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process
pub inputs: Vec<Vec<Value>>,
pub inputs: Vec<ForeignCallParam>,
}
15 changes: 10 additions & 5 deletions acvm-repo/acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ fn inversion_brillig_oracle_equivalence() {
assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input");

// As caller of VM, need to resolve foreign calls
let foreign_call_result = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse());
let foreign_call_result =
Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse());
// Alter Brillig oracle opcode with foreign call resolution
acvm.resolve_pending_foreign_call(foreign_call_result.into());

Expand Down Expand Up @@ -274,7 +275,8 @@ fn double_inversion_brillig_oracle() {
acvm.get_pending_foreign_call().expect("should have a brillig foreign call request");
assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input");

let x_plus_y_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse());
let x_plus_y_inverse =
Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse());

// Resolve Brillig foreign call
acvm.resolve_pending_foreign_call(x_plus_y_inverse.into());
Expand All @@ -291,7 +293,8 @@ fn double_inversion_brillig_oracle() {
acvm.get_pending_foreign_call().expect("should have a brillig foreign call request");
assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input");

let i_plus_j_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse());
let i_plus_j_inverse =
Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse());
assert_ne!(x_plus_y_inverse, i_plus_j_inverse);

// Alter Brillig oracle opcode
Expand Down Expand Up @@ -396,7 +399,8 @@ fn oracle_dependent_execution() {
assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input");

// Resolve Brillig foreign call
let x_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse());
let x_inverse =
Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse());
acvm.resolve_pending_foreign_call(x_inverse.into());

// After filling data request, continue solving
Expand All @@ -412,7 +416,8 @@ fn oracle_dependent_execution() {
assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input");

// Resolve Brillig foreign call
let y_inverse = Value::from(foreign_call_wait_info.inputs[0][0].to_field().inverse());
let y_inverse =
Value::from(foreign_call_wait_info.inputs[0].unwrap_value().to_field().inverse());
acvm.resolve_pending_foreign_call(y_inverse.into());

// We've resolved all the brillig foreign calls so we should be able to complete execution now.
Expand Down
8 changes: 5 additions & 3 deletions acvm-repo/acvm_js/src/foreign_call/inputs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use acvm::brillig_vm::brillig::Value;
use acvm::brillig_vm::brillig::ForeignCallParam;

use crate::js_witness_map::field_element_to_js_string;

pub(super) fn encode_foreign_call_inputs(foreign_call_inputs: &[Vec<Value>]) -> js_sys::Array {
pub(super) fn encode_foreign_call_inputs(
foreign_call_inputs: &[ForeignCallParam],
) -> js_sys::Array {
let inputs = js_sys::Array::default();
for input in foreign_call_inputs {
let input_array = js_sys::Array::default();
for value in input {
for value in input.values() {
let hex_js_string = field_element_to_js_string(&value.to_field());
input_array.push(&hex_js_string);
}
Expand Down
12 changes: 6 additions & 6 deletions acvm-repo/acvm_js/src/foreign_call/outputs.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use acvm::brillig_vm::brillig::{ForeignCallOutput, ForeignCallResult, Value};
use acvm::brillig_vm::brillig::{ForeignCallParam, ForeignCallResult, Value};
use wasm_bindgen::JsValue;

use crate::js_witness_map::js_value_to_field_element;

fn decode_foreign_call_output(output: JsValue) -> Result<ForeignCallOutput, String> {
fn decode_foreign_call_output(output: JsValue) -> Result<ForeignCallParam, String> {
if output.is_string() {
let value = Value::from(js_value_to_field_element(output)?);
Ok(ForeignCallOutput::Single(value))
Ok(ForeignCallParam::Single(value))
} else if output.is_array() {
let output = js_sys::Array::from(&output);

let mut values: Vec<Value> = Vec::with_capacity(output.length() as usize);
for elem in output.iter() {
values.push(Value::from(js_value_to_field_element(elem)?))
values.push(Value::from(js_value_to_field_element(elem)?));
}
Ok(ForeignCallOutput::Array(values))
Ok(ForeignCallParam::Array(values))
} else {
return Err("Non-string-or-array element in foreign_call_handler return".into());
}
Expand All @@ -23,7 +23,7 @@ fn decode_foreign_call_output(output: JsValue) -> Result<ForeignCallOutput, Stri
pub(super) fn decode_foreign_call_result(
js_array: js_sys::Array,
) -> Result<ForeignCallResult, String> {
let mut values: Vec<ForeignCallOutput> = Vec::with_capacity(js_array.length() as usize);
let mut values: Vec<ForeignCallParam> = Vec::with_capacity(js_array.length() as usize);
for elem in js_array.iter() {
values.push(decode_foreign_call_output(elem)?);
}
Expand Down
40 changes: 34 additions & 6 deletions acvm-repo/brillig/src/foreign_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,60 @@ use serde::{Deserialize, Serialize};

/// Single output of a [foreign call][crate::Opcode::ForeignCall].
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
pub enum ForeignCallOutput {
pub enum ForeignCallParam {
Single(Value),
Array(Vec<Value>),
}

impl From<Value> for ForeignCallParam {
fn from(value: Value) -> Self {
ForeignCallParam::Single(value)
}
}

impl From<Vec<Value>> for ForeignCallParam {
fn from(values: Vec<Value>) -> Self {
ForeignCallParam::Array(values)
}
}

impl ForeignCallParam {
pub fn values(&self) -> Vec<Value> {
match self {
ForeignCallParam::Single(value) => vec![*value],
ForeignCallParam::Array(values) => values.clone(),
}
}

pub fn unwrap_value(&self) -> Value {
match self {
ForeignCallParam::Single(value) => *value,
ForeignCallParam::Array(_) => panic!("Expected single value, found array"),
}
}
}

/// Represents the full output of a [foreign call][crate::Opcode::ForeignCall].
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
pub struct ForeignCallResult {
/// Resolved output values of the foreign call.
pub values: Vec<ForeignCallOutput>,
pub values: Vec<ForeignCallParam>,
}

impl From<Value> for ForeignCallResult {
fn from(value: Value) -> Self {
ForeignCallResult { values: vec![ForeignCallOutput::Single(value)] }
ForeignCallResult { values: vec![value.into()] }
}
}

impl From<Vec<Value>> for ForeignCallResult {
fn from(values: Vec<Value>) -> Self {
ForeignCallResult { values: vec![ForeignCallOutput::Array(values)] }
ForeignCallResult { values: vec![values.into()] }
}
}

impl From<Vec<ForeignCallOutput>> for ForeignCallResult {
fn from(values: Vec<ForeignCallOutput>) -> Self {
impl From<Vec<ForeignCallParam>> for ForeignCallResult {
fn from(values: Vec<ForeignCallParam>) -> Self {
ForeignCallResult { values }
}
}
2 changes: 1 addition & 1 deletion acvm-repo/brillig/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod opcodes;
mod value;

pub use black_box::BlackBoxOp;
pub use foreign_call::{ForeignCallOutput, ForeignCallResult};
pub use foreign_call::{ForeignCallParam, ForeignCallResult};
pub use opcodes::{
BinaryFieldOp, BinaryIntOp, HeapArray, HeapVector, RegisterIndex, RegisterOrMemory,
};
Expand Down
40 changes: 21 additions & 19 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
//! [acvm]: https://crates.io/crates/acvm

use acir::brillig::{
BinaryFieldOp, BinaryIntOp, ForeignCallOutput, ForeignCallResult, HeapArray, HeapVector,
Opcode, RegisterIndex, RegisterOrMemory, Value,
BinaryFieldOp, BinaryIntOp, ForeignCallParam, ForeignCallResult, HeapArray, HeapVector, Opcode,
RegisterIndex, RegisterOrMemory, Value,
};
use acir::FieldElement;
// Re-export `brillig`.
Expand Down Expand Up @@ -54,7 +54,7 @@ pub enum VMStatus {
function: String,
/// Input values
/// Each input is a list of values as an input can be either a single value or a memory pointer
inputs: Vec<Vec<Value>>,
inputs: Vec<ForeignCallParam>,
},
}

Expand Down Expand Up @@ -119,7 +119,11 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {

/// Sets the status of the VM to `ForeignCallWait`.
/// Indicating that the VM is now waiting for a foreign call to be resolved.
fn wait_for_foreign_call(&mut self, function: String, inputs: Vec<Vec<Value>>) -> VMStatus {
fn wait_for_foreign_call(
&mut self,
function: String,
inputs: Vec<ForeignCallParam>,
) -> VMStatus {
self.status(VMStatus::ForeignCallWait { function, inputs })
}

Expand Down Expand Up @@ -210,7 +214,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
for (destination, output) in destinations.iter().zip(values) {
match destination {
RegisterOrMemory::RegisterIndex(value_index) => match output {
ForeignCallOutput::Single(value) => {
ForeignCallParam::Single(value) => {
self.registers.set(*value_index, *value);
}
_ => unreachable!(
Expand All @@ -219,7 +223,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
},
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
match output {
ForeignCallOutput::Array(values) => {
ForeignCallParam::Array(values) => {
if values.len() != *size {
invalid_foreign_call_result = true;
break;
Expand All @@ -236,7 +240,7 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
}
RegisterOrMemory::HeapVector(HeapVector { pointer: pointer_index, size: size_index }) => {
match output {
ForeignCallOutput::Array(values) => {
ForeignCallParam::Array(values) => {
// Set our size in the size register
self.registers.set(*size_index, Value::from(values.len()));
// Convert the destination pointer to a usize
Expand Down Expand Up @@ -330,22 +334,20 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
self.status.clone()
}

fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> Vec<Value> {
fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> ForeignCallParam {
match input {
RegisterOrMemory::RegisterIndex(value_index) => {
vec![self.registers.get(value_index)]
}
RegisterOrMemory::RegisterIndex(value_index) => self.registers.get(value_index).into(),
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
let start = self.registers.get(pointer_index);
self.memory.read_slice(start.to_usize(), size).to_vec()
self.memory.read_slice(start.to_usize(), size).to_vec().into()
}
RegisterOrMemory::HeapVector(HeapVector {
pointer: pointer_index,
size: size_index,
}) => {
let start = self.registers.get(pointer_index);
let size = self.registers.get(size_index);
self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec()
self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec().into()
}
}
}
Expand Down Expand Up @@ -922,7 +924,7 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "double".into(),
inputs: vec![vec![Value::from(5u128)]]
inputs: vec![Value::from(5u128).into()]
}
);

Expand Down Expand Up @@ -983,7 +985,7 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "matrix_2x2_transpose".into(),
inputs: vec![initial_matrix]
inputs: vec![initial_matrix.into()]
}
);

Expand Down Expand Up @@ -1056,13 +1058,13 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "string_double".into(),
inputs: vec![input_string.clone()]
inputs: vec![input_string.clone().into()]
}
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult {
values: vec![ForeignCallOutput::Array(output_string.clone())],
values: vec![ForeignCallParam::Array(output_string.clone())],
});

// Resume VM
Expand Down Expand Up @@ -1118,7 +1120,7 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "matrix_2x2_transpose".into(),
inputs: vec![initial_matrix.clone()]
inputs: vec![initial_matrix.clone().into()]
}
);

Expand Down Expand Up @@ -1203,7 +1205,7 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "matrix_2x2_transpose".into(),
inputs: vec![matrix_a, matrix_b]
inputs: vec![matrix_a.into(), matrix_b.into()]
}
);

Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ pub(crate) mod tests {
use std::vec;

use acvm::acir::brillig::{
BinaryIntOp, ForeignCallOutput, ForeignCallResult, HeapVector, RegisterIndex,
BinaryIntOp, ForeignCallParam, ForeignCallResult, HeapVector, RegisterIndex,
RegisterOrMemory, Value,
};
use acvm::brillig_vm::{Registers, VMStatus, VM};
Expand Down Expand Up @@ -1133,7 +1133,7 @@ pub(crate) mod tests {
Registers { inner: vec![] },
vec![],
bytecode,
vec![ForeignCallResult { values: vec![ForeignCallOutput::Array(number_sequence)] }],
vec![ForeignCallResult { values: vec![ForeignCallParam::Array(number_sequence)] }],
&DummyBlackBoxSolver,
);
let status = vm.process_opcodes();
Expand Down
Loading