From 0f9af652efc6b7628784a397a9df674eaa30de61 Mon Sep 17 00:00:00 2001 From: jfecher Date: Mon, 30 Oct 2023 17:15:20 -0500 Subject: [PATCH] feat: Allow traits to have generic functions (#3365) --- .../src/hir/def_collector/dc_crate.rs | 6 +++--- .../noirc_frontend/src/hir/type_check/expr.rs | 1 + compiler/noirc_frontend/src/hir_def/traits.rs | 4 +++- compiler/noirc_frontend/src/node_interner.rs | 5 ++++- .../trait_generics/src/main.nr | 19 ++++++++++++++++++- 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index e59ab3e59f9..95ee08b29c5 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -757,7 +757,7 @@ fn resolve_trait_methods( for item in &unresolved_trait.trait_def.items { if let TraitItem::Function { name, - generics: _, + generics, parameters, return_type, where_clause: _, @@ -769,14 +769,14 @@ fn resolve_trait_methods( Type::TypeVariable(the_trait.self_type_typevar.clone(), TypeVariableKind::Normal); let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file); + resolver.add_generics(generics); resolver.set_self_type(Some(self_type)); let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone())); let resolved_return_type = resolver.resolve_type(return_type.get_type().into_owned()); + let generics = resolver.get_generics().to_vec(); let name = name.clone(); - // TODO - let generics: Generics = vec![]; let span: Span = name.span(); let default_impl_list: Vec<_> = unresolved_trait .fns_with_default_impl diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 09dba2cce86..645082c3713 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -266,6 +266,7 @@ impl<'interner> TypeChecker<'interner> { Box::new(method.return_type.clone()), Box::new(Type::Unit), ); + let (typ, bindings) = typ.instantiate(self.interner); self.interner.store_instantiation_bindings(*expr_id, bindings); typ diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 3965cb2412f..0fbe5520b3f 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, @@ -9,7 +11,7 @@ use noirc_errors::Span; #[derive(Clone, Debug, PartialEq, Eq)] pub struct TraitFunction { pub name: Ident, - pub generics: Generics, + pub generics: Vec<(Rc, TypeVariable, Span)>, pub arguments: Vec, pub return_type: Type, pub span: Span, diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index ef5f72df4a2..c527951ea0a 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -953,7 +953,10 @@ impl NodeInterner { ) -> Option> { let impls = self.trait_implementation_map.get(&trait_id)?; for (existing_object_type, impl_id) in impls { - if object_type.try_unify(existing_object_type).is_ok() { + let object_type = object_type.instantiate_named_generics(self); + let existing_object_type = existing_object_type.instantiate_named_generics(self); + + if object_type.try_unify(&existing_object_type).is_ok() { return Some(self.get_trait_implementation(*impl_id)); } } diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr index ccda969f930..504132eea8d 100644 --- a/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr +++ b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr @@ -25,7 +25,12 @@ fn main() { // the first matching one instead of erroring assert(z.foo() == 32); - // Ensure we can call a generic impl + call_impl_with_generic_struct(); + call_impl_with_generic_function(); +} + +// Ensure we can call a generic impl +fn call_impl_with_generic_struct() { let x: u8 = 7; let y: i8 = 8; let s2_u8 = S2 { x }; @@ -43,3 +48,15 @@ struct S2 { x: T } impl T2 for S2 { fn t2(self) -> Self { self } } + +fn call_impl_with_generic_function() { + assert(3.t3(7) == 7); +} + +trait T3 { + fn t3(self, x: T) -> T; +} + +impl T3 for u32 { + fn t3(self, y: U) -> U { y } +}