From 44716fae0bae0f78ceee76f7231af49c4abeace1 Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 11 Oct 2023 11:02:25 -0500 Subject: [PATCH] feat: Implement impl specialization (#3087) --- .../src/hir/def_collector/dc_crate.rs | 15 ++- .../src/hir/def_map/item_scope.rs | 5 + .../src/hir/def_map/module_data.rs | 5 + .../noirc_frontend/src/hir/type_check/expr.rs | 3 +- compiler/noirc_frontend/src/hir_def/types.rs | 4 +- compiler/noirc_frontend/src/node_interner.rs | 92 ++++++++++++++----- .../specialization/Nargo.toml | 7 ++ .../specialization/src/main.nr | 15 +++ 8 files changed, 113 insertions(+), 33 deletions(-) create mode 100644 tooling/nargo_cli/tests/compile_success_empty/specialization/Nargo.toml create mode 100644 tooling/nargo_cli/tests/compile_success_empty/specialization/src/main.nr diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 38fb88f742c..86cdd936b0d 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -394,15 +394,12 @@ fn collect_impls( let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0]; for (_, method_id, method) in &unresolved.functions { - let result = module.declare_function(method.name_ident().clone(), *method_id); - - if let Err((first_def, second_def)) = result { - let error = DefCollectorErrorKind::Duplicate { - typ: DuplicateType::Function, - first_def, - second_def, - }; - errors.push((error.into(), unresolved.file_id)); + // If this method was already declared, remove it from the module so it cannot + // be accessed with the `TypeName::method` syntax. We'll check later whether the + // object types in each method overlap or not. If they do, we issue an error. + // If not, that is specialization which is allowed. + if module.declare_function(method.name_ident().clone(), *method_id).is_err() { + module.remove_function(method.name_ident()); } } // Prohibit defining impls for primitive types if we're not in the stdlib diff --git a/compiler/noirc_frontend/src/hir/def_map/item_scope.rs b/compiler/noirc_frontend/src/hir/def_map/item_scope.rs index 7dcc5051a0c..78375ab2bc0 100644 --- a/compiler/noirc_frontend/src/hir/def_map/item_scope.rs +++ b/compiler/noirc_frontend/src/hir/def_map/item_scope.rs @@ -85,4 +85,9 @@ impl ItemScope { pub fn values(&self) -> &HashMap { &self.values } + + pub fn remove_definition(&mut self, name: &Ident) { + self.types.remove(name); + self.values.remove(name); + } } diff --git a/compiler/noirc_frontend/src/hir/def_map/module_data.rs b/compiler/noirc_frontend/src/hir/def_map/module_data.rs index a1ed5587306..5528312c0fc 100644 --- a/compiler/noirc_frontend/src/hir/def_map/module_data.rs +++ b/compiler/noirc_frontend/src/hir/def_map/module_data.rs @@ -53,6 +53,11 @@ impl ModuleData { self.declare(name, id.into()) } + pub fn remove_function(&mut self, name: &Ident) { + self.scope.remove_definition(name); + self.definitions.remove_definition(name); + } + pub fn declare_global(&mut self, name: Ident, id: StmtId) -> Result<(), (Ident, Ident)> { self.declare(name, id.into()) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 3b61f2b2b6b..b7e9fe953fd 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -826,7 +826,8 @@ impl<'interner> TypeChecker<'interner> { ) -> Option { match object_type { Type::Struct(typ, _args) => { - match self.interner.lookup_method(typ.borrow().id, method_name) { + let id = typ.borrow().id; + match self.interner.lookup_method(object_type, id, method_name, false) { Some(method_id) => Some(HirMethodReference::FuncId(method_id)), None => { self.errors.push(TypeCheckError::UnresolvedMethodCall { diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index e6c9d7bee9a..ef321ee2f71 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -782,7 +782,7 @@ impl Type { /// `try_unify` is a bit of a misnomer since although errors are not committed, /// any unified bindings are on success. - fn try_unify(&self, other: &Type) -> Result<(), UnificationError> { + pub fn try_unify(&self, other: &Type) -> Result<(), UnificationError> { use Type::*; use TypeVariableKind as Kind; @@ -995,7 +995,7 @@ impl Type { /// Instantiate this type, replacing any type variables it is quantified /// over with fresh type variables. If this type is not a Type::Forall, /// it is unchanged. - pub fn instantiate(&self, interner: &mut NodeInterner) -> (Type, TypeBindings) { + pub fn instantiate(&self, interner: &NodeInterner) -> (Type, TypeBindings) { match self { Type::Forall(typevars, typ) => { let replacements = typevars diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 67009746c4d..4f26f212afa 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -87,9 +87,6 @@ pub struct NodeInterner { // Each trait definition is possibly shared across multiple type nodes. // It is also mutated through the RefCell during name resolution to append // methods from impls to the type. - // - // TODO: We may be able to remove the Shared wrapper once traits are no longer types. - // We'd just lookup their methods as needed through the NodeInterner. traits: HashMap, // Trait implementation map @@ -108,10 +105,15 @@ pub struct NodeInterner { globals: HashMap, // NOTE: currently only used for checking repeat globals and restricting their scope to a module - next_type_variable_id: usize, + next_type_variable_id: std::cell::Cell, /// A map from a struct type and method name to a function id for the method. - struct_methods: HashMap<(StructId, String), FuncId>, + /// This can resolve to potentially multiple methods if the same method name is + /// specialized for different generics on the same type. E.g. for `Struct`, we + /// may have both `impl Struct { fn foo(){} }` and `impl Struct { fn foo(){} }`. + /// If this happens, the returned Vec will have 2 entries and we'll need to further + /// disambiguate them by checking the type of each function. + struct_methods: HashMap<(StructId, String), Vec>, /// Methods on primitive types defined in the stdlib. primitive_methods: HashMap<(TypeMethodKey, String), FuncId>, @@ -381,7 +383,7 @@ impl Default for NodeInterner { trait_implementations: HashMap::new(), instantiation_bindings: HashMap::new(), field_indices: HashMap::new(), - next_type_variable_id: 0, + next_type_variable_id: std::cell::Cell::new(0), globals: HashMap::new(), struct_methods: HashMap::new(), primitive_methods: HashMap::new(), @@ -829,13 +831,13 @@ impl NodeInterner { *old = Node::Expression(new); } - pub fn next_type_variable_id(&mut self) -> TypeVariableId { - let id = self.next_type_variable_id; - self.next_type_variable_id += 1; + pub fn next_type_variable_id(&self) -> TypeVariableId { + let id = self.next_type_variable_id.get(); + self.next_type_variable_id.set(id + 1); TypeVariableId(id) } - pub fn next_type_variable(&mut self) -> Type { + pub fn next_type_variable(&self) -> Type { Type::type_variable(self.next_type_variable_id()) } @@ -863,9 +865,10 @@ impl NodeInterner { self.function_definition_ids[&function] } - /// Add a method to a type. - /// This will panic for non-struct types currently as we do not support methods - /// for primitives. We could allow this in the future however. + /// Adds a non-trait method to a type. + /// + /// Returns `Some(duplicate)` if a matching method was already defined. + /// Returns `None` otherwise. pub fn add_method( &mut self, self_type: &Type, @@ -874,8 +877,15 @@ impl NodeInterner { ) -> Option { match self_type { Type::Struct(struct_type, _generics) => { - let key = (struct_type.borrow().id, method_name); - self.struct_methods.insert(key, method_id) + let id = struct_type.borrow().id; + + if let Some(existing) = self.lookup_method(self_type, id, &method_name, true) { + return Some(existing); + } + + let key = (id, method_name); + self.struct_methods.entry(key).or_default().push(method_id); + None } Type::Error => None, @@ -899,11 +909,10 @@ impl NodeInterner { ) -> bool { self.trait_implementations.insert(key.clone(), trait_impl.clone()); match &key.typ { - Type::Struct(struct_type, _generics) => { + Type::Struct(..) => { for func_id in &trait_impl.borrow().methods { let method_name = self.function_name(func_id).to_owned(); - let key = (struct_type.borrow().id, method_name); - self.struct_methods.insert(key, *func_id); + self.add_method(&key.typ, method_name, *func_id); } true } @@ -938,9 +947,50 @@ impl NodeInterner { } } - /// Search by name for a method on the given struct - pub fn lookup_method(&self, id: StructId, method_name: &str) -> Option { - self.struct_methods.get(&(id, method_name.to_owned())).copied() + /// Search by name for a method on the given struct. + /// + /// If `check_type` is true, this will force `lookup_method` to check the type + /// of each candidate instead of returning only the first candidate if there is exactly one. + /// This is generally only desired when declaring new methods to check if they overlap any + /// existing methods. + /// + /// Another detail is that this method does not handle auto-dereferencing through `&mut T`. + /// So if an object is of type `self : &mut T` but a method only accepts `self: T` (or + /// vice-versa), the call will not be selected. If this is ever implemented into this method, + /// we can remove the `methods.len() == 1` check and the `check_type` early return. + pub fn lookup_method( + &self, + typ: &Type, + id: StructId, + method_name: &str, + check_type: bool, + ) -> Option { + let methods = self.struct_methods.get(&(id, method_name.to_owned()))?; + + // If there is only one method, just return it immediately. + // It will still be typechecked later. + if !check_type && methods.len() == 1 { + return Some(methods[0]); + } + + // When adding methods we always check they do not overlap, so there should be + // at most 1 matching method in this list. + for method in methods { + match self.function_meta(method).typ.instantiate(self).0 { + Type::Function(args, _, _) => { + if let Some(object) = args.get(0) { + // TODO #3089: This is dangerous! try_unify may commit type bindings even on failure + if object.try_unify(typ).is_ok() { + return Some(*method); + } + } + } + Type::Error => (), + other => unreachable!("Expected function type, found {other}"), + } + } + + None } /// Looks up a given method name on the given primitive type. diff --git a/tooling/nargo_cli/tests/compile_success_empty/specialization/Nargo.toml b/tooling/nargo_cli/tests/compile_success_empty/specialization/Nargo.toml new file mode 100644 index 00000000000..df379491dc9 --- /dev/null +++ b/tooling/nargo_cli/tests/compile_success_empty/specialization/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "specialization" +type = "bin" +authors = [""] +compiler_version = "0.16.0" + +[dependencies] \ No newline at end of file diff --git a/tooling/nargo_cli/tests/compile_success_empty/specialization/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/specialization/src/main.nr new file mode 100644 index 00000000000..32102ad1437 --- /dev/null +++ b/tooling/nargo_cli/tests/compile_success_empty/specialization/src/main.nr @@ -0,0 +1,15 @@ +struct Foo {} + +impl Foo { + fn foo(_self: Self) -> Field { 1 } +} + +impl Foo { + fn foo(_self: Self) -> Field { 2 } +} + +fn main() { + let f1: Foo = Foo{}; + let f2: Foo = Foo{}; + assert(f1.foo() + f2.foo() == 3); +}