From 7cb7dd156a6e5b6430c6f05f4833faa1a89902b6 Mon Sep 17 00:00:00 2001 From: Lucas Xia Date: Thu, 19 Sep 2024 19:56:33 -0400 Subject: [PATCH] Revert "feat: Sync from noir (#8653)" This reverts commit 03b9e71e5ebb3d46827671b2197697b5d294d04e. --- .noir-sync-commit | 2 +- .../src/brillig/brillig_ir/codegen_stack.rs | 286 +----------------- .../noirc_frontend/src/elaborator/types.rs | 18 +- .../src/hir/comptime/interpreter.rs | 20 +- .../compiler/noirc_frontend/src/tests.rs | 24 -- .../noir_stdlib/src/embedded_curve_ops.nr | 15 - noir/noir-repo/noir_stdlib/src/schnorr.nr | 65 ---- .../Nargo.toml | 7 - .../src/main.nr | 19 -- .../macro_result_type/Nargo.toml | 0 .../macro_result_type/src/main.nr | 0 .../macro_result_type/t.rs | 12 + .../execution_success/schnorr/src/main.nr | 104 ++++++- 13 files changed, 137 insertions(+), 435 deletions(-) delete mode 100644 noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml delete mode 100644 noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr rename noir/noir-repo/test_programs/{compile_failure => compile_success_empty}/macro_result_type/Nargo.toml (100%) rename noir/noir-repo/test_programs/{compile_failure => compile_success_empty}/macro_result_type/src/main.nr (100%) create mode 100644 noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs diff --git a/.noir-sync-commit b/.noir-sync-commit index 23a34206cf8e..87a3bf56846e 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -0864e7c945089cc06f8cc9e5c7d933c465d8c892 +1df102a1ee0eb39dcbada50e10b226c7f7be0f26 \ No newline at end of file diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs index b7b25c6db494..945b768efcf9 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs @@ -1,290 +1,26 @@ use acvm::{acir::brillig::MemoryAddress, AcirField}; -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use super::{debug_show::DebugToString, registers::RegisterAllocator, BrilligContext}; impl BrilligContext { /// This function moves values from a set of registers to another set of registers. - /// The only requirement is that every destination needs to be written at most once. + /// It first moves all sources to new allocated registers to avoid overwriting. pub(crate) fn codegen_mov_registers_to_registers( &mut self, sources: Vec, destinations: Vec, ) { - assert_eq!(sources.len(), destinations.len()); - // Remove all no-ops - let movements: Vec<_> = sources - .into_iter() - .zip(destinations) - .filter(|(source, destination)| source != destination) + let new_sources: Vec<_> = sources + .iter() + .map(|source| { + let new_source = self.allocate_register(); + self.mov_instruction(new_source, *source); + new_source + }) .collect(); - - // Now we need to detect all cycles. - // First build a map of the movements. Note that a source could have multiple destinations - let mut movements_map: HashMap> = - movements.into_iter().fold(HashMap::default(), |mut map, (source, destination)| { - map.entry(source).or_default().insert(destination); - map - }); - - let destinations_set: HashSet<_> = movements_map.values().flatten().copied().collect(); - assert_eq!( - destinations_set.len(), - movements_map.values().flatten().count(), - "Multiple moves to the same register found" - ); - - let mut loop_detector = LoopDetector::default(); - loop_detector.collect_loops(&movements_map); - let loops = loop_detector.loops; - // In order to break the loops we need to store one register from each in a temporary and then use that temporary as source. - let mut temporaries = Vec::with_capacity(loops.len()); - for loop_found in loops { - let temp_register = self.allocate_register(); - temporaries.push(temp_register); - let first_source = loop_found.iter().next().unwrap(); - self.mov_instruction(temp_register, *first_source); - let destinations_of_temp = movements_map.remove(first_source).unwrap(); - movements_map.insert(temp_register, destinations_of_temp); - } - // After removing loops we should have an DAG with each node having only one ancestor (but could have multiple successors) - // Now we should be able to move the registers just by performing a DFS on the movements map - let heads: Vec<_> = movements_map - .keys() - .filter(|source| !destinations_set.contains(source)) - .copied() - .collect(); - for head in heads { - self.perform_movements(&movements_map, head); - } - - // Deallocate all temporaries - for temp in temporaries { - self.deallocate_register(temp); + for (new_source, destination) in new_sources.iter().zip(destinations.iter()) { + self.mov_instruction(*destination, *new_source); + self.deallocate_register(*new_source); } } - - fn perform_movements( - &mut self, - movements: &HashMap>, - current_source: MemoryAddress, - ) { - if let Some(destinations) = movements.get(¤t_source) { - for destination in destinations { - self.perform_movements(movements, *destination); - } - for destination in destinations { - self.mov_instruction(*destination, current_source); - } - } - } -} - -#[derive(Default)] -struct LoopDetector { - visited_sources: HashSet, - loops: Vec>, -} - -impl LoopDetector { - fn collect_loops(&mut self, movements: &HashMap>) { - for source in movements.keys() { - self.find_loop_recursive(*source, movements, im::OrdSet::default()); - } - } - - fn find_loop_recursive( - &mut self, - source: MemoryAddress, - movements: &HashMap>, - mut previous_sources: im::OrdSet, - ) { - if self.visited_sources.contains(&source) { - return; - } - // Mark as visited - self.visited_sources.insert(source); - - previous_sources.insert(source); - // Get all destinations - if let Some(destinations) = movements.get(&source) { - for destination in destinations { - if previous_sources.contains(destination) { - // Found a loop - let loop_sources = previous_sources.clone(); - self.loops.push(loop_sources); - } else { - self.find_loop_recursive(*destination, movements, previous_sources.clone()); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use acvm::{ - acir::brillig::{MemoryAddress, Opcode}, - FieldElement, - }; - use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; - - use crate::{ - brillig::brillig_ir::{artifact::Label, registers::Stack, BrilligContext}, - ssa::ir::function::FunctionId, - }; - - // Tests for the loop finder - - fn generate_movements_map( - movements: Vec<(usize, usize)>, - ) -> HashMap> { - movements.into_iter().fold(HashMap::default(), |mut map, (source, destination)| { - map.entry(MemoryAddress(source)).or_default().insert(MemoryAddress(destination)); - map - }) - } - - #[test] - fn test_loop_detector_basic_loop() { - let movements = vec![(0, 1), (1, 2), (2, 3), (3, 0)]; - let movements_map = generate_movements_map(movements); - let mut loop_detector = super::LoopDetector::default(); - loop_detector.collect_loops(&movements_map); - assert_eq!(loop_detector.loops.len(), 1); - assert_eq!(loop_detector.loops[0].len(), 4); - } - - #[test] - fn test_loop_detector_no_loop() { - let movements = vec![(0, 1), (1, 2), (2, 3), (3, 4)]; - let movements_map = generate_movements_map(movements); - let mut loop_detector = super::LoopDetector::default(); - loop_detector.collect_loops(&movements_map); - assert_eq!(loop_detector.loops.len(), 0); - } - - #[test] - fn test_loop_detector_loop_with_branch() { - let movements = vec![(0, 1), (1, 2), (2, 0), (0, 3), (3, 4)]; - let movements_map = generate_movements_map(movements); - let mut loop_detector = super::LoopDetector::default(); - loop_detector.collect_loops(&movements_map); - assert_eq!(loop_detector.loops.len(), 1); - assert_eq!(loop_detector.loops[0].len(), 3); - } - - #[test] - fn test_loop_detector_two_loops() { - let movements = vec![(0, 1), (1, 2), (2, 0), (3, 4), (4, 5), (5, 3)]; - let movements_map = generate_movements_map(movements); - let mut loop_detector = super::LoopDetector::default(); - loop_detector.collect_loops(&movements_map); - assert_eq!(loop_detector.loops.len(), 2); - assert_eq!(loop_detector.loops[0].len(), 3); - assert_eq!(loop_detector.loops[1].len(), 3); - } - - // Tests for mov_registers_to_registers - - fn movements_to_source_and_destinations( - movements: Vec<(usize, usize)>, - ) -> (Vec, Vec) { - let sources = movements.iter().map(|(source, _)| MemoryAddress::from(*source)).collect(); - let destinations = - movements.iter().map(|(_, destination)| MemoryAddress::from(*destination)).collect(); - (sources, destinations) - } - - pub(crate) fn create_context() -> BrilligContext { - let mut context = BrilligContext::new(true); - context.enter_context(Label::function(FunctionId::test_new(0))); - context - } - - #[test] - #[should_panic(expected = "Multiple moves to the same register found")] - fn test_mov_registers_to_registers_overwrite() { - let movements = vec![(10, 11), (12, 11), (10, 13)]; - let (sources, destinations) = movements_to_source_and_destinations(movements); - let mut context = create_context(); - - context.codegen_mov_registers_to_registers(sources, destinations); - } - - #[test] - fn test_mov_registers_to_registers_no_loop() { - let movements = vec![(10, 11), (11, 12), (12, 13), (13, 14)]; - let (sources, destinations) = movements_to_source_and_destinations(movements); - let mut context = create_context(); - - context.codegen_mov_registers_to_registers(sources, destinations); - let opcodes = context.artifact().byte_code; - assert_eq!( - opcodes, - vec![ - Opcode::Mov { destination: MemoryAddress(14), source: MemoryAddress(13) }, - Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(12) }, - Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, - Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(10) }, - ] - ); - } - #[test] - fn test_mov_registers_to_registers_no_op_filter() { - let movements = vec![(10, 11), (11, 11), (11, 12)]; - let (sources, destinations) = movements_to_source_and_destinations(movements); - let mut context = create_context(); - - context.codegen_mov_registers_to_registers(sources, destinations); - let opcodes = context.artifact().byte_code; - assert_eq!( - opcodes, - vec![ - Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, - Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(10) }, - ] - ); - } - - #[test] - fn test_mov_registers_to_registers_loop() { - let movements = vec![(10, 11), (11, 12), (12, 13), (13, 10)]; - let (sources, destinations) = movements_to_source_and_destinations(movements); - let mut context = create_context(); - - context.codegen_mov_registers_to_registers(sources, destinations); - let opcodes = context.artifact().byte_code; - assert_eq!( - opcodes, - vec![ - Opcode::Mov { destination: MemoryAddress(3), source: MemoryAddress(10) }, - Opcode::Mov { destination: MemoryAddress(10), source: MemoryAddress(13) }, - Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(12) }, - Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, - Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(3) } - ] - ); - } - - #[test] - fn test_mov_registers_to_registers_loop_and_branch() { - let movements = vec![(10, 11), (11, 12), (12, 10), (10, 13), (13, 14)]; - let (sources, destinations) = movements_to_source_and_destinations(movements); - let mut context = create_context(); - - context.codegen_mov_registers_to_registers(sources, destinations); - let opcodes = context.artifact().byte_code; - assert_eq!( - opcodes, - vec![ - Opcode::Mov { destination: MemoryAddress(3), source: MemoryAddress(10) }, // Temporary - Opcode::Mov { destination: MemoryAddress(14), source: MemoryAddress(13) }, // Branch - Opcode::Mov { destination: MemoryAddress(10), source: MemoryAddress(12) }, // Loop - Opcode::Mov { destination: MemoryAddress(12), source: MemoryAddress(11) }, // Loop - Opcode::Mov { destination: MemoryAddress(13), source: MemoryAddress(3) }, // Finish branch - Opcode::Mov { destination: MemoryAddress(11), source: MemoryAddress(3) } // Finish loop - ] - ); - } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs index 264b83956f85..6be18df7b52f 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs @@ -633,7 +633,7 @@ impl<'context> Elaborator<'context> { 0 } - pub(super) fn unify( + pub fn unify( &mut self, actual: &Type, expected: &Type, @@ -644,22 +644,6 @@ impl<'context> Elaborator<'context> { } } - /// Do not apply type bindings even after a successful unification. - /// This function is used by the interpreter for some comptime code - /// which can change types e.g. on each iteration of a for loop. - pub fn unify_without_applying_bindings( - &mut self, - actual: &Type, - expected: &Type, - file: fm::FileId, - make_error: impl FnOnce() -> TypeCheckError, - ) { - let mut bindings = TypeBindings::new(); - if actual.try_unify(expected, &mut bindings).is_err() { - self.errors.push((make_error().into(), file)); - } - } - /// Wrapper of Type::unify_with_coercions using self.errors pub(super) fn unify_with_coercions( &mut self, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index e920073b4530..b5ed8126e331 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1303,11 +1303,9 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { // Macro calls are typed as type variables during type checking. // Now that we know the type we need to further unify it in case there // are inconsistencies or the type needs to be known. - // We don't commit any type bindings made this way in case the type of - // the macro result changes across loop iterations. let expected_type = self.elaborator.interner.id_type(id); let actual_type = result.get_type(); - self.unify_without_binding(&actual_type, &expected_type, location); + self.unify(&actual_type, &expected_type, location); } Ok(result) } @@ -1321,14 +1319,16 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } } - fn unify_without_binding(&mut self, actual: &Type, expected: &Type, location: Location) { - self.elaborator.unify_without_applying_bindings(actual, expected, location.file, || { - TypeCheckError::TypeMismatch { - expected_typ: expected.to_string(), - expr_typ: actual.to_string(), - expr_span: location.span, - } + fn unify(&mut self, actual: &Type, expected: &Type, location: Location) { + // We need to swap out the elaborator's file since we may be + // in a different one currently, and it uses that for the error location. + let old_file = std::mem::replace(&mut self.elaborator.file, location.file); + self.elaborator.unify(actual, expected, || TypeCheckError::TypeMismatch { + expected_typ: expected.to_string(), + expr_typ: actual.to_string(), + expr_span: location.span, }); + self.elaborator.file = old_file; } fn evaluate_method_call( diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index 672328c05bda..22de18b64618 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -3722,29 +3722,5 @@ fn use_numeric_generic_in_trait_method() { "#; let errors = get_program_errors(src); - println!("{errors:?}"); assert_eq!(errors.len(), 0); } - -#[test] -fn macro_result_type_mismatch() { - let src = r#" - fn main() { - comptime { - let x = unquote!(quote { "test" }); - let _: Field = x; - } - } - - comptime fn unquote(q: Quoted) -> Quoted { - q - } - "#; - - let errors = get_program_errors(src); - assert_eq!(errors.len(), 1); - assert!(matches!( - errors[0].0, - CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) - )); -} diff --git a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr index 38bc1764b643..d93b4f41cf0b 100644 --- a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr +++ b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr @@ -71,21 +71,6 @@ impl EmbeddedCurveScalar { let (a,b) = crate::field::bn254::decompose(scalar); EmbeddedCurveScalar { lo: a, hi: b } } - - //Bytes to scalar: take the first (after the specified offset) 16 bytes of the input as the lo value, and the next 16 bytes as the hi value - #[field(bn254)] - fn from_bytes(bytes: [u8; 64], offset: u32) -> EmbeddedCurveScalar { - let mut v = 1; - let mut lo = 0 as Field; - let mut hi = 0 as Field; - for i in 0..16 { - lo = lo + (bytes[offset+31 - i] as Field) * v; - hi = hi + (bytes[offset+15 - i] as Field) * v; - v = v * 256; - } - let sig_s = crate::embedded_curve_ops::EmbeddedCurveScalar { lo, hi }; - sig_s - } } impl Eq for EmbeddedCurveScalar { diff --git a/noir/noir-repo/noir_stdlib/src/schnorr.nr b/noir/noir-repo/noir_stdlib/src/schnorr.nr index 336041fec19e..24ca514025c1 100644 --- a/noir/noir-repo/noir_stdlib/src/schnorr.nr +++ b/noir/noir-repo/noir_stdlib/src/schnorr.nr @@ -1,6 +1,3 @@ -use crate::collections::vec::Vec; -use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar}; - #[foreign(schnorr_verify)] // docs:start:schnorr_verify pub fn verify_signature( @@ -23,65 +20,3 @@ pub fn verify_signature_slice( // docs:end:schnorr_verify_slice {} -pub fn verify_signature_noir(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) -> bool { - //scalar lo/hi from bytes - let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0); - let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32); - // pub_key is on Grumpkin curve - let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17) - & (!public_key.is_infinite); - - if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) { - let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message); - - is_ok = !r_is_infinite; - for i in 0..32 { - is_ok &= result[i] == signature[32 + i]; - } - } - is_ok -} - -pub fn assert_valid_signature(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) { - //scalar lo/hi from bytes - let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0); - let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32); - - // assert pub_key is on Grumpkin curve - assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17); - assert(public_key.is_infinite == false); - // assert signature is not null - assert((sig_s.lo != 0) | (sig_s.hi != 0)); - assert((sig_e.lo != 0) | (sig_e.hi != 0)); - - let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message); - - assert(!r_is_infinite); - for i in 0..32 { - assert(result[i] == signature[32 + i]); - } -} - -fn calculate_signature_challenge( - public_key: EmbeddedCurvePoint, - sig_s: EmbeddedCurveScalar, - sig_e: EmbeddedCurveScalar, - message: [u8; N] -) -> (bool, [u8; 32]) { - let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; - let r = crate::embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); - // compare the _hashes_ rather than field elements modulo r - let pedersen_hash = crate::hash::pedersen_hash([r.x, public_key.x, public_key.y]); - let pde: [u8; 32] = pedersen_hash.to_be_bytes(); - - let mut hash_input = [0; N + 32]; - for i in 0..32 { - hash_input[i] = pde[i]; - } - for i in 0..N { - hash_input[32+i] = message[i]; - } - - let result = crate::hash::blake2s(hash_input); - (r.is_infinite, result) -} diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml deleted file mode 100644 index 38e72395bb53..000000000000 --- a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/Nargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "comptime_change_type_each_iteration" -type = "bin" -authors = [""] -compiler_version = ">=0.34.0" - -[dependencies] diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr deleted file mode 100644 index 7b34c112d4ff..000000000000 --- a/noir/noir-repo/test_programs/compile_success_empty/comptime_change_type_each_iteration/src/main.nr +++ /dev/null @@ -1,19 +0,0 @@ -fn main() { - comptime - { - for i in 9..11 { - // Lengths are different on each iteration: - // foo9, foo10 - let name = f"foo{i}".as_ctstring().as_quoted_str!(); - - // So to call `from_signature` we need to delay the type check - // by quoting the function call so that we re-typecheck on each iteration - let hash = std::meta::unquote!(quote { from_signature($name) }); - assert(hash > 3); - } - } -} - -fn from_signature(_signature: str) -> u32 { - N -} diff --git a/noir/noir-repo/test_programs/compile_failure/macro_result_type/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/Nargo.toml similarity index 100% rename from noir/noir-repo/test_programs/compile_failure/macro_result_type/Nargo.toml rename to noir/noir-repo/test_programs/compile_success_empty/macro_result_type/Nargo.toml diff --git a/noir/noir-repo/test_programs/compile_failure/macro_result_type/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/src/main.nr similarity index 100% rename from noir/noir-repo/test_programs/compile_failure/macro_result_type/src/main.nr rename to noir/noir-repo/test_programs/compile_success_empty/macro_result_type/src/main.nr diff --git a/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs b/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs new file mode 100644 index 000000000000..bcd91d7bf5dc --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/macro_result_type/t.rs @@ -0,0 +1,12 @@ + +trait Foo { + fn foo() {} +} + +impl Foo<3> for () { + fn foo() {} +} + +fn main() { + let _ = Foo::foo(); +} diff --git a/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr b/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr index 835ea2ffb1ff..b64078e6b46b 100644 --- a/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr +++ b/noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr @@ -12,6 +12,11 @@ fn main( // Regression for issue #2421 // We want to make sure that we can accurately verify a signature whose message is a slice vs. an array let message_field_bytes: [u8; 10] = message_field.to_be_bytes(); + let mut message2 = [0; 42]; + for i in 0..10 { + assert(message[i] == message_field_bytes[i]); + message2[i] = message[i]; + } // Is there ever a situation where someone would want // to ensure that a signature was invalid? @@ -22,7 +27,102 @@ fn main( let valid_signature = std::schnorr::verify_signature(pub_key_x, pub_key_y, signature, message); assert(valid_signature); let pub_key = embedded_curve_ops::EmbeddedCurvePoint { x: pub_key_x, y: pub_key_y, is_infinite: false }; - let valid_signature = std::schnorr::verify_signature_noir(pub_key, signature, message); + let valid_signature = verify_signature_noir(pub_key, signature, message2); assert(valid_signature); - std::schnorr::assert_valid_signature(pub_key, signature, message); + assert_valid_signature(pub_key, signature, message2); +} + +// TODO: to put in the stdlib once we have numeric generics +// Meanwhile, you have to use a message with 32 additional bytes: +// If you want to verify a signature on a message of 10 bytes, you need to pass a message of length 42, +// where the first 10 bytes are the one from the original message (the other bytes are not used) +pub fn verify_signature_noir( + public_key: embedded_curve_ops::EmbeddedCurvePoint, + signature: [u8; 64], + message: [u8; M] +) -> bool { + let N = message.len() - 32; + + //scalar lo/hi from bytes + let sig_s = bytes_to_scalar(signature, 0); + let sig_e = bytes_to_scalar(signature, 32); + // pub_key is on Grumpkin curve + let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17) + & (!public_key.is_infinite); + + if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) { + let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); + // compare the _hashes_ rather than field elements modulo r + let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]); + let mut hash_input = [0; M]; + let pde: [u8; 32] = pedersen_hash.to_be_bytes(); + + for i in 0..32 { + hash_input[i] = pde[i]; + } + for i in 0..N { + hash_input[32+i] = message[i]; + } + let result = std::hash::blake2s(hash_input); + + is_ok = !r.is_infinite; + for i in 0..32 { + if result[i] != signature[32 + i] { + is_ok = false; + } + } + } + is_ok +} + +pub fn bytes_to_scalar(bytes: [u8; 64], offset: u32) -> embedded_curve_ops::EmbeddedCurveScalar { + let mut v = 1; + let mut lo = 0 as Field; + let mut hi = 0 as Field; + for i in 0..16 { + lo = lo + (bytes[offset+31 - i] as Field) * v; + hi = hi + (bytes[offset+15 - i] as Field) * v; + v = v * 256; + } + let sig_s = embedded_curve_ops::EmbeddedCurveScalar { lo, hi }; + sig_s +} + +pub fn assert_valid_signature( + public_key: embedded_curve_ops::EmbeddedCurvePoint, + signature: [u8; 64], + message: [u8; M] +) { + let N = message.len() - 32; + //scalar lo/hi from bytes + let sig_s = bytes_to_scalar(signature, 0); + let sig_e = bytes_to_scalar(signature, 32); + + // assert pub_key is on Grumpkin curve + assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17); + assert(public_key.is_infinite == false); + // assert signature is not null + assert((sig_s.lo != 0) | (sig_s.hi != 0)); + assert((sig_e.lo != 0) | (sig_e.hi != 0)); + + let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]); + // compare the _hashes_ rather than field elements modulo r + let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]); + let mut hash_input = [0; M]; + let pde: [u8; 32] = pedersen_hash.to_be_bytes(); + + for i in 0..32 { + hash_input[i] = pde[i]; + } + for i in 0..N { + hash_input[32+i] = message[i]; + } + let result = std::hash::blake2s(hash_input); + + assert(!r.is_infinite); + for i in 0..32 { + assert(result[i] == signature[32 + i]); + } }