diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index c537531a748..261a97ec3fb 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -205,7 +205,6 @@ impl<'context> Elaborator<'context> { } this.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls); - this.collect_traits(items.traits); // Must resolve structs before we resolve globals. @@ -224,7 +223,7 @@ impl<'context> Elaborator<'context> { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - for ((typ, module), impls) in &items.impls { + for ((typ, module), impls) in &mut items.impls { this.collect_impls(typ, *module, impls); } @@ -238,8 +237,8 @@ impl<'context> Elaborator<'context> { this.elaborate_functions(functions); } - for ((typ, module), impls) in items.impls { - this.elaborate_impls(typ, module, impls); + for impls in items.impls.into_values() { + this.elaborate_impls(impls); } for trait_impl in items.trait_impls { @@ -251,15 +250,27 @@ impl<'context> Elaborator<'context> { this.errors } + /// Runs `f` and if it modifies `self.generics`, `self.generics` is truncated + /// back to the previous length. + fn recover_generics(&mut self, f: impl FnOnce(&mut Self) -> T) -> T { + let generics_count = self.generics.len(); + let ret = f(self); + self.generics.truncate(generics_count); + ret + } + fn elaborate_functions(&mut self, functions: UnresolvedFunctions) { self.file = functions.file_id; self.trait_id = functions.trait_id; // TODO: Resolve? + self.self_type = functions.self_type; + for (local_module, id, func) in functions.functions { self.local_module = local_module; - let generics_count = self.generics.len(); - self.elaborate_function(func, id); - self.generics.truncate(generics_count); + self.recover_generics(|this| this.elaborate_function(func, id)); } + + self.self_type = None; + self.trait_id = None; } fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) { @@ -277,12 +288,6 @@ impl<'context> Elaborator<'context> { self.resolve_local_globals(); self.trait_bounds = function.def.where_clause.clone(); - let is_low_level_or_oracle = function - .attributes() - .function - .as_ref() - .map_or(false, |func| func.is_low_level() || func.is_oracle()); - if function.def.is_unconstrained { self.in_unconstrained_fn = true; } @@ -297,7 +302,7 @@ impl<'context> Elaborator<'context> { self.elaborate_pattern(parameter.pattern.clone(), param2.1.clone(), definition_kind); } - self.add_generics(&function.def.generics); + self.generics = func_meta.all_generics.clone(); self.desugar_impl_trait_args(&mut function, id); self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type()); self.add_trait_constraints_to_scope(&func_meta); @@ -315,7 +320,8 @@ impl<'context> Elaborator<'context> { } }; - if !func_meta.can_ignore_return_type() { + // Don't verify the return type for builtin functions & trait function declarations + if !func_meta.is_stub() { self.type_check_function_body(body_type, &func_meta, hir_func.as_expr()); } @@ -356,7 +362,7 @@ impl<'context> Elaborator<'context> { let func_scope_tree = self.scopes.end_function(); // The arguments to low-level and oracle functions are always unused so we do not produce warnings for them. - if !is_low_level_or_oracle { + if !func_meta.is_stub() { self.check_for_unused_variables_in_scope_tree(func_scope_tree); } @@ -525,7 +531,12 @@ impl<'context> Elaborator<'context> { /// to be used in analysis and intern the function parameters /// Prerequisite: any implicit generics, including any generics from the impl, /// have already been added to scope via `self.add_generics`. - fn define_function_meta(&mut self, func: &mut NoirFunction, func_id: FuncId) { + fn define_function_meta( + &mut self, + func: &mut NoirFunction, + func_id: FuncId, + is_trait_function: bool, + ) { self.current_function = Some(func_id); self.resolve_where_clause(&mut func.def.where_clause); @@ -652,6 +663,7 @@ impl<'context> Elaborator<'context> { location, typ, direct_generics, + all_generics: self.generics.clone(), trait_impl: self.current_trait_impl, parameters: parameters.into(), return_type: func.def.return_type.clone(), @@ -659,6 +671,7 @@ impl<'context> Elaborator<'context> { has_body: !func.def.body.is_empty(), trait_constraints: self.resolve_trait_constraints(&func.def.where_clause), is_entry_point, + is_trait_function, has_inline_attribute, }; @@ -840,42 +853,10 @@ impl<'context> Elaborator<'context> { } } - fn elaborate_impls( - &mut self, - typ: UnresolvedType, - module: LocalModuleId, - impls: Vec<(Vec, Span, UnresolvedFunctions)>, - ) { - self.local_module = module; - - for (generics, _, functions) in impls { + fn elaborate_impls(&mut self, impls: Vec<(Vec, Span, UnresolvedFunctions)>) { + for (_, _, functions) in impls { self.file = functions.file_id; - let old_generics_length = self.generics.len(); - self.add_generics(&generics); - let self_type = self.resolve_type(typ.clone()); - self.self_type = Some(self_type.clone()); - - let function_ids = vecmap(&functions.functions, |(_, id, _)| *id); - self.elaborate_functions(functions); - - if self_type != Type::Error { - for method_id in function_ids { - let method_name = self.interner.function_name(&method_id).to_owned(); - - if let Some(first_fn) = - self.interner.add_method(&self_type, method_name.clone(), method_id, false) - { - let error = ResolverError::DuplicateDefinition { - name: method_name, - first_span: self.interner.function_ident(&first_fn).span(), - second_span: self.interner.function_ident(&method_id).span(), - }; - self.push_err(error); - } - } - } - - self.generics.truncate(old_generics_length); + self.recover_generics(|this| this.elaborate_functions(functions)); } } @@ -886,7 +867,7 @@ impl<'context> Elaborator<'context> { let unresolved_type = trait_impl.object_type; let self_type_span = unresolved_type.span; let old_generics_length = self.generics.len(); - self.add_generics(&trait_impl.generics); + self.generics = trait_impl.resolved_generics; let trait_generics = vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())); @@ -895,7 +876,6 @@ impl<'context> Elaborator<'context> { let impl_id = trait_impl.impl_id.expect("An impls' id should be set during define_function_metas"); - self.self_type = Some(self_type.clone()); self.current_trait_impl = trait_impl.impl_id; let methods = trait_impl.methods.function_ids(); @@ -960,13 +940,16 @@ impl<'context> Elaborator<'context> { &mut self, self_type: &UnresolvedType, module: LocalModuleId, - impls: &[(Vec, Span, UnresolvedFunctions)], + impls: &mut [(Vec, Span, UnresolvedFunctions)], ) { self.local_module = module; for (generics, span, unresolved) in impls { self.file = unresolved.file_id; - self.declare_method_on_struct(self_type, generics, false, unresolved, *span); + self.recover_generics(|this| { + this.add_generics(generics); + this.declare_methods_on_struct(self_type, false, unresolved, *span); + }); } } @@ -980,8 +963,12 @@ impl<'context> Elaborator<'context> { let span = trait_impl.object_type.span.expect("All trait self types should have spans"); let object_type = &trait_impl.object_type; - let generics = &trait_impl.generics; - self.declare_method_on_struct(object_type, generics, true, &trait_impl.methods, span); + + self.recover_generics(|this| { + this.add_generics(&trait_impl.generics); + trait_impl.resolved_generics = this.generics.clone(); + this.declare_methods_on_struct(object_type, true, &mut trait_impl.methods, span); + }); } } @@ -990,33 +977,33 @@ impl<'context> Elaborator<'context> { &mut self.def_maps.get_mut(&module.krate).expect(message).modules[module.local_id.0] } - fn declare_method_on_struct( + fn declare_methods_on_struct( &mut self, self_type: &UnresolvedType, - generics: &UnresolvedGenerics, is_trait_impl: bool, - functions: &UnresolvedFunctions, + functions: &mut UnresolvedFunctions, span: Span, ) { - let generic_count = self.generics.len(); - self.add_generics(generics); - let typ = self.resolve_type(self_type.clone()); + let self_type = self.resolve_type(self_type.clone()); + + functions.self_type = Some(self_type.clone()); + + let function_ids = functions.function_ids(); - if let Type::Struct(struct_type, _generics) = typ { - let struct_type = struct_type.borrow(); + if let Type::Struct(struct_type, _) = &self_type { + let struct_ref = struct_type.borrow(); // `impl`s are only allowed on types defined within the current crate - if !is_trait_impl && struct_type.id.krate() != self.crate_id { - let type_name = struct_type.name.to_string(); + if !is_trait_impl && struct_ref.id.krate() != self.crate_id { + let type_name = struct_ref.name.to_string(); self.push_err(DefCollectorErrorKind::ForeignImpl { span, type_name }); - self.generics.truncate(generic_count); return; } // Grab the module defined by the struct type. Note that impls are a case // where the module the methods are added to is not the same as the module // they are resolved in. - let module = self.get_module_mut(struct_type.id.module_id()); + let module = self.get_module_mut(struct_ref.id.module_id()); for (_, method_id, method) in &functions.functions { // If this method was already declared, remove it from the module so it cannot @@ -1028,11 +1015,33 @@ impl<'context> Elaborator<'context> { module.remove_function(method.name_ident()); } } - // Prohibit defining impls for primitive types if we're not in the stdlib - } else if !is_trait_impl && typ != Type::Error && !self.crate_id.is_stdlib() { - self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span }); + + self.declare_struct_methods(&self_type, &function_ids); + // We can define methods on primitive types only if we're in the stdlib + } else if !is_trait_impl && self_type != Type::Error { + if self.crate_id.is_stdlib() { + self.declare_struct_methods(&self_type, &function_ids); + } else { + self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span }); + } + } + } + + fn declare_struct_methods(&mut self, self_type: &Type, function_ids: &[FuncId]) { + for method_id in function_ids { + let method_name = self.interner.function_name(method_id).to_owned(); + + if let Some(first_fn) = + self.interner.add_method(self_type, method_name.clone(), *method_id, false) + { + let error = ResolverError::DuplicateDefinition { + name: method_name, + first_span: self.interner.function_ident(&first_fn).span(), + second_span: self.interner.function_ident(method_id).span(), + }; + self.push_err(error); + } } - self.generics.truncate(generic_count); } fn collect_trait_impl_methods( @@ -1203,18 +1212,20 @@ impl<'context> Elaborator<'context> { unresolved: NoirStruct, struct_id: StructId, ) -> (Generics, Vec<(Ident, Type)>) { - let generics = self.add_generics(&unresolved.generics); + self.recover_generics(|this| { + let generics = this.add_generics(&unresolved.generics); - // Check whether the struct definition has globals in the local module and add them to the scope - self.resolve_local_globals(); + // Check whether the struct definition has globals in the local module and add them to the scope + this.resolve_local_globals(); - self.current_item = Some(DependencyId::Struct(struct_id)); + this.current_item = Some(DependencyId::Struct(struct_id)); - self.resolving_ids.insert(struct_id); - let fields = vecmap(unresolved.fields, |(ident, typ)| (ident, self.resolve_type(typ))); - self.resolving_ids.remove(&struct_id); + this.resolving_ids.insert(struct_id); + let fields = vecmap(unresolved.fields, |(ident, typ)| (ident, this.resolve_type(typ))); + this.resolving_ids.remove(&struct_id); - (generics, fields) + (generics, fields) + }) } fn elaborate_global(&mut self, global: UnresolvedGlobal) { @@ -1256,11 +1267,16 @@ impl<'context> Elaborator<'context> { self.define_function_metas_for_functions(function_set); } - for ((_typ, local_module), function_sets) in impls { + for ((self_type, local_module), function_sets) in impls { self.local_module = *local_module; - for (_generics, _, function_set) in function_sets { + for (generics, _, function_set) in function_sets { + self.add_generics(generics); + let self_type = self.resolve_type(self_type.clone()); + function_set.self_type = Some(self_type.clone()); + self.self_type = Some(self_type); self.define_function_metas_for_functions(function_set); + self.generics.clear(); } } @@ -1269,11 +1285,12 @@ impl<'context> Elaborator<'context> { self.local_module = trait_impl.module_id; let unresolved_type = &trait_impl.object_type; - let old_generics_length = self.generics.len(); self.add_generics(&trait_impl.generics); let self_type = self.resolve_type(unresolved_type.clone()); + self.self_type = Some(self_type.clone()); + trait_impl.methods.self_type = Some(self_type); let impl_id = self.interner.next_trait_impl_id(); self.current_trait_impl = Some(impl_id); @@ -1282,7 +1299,7 @@ impl<'context> Elaborator<'context> { trait_impl.resolved_object_type = self.self_type.take(); trait_impl.impl_id = self.current_trait_impl.take(); - self.generics.truncate(old_generics_length); + self.generics.clear(); } } @@ -1291,9 +1308,9 @@ impl<'context> Elaborator<'context> { for (local_module, id, func) in &mut function_set.functions { self.local_module = *local_module; - let old_generics_length = self.generics.len(); - self.define_function_meta(func, *id); - self.generics.truncate(old_generics_length); + self.recover_generics(|this| { + this.define_function_meta(func, *id, false); + }); } } } diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index c2f9a83e559..76cdc592276 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -178,7 +178,7 @@ impl<'context> Elaborator<'context> { }; let mut function = NoirFunction { kind, def }; - self.define_function_meta(&mut function, func_id); + self.define_function_meta(&mut function, func_id, true); self.elaborate_function(function, func_id); let _ = self.scopes.end_function(); // Don't check the scope tree for unused variables, they can't be used in a declaration anyway. diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 60b841699f1..838aac3f067 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -5,7 +5,7 @@ use crate::graph::CrateId; use crate::hir::comptime::{Interpreter, InterpreterError}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; use crate::hir::resolution::errors::ResolverError; -use crate::Type; +use crate::{Type, TypeVariable}; use crate::hir::resolution::import::{resolve_import, ImportDirective, PathResolution}; use crate::hir::resolution::{ @@ -33,6 +33,7 @@ use iter_extended::vecmap; use noirc_errors::{CustomDiagnostic, Span}; use std::collections::{BTreeMap, HashMap}; +use std::rc::Rc; use std::vec; #[derive(Default)] @@ -50,6 +51,9 @@ pub struct UnresolvedFunctions { pub file_id: FileId, pub functions: Vec<(LocalModuleId, FuncId, NoirFunction)>, pub trait_id: Option, + + // The object type this set of functions was declared on, if there is one. + pub self_type: Option, } impl UnresolvedFunctions { @@ -117,10 +121,11 @@ pub struct UnresolvedTraitImpl { pub generics: UnresolvedGenerics, pub where_clause: Vec, - // These fields are filled in later during elaboration + // These fields are filled in later pub trait_id: Option, pub impl_id: Option, pub resolved_object_type: Option, + pub resolved_generics: Vec<(Rc, TypeVariable, Span)>, } #[derive(Clone)] diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 645c41ad33b..536332d8f8a 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -146,6 +146,7 @@ impl<'a> ModCollector<'a> { file_id: self.file_id, functions: Vec::new(), trait_id: None, + self_type: None, }; for (method, _) in r#impl.methods { @@ -195,6 +196,7 @@ impl<'a> ModCollector<'a> { trait_id: None, impl_id: None, resolved_object_type: None, + resolved_generics: Vec::new(), }; self.def_collector.items.trait_impls.push(unresolved_trait_impl); @@ -207,8 +209,12 @@ impl<'a> ModCollector<'a> { trait_impl: &NoirTraitImpl, krate: CrateId, ) -> UnresolvedFunctions { - let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None }; + let mut unresolved_functions = UnresolvedFunctions { + file_id: self.file_id, + functions: Vec::new(), + trait_id: None, + self_type: None, + }; let module = ModuleId { krate, local_id: self.module_id }; @@ -230,8 +236,12 @@ impl<'a> ModCollector<'a> { functions: Vec, krate: CrateId, ) -> Vec<(CompilationError, FileId)> { - let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None }; + let mut unresolved_functions = UnresolvedFunctions { + file_id: self.file_id, + functions: Vec::new(), + trait_id: None, + self_type: None, + }; let mut errors = vec![]; let module = ModuleId { krate, local_id: self.module_id }; @@ -404,6 +414,7 @@ impl<'a> ModCollector<'a> { file_id: self.file_id, functions: Vec::new(), trait_id: None, + self_type: None, }; let mut method_ids = HashMap::new(); diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 8beac340c4b..9fe0d473f15 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -1106,6 +1106,10 @@ impl<'a> Resolver<'a> { trait_constraints: self.resolve_trait_constraints(&func.def.where_clause), is_entry_point: self.is_entry_point_function(func), has_inline_attribute, + + // This is only used by the elaborator + all_generics: Vec::new(), + is_trait_function: false, } } diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index b2a76828c88..70d7c4021ed 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -49,7 +49,7 @@ pub struct TypeChecker<'interner> { pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec { let meta = interner.function_meta(&func_id); let declared_return_type = meta.return_type().clone(); - let can_ignore_ret = meta.can_ignore_return_type(); + let can_ignore_ret = meta.is_stub(); let function_body_id = &interner.function(&func_id).as_expr(); @@ -549,7 +549,9 @@ pub mod test { trait_constraints: Vec::new(), direct_generics: Vec::new(), is_entry_point: true, + is_trait_function: false, has_inline_attribute: false, + all_generics: Vec::new(), }; interner.push_fn_meta(func_meta, func_id); diff --git a/compiler/noirc_frontend/src/hir_def/function.rs b/compiler/noirc_frontend/src/hir_def/function.rs index ceec9ad8580..9e03f074ffe 100644 --- a/compiler/noirc_frontend/src/hir_def/function.rs +++ b/compiler/noirc_frontend/src/hir_def/function.rs @@ -109,6 +109,12 @@ pub struct FuncMeta { /// such as a trait's `Self` type variable. pub direct_generics: Vec<(Rc, TypeVariable)>, + /// All the generics used by this function, which includes any implicit generics or generics + /// from outer scopes, such as those introduced by an impl. + /// This is stored when the FuncMeta is first created to later be used to set the current + /// generics when the function's body is later resolved. + pub all_generics: Vec<(Rc, TypeVariable, Span)>, + pub location: Location, // This flag is needed for the attribute check pass @@ -123,6 +129,11 @@ pub struct FuncMeta { /// For non-contracts, this means the function is `main`. pub is_entry_point: bool, + /// True if this function was defined within a trait (not a trait impl!). + /// Trait functions are just stubs and shouldn't have their return type checked + /// against their body type, nor should unused variables be checked. + pub is_trait_function: bool, + /// True if this function is marked with an attribute /// that indicates it should be inlined differently than the default (inline everything). /// For example, such as `fold` (never inlined) or `no_predicates` (inlined after flattening) @@ -130,12 +141,13 @@ pub struct FuncMeta { } impl FuncMeta { - /// Builtin, LowLevel and Oracle functions usually have the return type - /// declared, however their function bodies will be empty - /// So this method tells the type checker to ignore the return - /// of the empty function, which is unit - pub fn can_ignore_return_type(&self) -> bool { - self.kind.can_ignore_return_type() + /// A stub function does not have a body. This includes Builtin, LowLevel, + /// and Oracle functions in addition to method declarations within a trait. + /// + /// We don't check the return type of these functions since it will always have + /// an empty body, and we don't check for unused parameters. + pub fn is_stub(&self) -> bool { + self.kind.can_ignore_return_type() || self.is_trait_function } pub fn function_signature(&self) -> FunctionSignature {