From 0d0d8f7d2b401eb6b534dbb175dfd4b26d2a5f7d Mon Sep 17 00:00:00 2001 From: Alex Vitkov <44268717+alexvitkov@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:31:52 +0300 Subject: [PATCH] feat(traits): Improve support for traits static method resolution (#2958) --- .../src/hir/def_collector/dc_crate.rs | 10 ++ .../src/hir/def_collector/dc_mod.rs | 18 ++- .../src/hir/resolution/resolver.rs | 133 ++++++++++++++---- .../noirc_frontend/src/hir/type_check/expr.rs | 23 ++- compiler/noirc_frontend/src/hir_def/expr.rs | 10 +- compiler/noirc_frontend/src/hir_def/traits.rs | 11 +- .../src/monomorphization/mod.rs | 15 +- compiler/noirc_frontend/src/node_interner.rs | 12 ++ .../execution_success/trait_self/src/main.nr | 14 ++ .../trait_where_clause/src/main.nr | 20 ++- .../trait_where_clause/src/the_trait.nr | 6 + 11 files changed, 218 insertions(+), 54 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 07ad6f598bd..eef3bbb3700 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -36,6 +36,7 @@ use std::vec; pub struct UnresolvedFunctions { pub file_id: FileId, pub functions: Vec<(LocalModuleId, FuncId, NoirFunction)>, + pub trait_id: Option, } impl UnresolvedFunctions { @@ -485,7 +486,9 @@ fn collect_trait_impl_methods( errors.push((error.into(), trait_impl.file_id)); } } + trait_impl.methods.functions = ordered_methods; + trait_impl.methods.trait_id = Some(trait_id); errors } @@ -993,6 +996,12 @@ fn resolve_trait_impls( errors, ); + if let Some(trait_id) = maybe_trait_id { + for (_, func) in &impl_methods { + interner.set_function_trait(*func, self_type.clone(), trait_id); + } + } + let mut new_resolver = Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id); new_resolver.set_self_type(Some(self_type.clone())); @@ -1156,6 +1165,7 @@ fn resolve_function_set( // TypeVariables for the same generic, causing it to instantiate incorrectly. resolver.set_generics(impl_generics.clone()); resolver.set_self_type(self_type.clone()); + resolver.set_trait_id(unresolved_functions.trait_id); let (hir_func, func_meta, errs) = resolver.resolve_function(func, func_id); interner.push_fn_meta(func_meta, func_id); diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 189cfaa1569..748c1dd26cd 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -117,8 +117,11 @@ impl<'a> ModCollector<'a> { let module_id = ModuleId { krate, local_id: self.module_id }; for r#impl in impls { - let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() }; + let mut unresolved_functions = UnresolvedFunctions { + file_id: self.file_id, + functions: Vec::new(), + trait_id: None, + }; for method in r#impl.methods { let func_id = context.def_interner.push_empty_fn(); @@ -171,7 +174,7 @@ impl<'a> ModCollector<'a> { krate: CrateId, ) -> UnresolvedFunctions { let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() }; + UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None }; let module = ModuleId { krate, local_id: self.module_id }; @@ -193,7 +196,7 @@ impl<'a> ModCollector<'a> { krate: CrateId, ) -> Vec<(CompilationError, FileId)> { let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() }; + UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None }; let mut errors = vec![]; let module = ModuleId { krate, local_id: self.module_id }; @@ -351,8 +354,11 @@ impl<'a> ModCollector<'a> { } // Add all functions that have a default implementation in the trait - let mut unresolved_functions = - UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() }; + let mut unresolved_functions = UnresolvedFunctions { + file_id: self.file_id, + functions: Vec::new(), + trait_id: None, + }; for trait_item in &trait_definition.items { // TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629 if let TraitItem::Function { diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 94ef9ec9d4e..ef66ba5e032 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -37,8 +37,8 @@ use crate::{ }; use crate::{ ArrayLiteral, ContractFunctionType, Distinctness, Generics, LValue, NoirStruct, NoirTypeAlias, - Path, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable, UnaryOp, - UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, + Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable, + UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT, }; use fm::FileId; @@ -78,6 +78,8 @@ pub struct Resolver<'a> { scopes: ScopeForest, path_resolver: &'a dyn PathResolver, def_maps: &'a BTreeMap, + trait_id: Option, + trait_bounds: Vec, pub interner: &'a mut NodeInterner, errors: Vec, file: FileId, @@ -120,6 +122,8 @@ impl<'a> Resolver<'a> { Self { path_resolver, def_maps, + trait_id: None, + trait_bounds: Vec::new(), scopes: ScopeForest::default(), interner, self_type: None, @@ -134,6 +138,10 @@ impl<'a> Resolver<'a> { self.self_type = self_type; } + pub fn set_trait_id(&mut self, trait_id: Option) { + self.trait_id = trait_id; + } + pub fn get_self_type(&mut self) -> Option<&Type> { self.self_type.as_ref() } @@ -158,12 +166,14 @@ impl<'a> Resolver<'a> { self.resolve_local_globals(); self.add_generics(&func.def.generics); + self.trait_bounds = func.def.where_clause.clone(); let (hir_func, func_meta) = self.intern_function(func, func_id); let func_scope_tree = self.scopes.end_function(); self.check_for_unused_variables_in_scope_tree(func_scope_tree); + self.trait_bounds.clear(); (hir_func, func_meta, self.errors) } @@ -1075,39 +1085,43 @@ impl<'a> Resolver<'a> { Literal::Unit => HirLiteral::Unit, }), ExpressionKind::Variable(path) => { - // If the Path is being used as an Expression, then it is referring to a global from a separate module - // Otherwise, then it is referring to an Identifier - // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; - // If the expression is a singular indent, we search the resolver's current scope as normal. - let (hir_ident, var_scope_index) = self.get_ident_from_path(path); - - if hir_ident.id != DefinitionId::dummy_id() { - match self.interner.definition(hir_ident.id).kind { - DefinitionKind::Function(id) => { - if self.interner.function_visibility(id) == Visibility::Private { - let span = hir_ident.location.span; - self.check_can_reference_private_function(id, span); + if let Some(expr) = self.resolve_trait_generic_path(&path) { + expr + } else { + // If the Path is being used as an Expression, then it is referring to a global from a separate module + // Otherwise, then it is referring to an Identifier + // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; + // If the expression is a singular indent, we search the resolver's current scope as normal. + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(id) => { + if self.interner.function_visibility(id) == Visibility::Private { + let span = hir_ident.location.span; + self.check_can_reference_private_function(id, span); + } } - } - DefinitionKind::Global(_) => {} - DefinitionKind::GenericType(_) => { - // Initialize numeric generics to a polymorphic integer type in case - // they're used in expressions. We must do this here since the type - // checker does not check definition kinds and otherwise expects - // parameters to already be typed. - if self.interner.id_type(hir_ident.id) == Type::Error { - let typ = Type::polymorphic_integer(self.interner); - self.interner.push_definition_type(hir_ident.id, typ); + DefinitionKind::Global(_) => {} + DefinitionKind::GenericType(_) => { + // Initialize numeric generics to a polymorphic integer type in case + // they're used in expressions. We must do this here since the type + // checker does not check definition kinds and otherwise expects + // parameters to already be typed. + if self.interner.id_type(hir_ident.id) == Type::Error { + let typ = Type::polymorphic_integer(self.interner); + self.interner.push_definition_type(hir_ident.id, typ); + } + } + DefinitionKind::Local(_) => { + // only local variables can be captured by closures. + self.resolve_local_variable(hir_ident, var_scope_index); } - } - DefinitionKind::Local(_) => { - // only local variables can be captured by closures. - self.resolve_local_variable(hir_ident, var_scope_index); } } - } - HirExpression::Ident(hir_ident) + HirExpression::Ident(hir_ident) + } } ExpressionKind::Prefix(prefix) => { let operator = prefix.operator; @@ -1445,6 +1459,65 @@ impl<'a> Resolver<'a> { self.lookup(path).ok().map(|id| self.interner.get_type_alias(id)) } + // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) + fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option { + if let Some(trait_id) = self.trait_id { + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let name = &path.segments[0].0.contents; + let method = &path.segments[1]; + + if name == SELF_TYPE_NAME { + let the_trait = self.interner.get_trait(trait_id); + + if let Some(method) = the_trait.find_method(method.clone()) { + let self_type = Type::TypeVariable( + the_trait.self_type_typevar, + crate::TypeVariableKind::Normal, + ); + return Some(HirExpression::TraitMethodReference(self_type, method)); + } + } + } + } + None + } + + // this resolves a static trait method T::trait_method by iterating over the where clause + fn resolve_trait_method_by_named_generic(&mut self, path: &Path) -> Option { + if path.segments.len() != 2 { + return None; + } + + for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() { + if let UnresolvedTypeData::Named(constraint_path, _) = &typ.typ { + // if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait` + if constraint_path.segments.len() == 1 + && path.segments[0] != constraint_path.last_segment() + { + continue; + } + + if let Ok(ModuleDefId::TraitId(trait_id)) = + self.path_resolver.resolve(self.def_maps, trait_bound.trait_path.clone()) + { + let the_trait = self.interner.get_trait(trait_id); + if let Some(method) = + the_trait.find_method(path.segments.last().unwrap().clone()) + { + let self_type = self.resolve_type(typ.clone()); + return Some(HirExpression::TraitMethodReference(self_type, method)); + } + } + } + } + None + } + + fn resolve_trait_generic_path(&mut self, path: &Path) -> Option { + self.resolve_trait_static_method_by_self(path) + .or_else(|| self.resolve_trait_method_by_named_generic(path)) + } + fn resolve_path(&mut self, path: Path) -> Result { self.path_resolver.resolve(self.def_maps, path).map_err(ResolverError::PathResolutionError) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 02241434729..c802482d9e0 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -174,7 +174,7 @@ impl<'interner> TypeChecker<'interner> { } let (function_id, function_call) = method_call.into_function_call( - method_ref, + method_ref.clone(), location, self.interner, ); @@ -291,7 +291,19 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } - HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"), + HirExpression::TraitMethodReference(_, method) => { + let the_trait = self.interner.get_trait(method.trait_id); + let method = &the_trait.methods[method.method_index]; + + let typ = Type::Function( + method.arguments.clone(), + Box::new(method.return_type.clone()), + Box::new(Type::Unit), + ); + let (typ, bindings) = typ.instantiate(self.interner); + self.interner.store_instantiation_bindings(*expr_id, bindings); + typ + } }; self.interner.push_expr_type(expr_id, typ.clone()); @@ -498,7 +510,7 @@ impl<'interner> TypeChecker<'interner> { (func_meta.typ, param_len) } - HirMethodReference::TraitMethodId(method) => { + HirMethodReference::TraitMethodId(_, method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; @@ -863,7 +875,10 @@ impl<'interner> TypeChecker<'interner> { if method.name.0.contents == method_name { let trait_method = TraitMethodId { trait_id: constraint.trait_id, method_index }; - return Some(HirMethodReference::TraitMethodId(trait_method)); + return Some(HirMethodReference::TraitMethodId( + object_type.clone(), + trait_method, + )); } } } diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 4989dd12bd6..8ec106c8c37 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -30,7 +30,7 @@ pub enum HirExpression { If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), - TraitMethodReference(TraitMethodId), + TraitMethodReference(Type, TraitMethodId), Error, } @@ -151,7 +151,7 @@ pub struct HirMethodCallExpression { pub location: Location, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum HirMethodReference { /// A method can be defined in a regular `impl` block, in which case /// it's syntax sugar for a normal function call, and can be @@ -161,7 +161,7 @@ pub enum HirMethodReference { /// Or a method can come from a Trait impl block, in which case /// the actual function called will depend on the instantiated type, /// which can be only known during monomorphizaiton. - TraitMethodId(TraitMethodId), + TraitMethodId(Type, TraitMethodId), } impl HirMethodCallExpression { @@ -179,8 +179,8 @@ impl HirMethodCallExpression { let id = interner.function_definition_id(func_id); HirExpression::Ident(HirIdent { location, id }) } - HirMethodReference::TraitMethodId(method_id) => { - HirExpression::TraitMethodReference(method_id) + HirMethodReference::TraitMethodId(typ, method_id) => { + HirExpression::TraitMethodReference(typ, method_id) } }; let func = interner.push_expr(expr); diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index b51405122cd..11e9dde6846 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,6 +1,6 @@ use crate::{ graph::CrateId, - node_interner::{FuncId, TraitId}, + node_interner::{FuncId, TraitId, TraitMethodId}, Generics, Ident, NoirFunction, Type, TypeVariable, TypeVariableId, }; use noirc_errors::Span; @@ -111,6 +111,15 @@ impl Trait { pub fn set_methods(&mut self, methods: Vec) { self.methods = methods; } + + pub fn find_method(&self, name: Ident) -> Option { + for (idx, method) in self.methods.iter().enumerate() { + if method.name == name { + return Some(TraitMethodId { trait_id: self.id, method_index: idx }); + } + } + None + } } impl std::fmt::Display for Trait { diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 0a62b71f105..2af0ac433d1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -203,6 +203,11 @@ impl<'interner> Monomorphizer<'interner> { } fn function(&mut self, f: node_interner::FuncId, id: FuncId) { + if let Some((self_type, trait_id)) = self.interner.get_function_trait(&f) { + let the_trait = self.interner.get_trait(trait_id); + *the_trait.self_type_typevar.borrow_mut() = TypeBinding::Bound(self_type); + } + let meta = self.interner.function_meta(&f); let modifiers = self.interner.function_modifiers(&f); let name = self.interner.function_name(&f).to_owned(); @@ -378,10 +383,9 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Lambda(lambda) => self.lambda(lambda, expr), - HirExpression::TraitMethodReference(method) => { - if let Type::Function(args, _, _) = self.interner.id_type(expr) { - let self_type = args[0].clone(); - self.resolve_trait_method_reference(self_type, expr, method) + HirExpression::TraitMethodReference(typ, method) => { + if let Type::Function(_, _, _) = self.interner.id_type(expr) { + self.resolve_trait_method_reference(typ, expr, method) } else { unreachable!( "Calling a non-function, this should've been caught in typechecking" @@ -799,9 +803,6 @@ impl<'interner> Monomorphizer<'interner> { ) -> ast::Expression { let function_type = self.interner.id_type(expr_id); - // the substitute() here is to replace all internal occurences of the 'Self' typevar - // with whatever 'Self' is currently bound to, so we don't lose type information - // if we need to rebind the trait. let trait_impl = self .interner .get_trait_implementation(&TraitImplKey { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 5c893682143..84752553585 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -115,6 +115,9 @@ pub struct NodeInterner { /// Methods on primitive types defined in the stdlib. primitive_methods: HashMap<(TypeMethodKey, String), FuncId>, + + // For trait implementation functions, this is their self type and trait they belong to + func_id_to_trait: HashMap, } /// All the information from a function that is filled out during definition collection rather than @@ -364,6 +367,7 @@ impl Default for NodeInterner { function_definition_ids: HashMap::new(), function_modifiers: HashMap::new(), function_modules: HashMap::new(), + func_id_to_trait: HashMap::new(), id_to_location: HashMap::new(), definitions: vec![], id_to_type: HashMap::new(), @@ -628,6 +632,14 @@ impl NodeInterner { self.push_definition(name, false, DefinitionKind::Function(func)) } + pub fn set_function_trait(&mut self, func: FuncId, self_type: Type, trait_id: TraitId) { + self.func_id_to_trait.insert(func, (self_type, trait_id)); + } + + pub fn get_function_trait(&self, func: &FuncId) -> Option<(Type, TraitId)> { + self.func_id_to_trait.get(func).cloned() + } + /// Returns the visibility of the given function. /// /// The underlying function_visibilities map is populated during def collection, diff --git a/tooling/nargo_cli/tests/execution_success/trait_self/src/main.nr b/tooling/nargo_cli/tests/execution_success/trait_self/src/main.nr index c116795a128..cb9619730f7 100644 --- a/tooling/nargo_cli/tests/execution_success/trait_self/src/main.nr +++ b/tooling/nargo_cli/tests/execution_success/trait_self/src/main.nr @@ -1,5 +1,13 @@ trait ATrait { fn asd() -> Self; + + fn static_method() -> Field { + Self::static_method_2() + }; + + fn static_method_2() -> Field { + 100 + }; } struct Foo { @@ -21,7 +29,13 @@ impl ATrait for Bar { fn asd() -> Bar { Bar{x: 100} } + + fn static_method_2() -> Field { + 200 + } } fn main() { + assert(Foo::static_method() == 100); + assert(Bar::static_method() == 200); } \ No newline at end of file diff --git a/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr index aac2362c1d9..891290061c6 100644 --- a/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr +++ b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/main.nr @@ -5,9 +5,10 @@ // - trait impl blocks (impl Foo for Bar where T...) // - structs (struct Foo where T: ...) -// import the trait from another module to ensure the where clauses are ok with that +// import the traits from another module to ensure the where clauses are ok with that mod the_trait; use crate::the_trait::Asd; +use crate::the_trait::StaticTrait; struct Add10 { x: Field, } struct Add20 { x: Field, } @@ -24,10 +25,24 @@ impl Asd for AddXY { } } +struct Static100 {} +impl StaticTrait for Static100 { + // use default implementatino for static_function, which returns 100 +} + +struct Static200 {} +impl StaticTrait for Static200 { + fn static_function(slf: Self) -> Field { 200 } +} + fn assert_asd_eq_100(t: T) where T: crate::the_trait::Asd { assert(t.asd() == 100); } +fn add_one_to_static_function(t: T) -> Field where T: StaticTrait { + T::static_function(t) + 1 +} + fn main() { let x = Add10{ x: 90 }; let z = Add20{ x: 80 }; @@ -38,4 +53,7 @@ fn main() { assert_asd_eq_100(z); assert_asd_eq_100(a); assert_asd_eq_100(xy); + + assert(add_one_to_static_function(Static100{}) == 101); + assert(add_one_to_static_function(Static200{}) == 201); } diff --git a/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/the_trait.nr b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/the_trait.nr index 1b8803fddfd..d84210c4b44 100644 --- a/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/the_trait.nr +++ b/tooling/nargo_cli/tests/execution_success/trait_where_clause/src/the_trait.nr @@ -1,3 +1,9 @@ trait Asd { fn asd(self) -> Field; +} + +trait StaticTrait { + fn static_function(slf: Self) -> Field { + 100 + } } \ No newline at end of file