Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add bigint solver in ACVM and add a unit test for bigints in Noir #4415

Merged
merged 5 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions noir/acvm-repo/acvm/src/pwg/blackbox/bigint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::collections::HashMap;

use acir::{
circuit::opcodes::FunctionInput,
native_types::{Witness, WitnessMap},
BlackBoxFunc, FieldElement,
};

use num_bigint::BigUint;

use crate::pwg::OpcodeResolutionError;

guipublic marked this conversation as resolved.
Show resolved Hide resolved
/// Resolve BigInt opcodes by storing BigInt values (and their moduli) by their ID in a HashMap:
/// - When it encounters a bigint operation opcode, it performs the operation on the stored values
/// and store the result using the provided ID.
/// - When it gets a to_bytes opcode, it simply looks up the value and resolves the output witness accordingly.
#[derive(Default)]
pub(crate) struct BigIntSolver {
bigint_id_to_value: HashMap<u32, BigUint>,
bigint_id_to_modulus: HashMap<u32, BigUint>,
}

impl BigIntSolver {
pub(crate) fn get_bigint(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_value
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}

pub(crate) fn get_modulus(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_modulus
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}
pub(crate) fn bigint_from_bytes(
&mut self,
inputs: &[FunctionInput],
modulus: &[u8],
output: u32,
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let bytes = inputs
.iter()
.map(|input| initial_witness.get(&input.witness).unwrap().to_u128() as u8)
.collect::<Vec<u8>>();
let bigint = BigUint::from_bytes_le(&bytes);
self.bigint_id_to_value.insert(output, bigint);
let modulus = BigUint::from_bytes_le(modulus);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}

pub(crate) fn bigint_to_bytes(
&self,
input: u32,
outputs: &Vec<Witness>,
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let bigint = self.get_bigint(input, BlackBoxFunc::BigIntToLeBytes)?;

let mut bytes = bigint.to_bytes_le();
while bytes.len() < outputs.len() {
bytes.push(0);
}
bytes.iter().zip(outputs.iter()).for_each(|(byte, output)| {
initial_witness.insert(*output, FieldElement::from(*byte as u128));
});
Ok(())
}

pub(crate) fn bigint_op(
&mut self,
lhs: u32,
rhs: u32,
output: u32,
func: BlackBoxFunc,
) -> Result<(), OpcodeResolutionError> {
let modulus = self.get_modulus(lhs, func)?;
let lhs = self.get_bigint(lhs, func)?;
let rhs = self.get_bigint(rhs, func)?;
let mut result = match func {
BlackBoxFunc::BigIntAdd => lhs + rhs,
BlackBoxFunc::BigIntNeg => {
if lhs >= rhs {
&lhs - &rhs
} else {
&lhs + &modulus - &rhs
}
}
BlackBoxFunc::BigIntMul => lhs * rhs,
BlackBoxFunc::BigIntDiv => {
lhs * rhs.modpow(&(&modulus - BigUint::from(1_u32)), &modulus)
} //TODO ensure that modulus is prime
_ => unreachable!("ICE - bigint_op must be called for an operation"),
};
if result > modulus {
let q = &result / &modulus;
result -= q * &modulus;
}
self.bigint_id_to_value.insert(output, result);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}
}
22 changes: 15 additions & 7 deletions noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use acir::{
};
use acvm_blackbox_solver::{blake2s, blake3, keccak256, keccakf1600, sha256};

use self::pedersen::pedersen_hash;
use self::{bigint::BigIntSolver, pedersen::pedersen_hash};

use super::{insert_value, OpcodeNotSolvable, OpcodeResolutionError};
use crate::{pwg::witness_to_value, BlackBoxFunctionSolver};

pub(crate) mod bigint;
mod fixed_base_scalar_mul;
mod hash;
mod logic;
Expand Down Expand Up @@ -53,6 +54,7 @@ pub(crate) fn solve(
backend: &impl BlackBoxFunctionSolver,
initial_witness: &mut WitnessMap,
bb_func: &BlackBoxFuncCall,
bigint_solver: &mut BigIntSolver,
) -> Result<(), OpcodeResolutionError> {
let inputs = bb_func.get_inputs_vec();
if !contains_all_inputs(initial_witness, &inputs) {
Expand Down Expand Up @@ -190,12 +192,18 @@ pub(crate) fn solve(
}
// Recursive aggregation will be entirely handled by the backend and is not solved by the ACVM
BlackBoxFuncCall::RecursiveAggregation { .. } => Ok(()),
BlackBoxFuncCall::BigIntAdd { .. } => todo!(),
BlackBoxFuncCall::BigIntNeg { .. } => todo!(),
BlackBoxFuncCall::BigIntMul { .. } => todo!(),
BlackBoxFuncCall::BigIntDiv { .. } => todo!(),
BlackBoxFuncCall::BigIntFromLeBytes { .. } => todo!(),
BlackBoxFuncCall::BigIntToLeBytes { .. } => todo!(),
BlackBoxFuncCall::BigIntAdd { lhs, rhs, output }
| BlackBoxFuncCall::BigIntNeg { lhs, rhs, output }
| BlackBoxFuncCall::BigIntMul { lhs, rhs, output }
| BlackBoxFuncCall::BigIntDiv { lhs, rhs, output } => {
bigint_solver.bigint_op(*lhs, *rhs, *output, bb_func.get_black_box_func())
}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus, output } => {
bigint_solver.bigint_from_bytes(inputs, modulus, *output, initial_witness)
}
BlackBoxFuncCall::BigIntToLeBytes { input, outputs } => {
bigint_solver.bigint_to_bytes(*input, outputs, initial_witness)
}
BlackBoxFuncCall::Poseidon2Permutation { .. } => todo!(),
BlackBoxFuncCall::Sha256Compression { .. } => todo!(),
}
Expand Down
17 changes: 13 additions & 4 deletions noir/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use acir::{
};
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{arithmetic::ExpressionSolver, directives::solve_directives, memory_op::MemoryOpSolver};
use self::{
arithmetic::ExpressionSolver, blackbox::bigint::BigIntSolver, directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::BlackBoxFunctionSolver;

use thiserror::Error;
Expand Down Expand Up @@ -132,6 +135,8 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {
/// Stores the solver for memory operations acting on blocks of memory disambiguated by [block][`BlockId`].
block_solvers: HashMap<BlockId, MemoryOpSolver>,

bigint_solver: BigIntSolver,

/// A list of opcodes which are to be executed by the ACVM.
opcodes: &'a [Opcode],
/// Index of the next opcode to be executed.
Expand All @@ -149,6 +154,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
status,
backend,
block_solvers: HashMap::default(),
bigint_solver: BigIntSolver::default(),
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
Expand Down Expand Up @@ -254,9 +260,12 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {

let resolution = match opcode {
Opcode::AssertZero(expr) => ExpressionSolver::solve(&mut self.witness_map, expr),
Opcode::BlackBoxFuncCall(bb_func) => {
blackbox::solve(self.backend, &mut self.witness_map, bb_func)
}
Opcode::BlackBoxFuncCall(bb_func) => blackbox::solve(
self.backend,
&mut self.witness_map,
bb_func,
&mut self.bigint_solver,
),
Opcode::Directive(directive) => solve_directives(&mut self.witness_map, directive),
Opcode::MemoryInit { block_id, init } => {
let solver = self.block_solvers.entry(*block_id).or_default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,8 @@ impl AcirContext {
for i in const_inputs {
field_inputs.push(i?);
}
let modulus = self.big_int_ctx.modulus(field_inputs[0]);
let bigint = self.big_int_ctx.get(field_inputs[0]);
let modulus = self.big_int_ctx.modulus(bigint.modulus_id());
let bytes_len = ((modulus - BigUint::from(1_u32)).bits() - 1) / 8 + 1;
output_count = bytes_len as usize;
(field_inputs, vec![FieldElement::from(bytes_len as u128)])
Expand Down
38 changes: 33 additions & 5 deletions noir/noir_stdlib/src/bigint.nr
Original file line number Diff line number Diff line change
@@ -1,27 +1,55 @@
use crate::ops::{Add, Sub, Mul, Div, Rem,};


global bn254_fq = [0x47, 0xFD, 0x7C, 0xD8, 0x16, 0x8C, 0x20, 0x3C, 0x8d, 0xca, 0x71, 0x68, 0x91, 0x6a, 0x81, 0x97,
0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, 0x64, 0x30];
global bn254_fr = [0x01, 0x00, 0x00, 0x00, 0x3F, 0x59, 0x1F, 0x43, 0x09, 0x97, 0xB9, 0x79, 0x48, 0xE8, 0x33, 0x28,
0x5D, 0x58, 0x81, 0x81, 0xB6, 0x45, 0x50, 0xB8, 0x29, 0xA0, 0x31, 0xE1, 0x72, 0x4E, 0x64, 0x30];
global secpk1_fr = [0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
global secpk1_fq = [0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
global secpr1_fq = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF];
global secpr1_fr = [0x51, 0x25, 0x63, 0xFC, 0xC2, 0xCA, 0xB9, 0xF3, 0x84, 0x9E, 0x17, 0xA7, 0xAD, 0xFA, 0xE6, 0xBC,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,0xFF, 0xFF, 0xFF, 0xFF];


struct BigInt {
pointer: u32,
modulus: u32,
}

impl BigInt {
#[builtin(bigint_add)]
pub fn bigint_add(self, other: BigInt) -> BigInt {
fn bigint_add(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_neg)]
pub fn bigint_neg(self, other: BigInt) -> BigInt {
fn bigint_neg(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_mul)]
pub fn bigint_mul(self, other: BigInt) -> BigInt {
fn bigint_mul(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_div)]
pub fn bigint_div(self, other: BigInt) -> BigInt {
fn bigint_div(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_from_le_bytes)]
pub fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt {}
fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt {}
#[builtin(bigint_to_le_bytes)]
pub fn to_le_bytes(self) -> [u8] {}

pub fn bn254_fr_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, bn254_fr)
Comment on lines +41 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay to merge, but I think we should likely just have different BigInt structures so that we get type safety. Can you open up an issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
pub fn bn254_fq_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, bn254_fq)
}
pub fn secpk1_fq_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, secpk1_fq)
}
pub fn secpk1_fr_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, secpk1_fr)
}
}

impl Add for BigInt {
Expand Down
2 changes: 1 addition & 1 deletion noir/noir_stdlib/src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod ops;
mod default;
mod prelude;
mod uint128;
// mod bigint;
mod bigint;

// Oracle calls are required to be wrapped in an unconstrained function
// Thus, the only argument to the `println` oracle is expected to always be an ident
Expand Down
6 changes: 6 additions & 0 deletions noir/test_programs/execution_success/bigint/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "bigint"
type = "bin"
authors = [""]

[dependencies]
2 changes: 2 additions & 0 deletions noir/test_programs/execution_success/bigint/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = [34,3,5,8,4]
y = [44,7,1,8,8]
21 changes: 21 additions & 0 deletions noir/test_programs/execution_success/bigint/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use dep::std::bigint;

fn main(mut x: [u8;5], y: [u8;5]) {
let a = bigint::BigInt::secpk1_fq_from_le_bytes([x[0],x[1],x[2],x[3],x[4]]);
let b = bigint::BigInt::secpk1_fq_from_le_bytes([y[0],y[1],y[2],y[3],y[4]]);

let a_bytes = a.to_le_bytes();
let b_bytes = b.to_le_bytes();
for i in 0..5 {
assert(a_bytes[i] == x[i]);
assert(b_bytes[i] == y[i]);
}

let d = a*b - b;
let d_bytes = d.to_le_bytes();
let d1 = bigint::BigInt::secpk1_fq_from_le_bytes(597243850900842442924.to_le_bytes(10));
let d1_bytes = d1.to_le_bytes();
for i in 0..32 {
assert(d_bytes[i] == d1_bytes[i]);
}
}
Loading