Skip to content

Commit

Permalink
feat(traits): Improve support for traits static method resolution (#2958
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexvitkov authored Oct 3, 2023
1 parent 6bb337d commit 0d0d8f7
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 54 deletions.
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use std::vec;
pub struct UnresolvedFunctions {
pub file_id: FileId,
pub functions: Vec<(LocalModuleId, FuncId, NoirFunction)>,
pub trait_id: Option<TraitId>,
}

impl UnresolvedFunctions {
Expand Down Expand Up @@ -485,7 +486,9 @@ fn collect_trait_impl_methods(
errors.push((error.into(), trait_impl.file_id));
}
}

trait_impl.methods.functions = ordered_methods;
trait_impl.methods.trait_id = Some(trait_id);
errors
}

Expand Down Expand Up @@ -993,6 +996,12 @@ fn resolve_trait_impls(
errors,
);

if let Some(trait_id) = maybe_trait_id {
for (_, func) in &impl_methods {
interner.set_function_trait(*func, self_type.clone(), trait_id);
}
}

let mut new_resolver =
Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id);
new_resolver.set_self_type(Some(self_type.clone()));
Expand Down Expand Up @@ -1156,6 +1165,7 @@ fn resolve_function_set(
// TypeVariables for the same generic, causing it to instantiate incorrectly.
resolver.set_generics(impl_generics.clone());
resolver.set_self_type(self_type.clone());
resolver.set_trait_id(unresolved_functions.trait_id);

let (hir_func, func_meta, errs) = resolver.resolve_function(func, func_id);
interner.push_fn_meta(func_meta, func_id);
Expand Down
18 changes: 12 additions & 6 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ impl<'a> ModCollector<'a> {
let module_id = ModuleId { krate, local_id: self.module_id };

for r#impl in impls {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
let mut unresolved_functions = UnresolvedFunctions {
file_id: self.file_id,
functions: Vec::new(),
trait_id: None,
};

for method in r#impl.methods {
let func_id = context.def_interner.push_empty_fn();
Expand Down Expand Up @@ -171,7 +174,7 @@ impl<'a> ModCollector<'a> {
krate: CrateId,
) -> UnresolvedFunctions {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None };

let module = ModuleId { krate, local_id: self.module_id };

Expand All @@ -193,7 +196,7 @@ impl<'a> ModCollector<'a> {
krate: CrateId,
) -> Vec<(CompilationError, FileId)> {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None };
let mut errors = vec![];

let module = ModuleId { krate, local_id: self.module_id };
Expand Down Expand Up @@ -351,8 +354,11 @@ impl<'a> ModCollector<'a> {
}

// Add all functions that have a default implementation in the trait
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
let mut unresolved_functions = UnresolvedFunctions {
file_id: self.file_id,
functions: Vec::new(),
trait_id: None,
};
for trait_item in &trait_definition.items {
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629
if let TraitItem::Function {
Expand Down
133 changes: 103 additions & 30 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use crate::{
};
use crate::{
ArrayLiteral, ContractFunctionType, Distinctness, Generics, LValue, NoirStruct, NoirTypeAlias,
Path, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable, UnaryOp,
UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable,
UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use fm::FileId;
Expand Down Expand Up @@ -78,6 +78,8 @@ pub struct Resolver<'a> {
scopes: ScopeForest,
path_resolver: &'a dyn PathResolver,
def_maps: &'a BTreeMap<CrateId, CrateDefMap>,
trait_id: Option<TraitId>,
trait_bounds: Vec<UnresolvedTraitConstraint>,
pub interner: &'a mut NodeInterner,
errors: Vec<ResolverError>,
file: FileId,
Expand Down Expand Up @@ -120,6 +122,8 @@ impl<'a> Resolver<'a> {
Self {
path_resolver,
def_maps,
trait_id: None,
trait_bounds: Vec::new(),
scopes: ScopeForest::default(),
interner,
self_type: None,
Expand All @@ -134,6 +138,10 @@ impl<'a> Resolver<'a> {
self.self_type = self_type;
}

pub fn set_trait_id(&mut self, trait_id: Option<TraitId>) {
self.trait_id = trait_id;
}

pub fn get_self_type(&mut self) -> Option<&Type> {
self.self_type.as_ref()
}
Expand All @@ -158,12 +166,14 @@ impl<'a> Resolver<'a> {
self.resolve_local_globals();

self.add_generics(&func.def.generics);
self.trait_bounds = func.def.where_clause.clone();

let (hir_func, func_meta) = self.intern_function(func, func_id);
let func_scope_tree = self.scopes.end_function();

self.check_for_unused_variables_in_scope_tree(func_scope_tree);

self.trait_bounds.clear();
(hir_func, func_meta, self.errors)
}

Expand Down Expand Up @@ -1075,39 +1085,43 @@ impl<'a> Resolver<'a> {
Literal::Unit => HirLiteral::Unit,
}),
ExpressionKind::Variable(path) => {
// If the Path is being used as an Expression, then it is referring to a global from a separate module
// Otherwise, then it is referring to an Identifier
// This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10;
// If the expression is a singular indent, we search the resolver's current scope as normal.
let (hir_ident, var_scope_index) = self.get_ident_from_path(path);

if hir_ident.id != DefinitionId::dummy_id() {
match self.interner.definition(hir_ident.id).kind {
DefinitionKind::Function(id) => {
if self.interner.function_visibility(id) == Visibility::Private {
let span = hir_ident.location.span;
self.check_can_reference_private_function(id, span);
if let Some(expr) = self.resolve_trait_generic_path(&path) {
expr
} else {
// If the Path is being used as an Expression, then it is referring to a global from a separate module
// Otherwise, then it is referring to an Identifier
// This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10;
// If the expression is a singular indent, we search the resolver's current scope as normal.
let (hir_ident, var_scope_index) = self.get_ident_from_path(path);

if hir_ident.id != DefinitionId::dummy_id() {
match self.interner.definition(hir_ident.id).kind {
DefinitionKind::Function(id) => {
if self.interner.function_visibility(id) == Visibility::Private {
let span = hir_ident.location.span;
self.check_can_reference_private_function(id, span);
}
}
}
DefinitionKind::Global(_) => {}
DefinitionKind::GenericType(_) => {
// Initialize numeric generics to a polymorphic integer type in case
// they're used in expressions. We must do this here since the type
// checker does not check definition kinds and otherwise expects
// parameters to already be typed.
if self.interner.id_type(hir_ident.id) == Type::Error {
let typ = Type::polymorphic_integer(self.interner);
self.interner.push_definition_type(hir_ident.id, typ);
DefinitionKind::Global(_) => {}
DefinitionKind::GenericType(_) => {
// Initialize numeric generics to a polymorphic integer type in case
// they're used in expressions. We must do this here since the type
// checker does not check definition kinds and otherwise expects
// parameters to already be typed.
if self.interner.id_type(hir_ident.id) == Type::Error {
let typ = Type::polymorphic_integer(self.interner);
self.interner.push_definition_type(hir_ident.id, typ);
}
}
DefinitionKind::Local(_) => {
// only local variables can be captured by closures.
self.resolve_local_variable(hir_ident, var_scope_index);
}
}
DefinitionKind::Local(_) => {
// only local variables can be captured by closures.
self.resolve_local_variable(hir_ident, var_scope_index);
}
}
}

HirExpression::Ident(hir_ident)
HirExpression::Ident(hir_ident)
}
}
ExpressionKind::Prefix(prefix) => {
let operator = prefix.operator;
Expand Down Expand Up @@ -1445,6 +1459,65 @@ impl<'a> Resolver<'a> {
self.lookup(path).ok().map(|id| self.interner.get_type_alias(id))
}

// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<HirExpression> {
if let Some(trait_id) = self.trait_id {
if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].0.contents;
let method = &path.segments[1];

if name == SELF_TYPE_NAME {
let the_trait = self.interner.get_trait(trait_id);

if let Some(method) = the_trait.find_method(method.clone()) {
let self_type = Type::TypeVariable(
the_trait.self_type_typevar,
crate::TypeVariableKind::Normal,
);
return Some(HirExpression::TraitMethodReference(self_type, method));
}
}
}
}
None
}

// this resolves a static trait method T::trait_method by iterating over the where clause
fn resolve_trait_method_by_named_generic(&mut self, path: &Path) -> Option<HirExpression> {
if path.segments.len() != 2 {
return None;
}

for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() {
if let UnresolvedTypeData::Named(constraint_path, _) = &typ.typ {
// if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait`
if constraint_path.segments.len() == 1
&& path.segments[0] != constraint_path.last_segment()
{
continue;
}

if let Ok(ModuleDefId::TraitId(trait_id)) =
self.path_resolver.resolve(self.def_maps, trait_bound.trait_path.clone())
{
let the_trait = self.interner.get_trait(trait_id);
if let Some(method) =
the_trait.find_method(path.segments.last().unwrap().clone())
{
let self_type = self.resolve_type(typ.clone());
return Some(HirExpression::TraitMethodReference(self_type, method));
}
}
}
}
None
}

fn resolve_trait_generic_path(&mut self, path: &Path) -> Option<HirExpression> {
self.resolve_trait_static_method_by_self(path)
.or_else(|| self.resolve_trait_method_by_named_generic(path))
}

fn resolve_path(&mut self, path: Path) -> Result<ModuleDefId, ResolverError> {
self.path_resolver.resolve(self.def_maps, path).map_err(ResolverError::PathResolutionError)
}
Expand Down
23 changes: 19 additions & 4 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<'interner> TypeChecker<'interner> {
}

let (function_id, function_call) = method_call.into_function_call(
method_ref,
method_ref.clone(),
location,
self.interner,
);
Expand Down Expand Up @@ -291,7 +291,19 @@ impl<'interner> TypeChecker<'interner> {

Type::Function(params, Box::new(lambda.return_type), Box::new(env_type))
}
HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"),
HirExpression::TraitMethodReference(_, method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];

let typ = Type::Function(
method.arguments.clone(),
Box::new(method.return_type.clone()),
Box::new(Type::Unit),
);
let (typ, bindings) = typ.instantiate(self.interner);
self.interner.store_instantiation_bindings(*expr_id, bindings);
typ
}
};

self.interner.push_expr_type(expr_id, typ.clone());
Expand Down Expand Up @@ -498,7 +510,7 @@ impl<'interner> TypeChecker<'interner> {

(func_meta.typ, param_len)
}
HirMethodReference::TraitMethodId(method) => {
HirMethodReference::TraitMethodId(_, method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];

Expand Down Expand Up @@ -863,7 +875,10 @@ impl<'interner> TypeChecker<'interner> {
if method.name.0.contents == method_name {
let trait_method =
TraitMethodId { trait_id: constraint.trait_id, method_index };
return Some(HirMethodReference::TraitMethodId(trait_method));
return Some(HirMethodReference::TraitMethodId(
object_type.clone(),
trait_method,
));
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub enum HirExpression {
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
TraitMethodReference(TraitMethodId),
TraitMethodReference(Type, TraitMethodId),
Error,
}

Expand Down Expand Up @@ -151,7 +151,7 @@ pub struct HirMethodCallExpression {
pub location: Location,
}

#[derive(Debug, Copy, Clone)]
#[derive(Debug, Clone)]
pub enum HirMethodReference {
/// A method can be defined in a regular `impl` block, in which case
/// it's syntax sugar for a normal function call, and can be
Expand All @@ -161,7 +161,7 @@ pub enum HirMethodReference {
/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.
TraitMethodId(TraitMethodId),
TraitMethodId(Type, TraitMethodId),
}

impl HirMethodCallExpression {
Expand All @@ -179,8 +179,8 @@ impl HirMethodCallExpression {
let id = interner.function_definition_id(func_id);
HirExpression::Ident(HirIdent { location, id })
}
HirMethodReference::TraitMethodId(method_id) => {
HirExpression::TraitMethodReference(method_id)
HirMethodReference::TraitMethodId(typ, method_id) => {
HirExpression::TraitMethodReference(typ, method_id)
}
};
let func = interner.push_expr(expr);
Expand Down
11 changes: 10 additions & 1 deletion compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
graph::CrateId,
node_interner::{FuncId, TraitId},
node_interner::{FuncId, TraitId, TraitMethodId},
Generics, Ident, NoirFunction, Type, TypeVariable, TypeVariableId,
};
use noirc_errors::Span;
Expand Down Expand Up @@ -111,6 +111,15 @@ impl Trait {
pub fn set_methods(&mut self, methods: Vec<TraitFunction>) {
self.methods = methods;
}

pub fn find_method(&self, name: Ident) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if method.name == name {
return Some(TraitMethodId { trait_id: self.id, method_index: idx });
}
}
None
}
}

impl std::fmt::Display for Trait {
Expand Down
Loading

0 comments on commit 0d0d8f7

Please sign in to comment.