Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check trait where clause #6325

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,14 @@ impl<'context> Elaborator<'context> {
self.file = trait_impl.file_id;
self.local_module = trait_impl.module_id;

self.check_parent_traits_are_implemented(&trait_impl);

self.generics = trait_impl.resolved_generics;
self.generics = trait_impl.resolved_generics.clone();
self.current_trait_impl = trait_impl.impl_id;

self.add_trait_impl_assumed_trait_implementations(trait_impl.impl_id);
self.check_trait_impl_where_clause_matches_trait_where_clause(&trait_impl);
self.check_parent_traits_are_implemented(&trait_impl);
self.remove_trait_impl_assumed_trait_implementations(trait_impl.impl_id);

for (module, function, _) in &trait_impl.methods.functions {
self.local_module = *module;
let errors = check_trait_impl_method_matches_declaration(self.interner, *function);
Expand All @@ -1045,6 +1048,95 @@ impl<'context> Elaborator<'context> {
self.generics.clear();
}

fn add_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
let trait_bound = &trait_constrain.trait_bound;
self.interner.add_assumed_trait_implementation(
trait_constrain.typ.clone(),
trait_bound.trait_id,
trait_bound.trait_generics.clone(),
);
}
}
}
}

fn remove_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
self.interner.remove_assumed_trait_implementations_for_trait(
trait_constrain.trait_bound.trait_id,
);
}
}
}
}

fn check_trait_impl_where_clause_matches_trait_where_clause(
&mut self,
trait_impl: &UnresolvedTraitImpl,
) {
let Some(trait_id) = trait_impl.trait_id else {
return;
};

let Some(the_trait) = self.interner.try_get_trait(trait_id) else {
return;
};

if the_trait.where_clause.is_empty() {
return;
}

let impl_trait = the_trait.name.to_string();
let the_trait_file = the_trait.location.file;

let mut bindings = TypeBindings::new();
bind_ordered_generics(
&the_trait.generics,
&trait_impl.resolved_trait_generics,
&mut bindings,
);

// Check that each of the trait's where clause constraints is satisfied
for trait_constraint in the_trait.where_clause.clone() {
let Some(trait_constraint_trait) =
self.interner.try_get_trait(trait_constraint.trait_bound.trait_id)
else {
continue;
};

let trait_constraint_type = trait_constraint.typ.substitute(&bindings);
let trait_bound = &trait_constraint.trait_bound;

if self
.interner
.try_lookup_trait_implementation(
&trait_constraint_type,
trait_bound.trait_id,
&trait_bound.trait_generics.ordered,
&trait_bound.trait_generics.named,
)
.is_err()
{
let missing_trait =
format!("{}{}", trait_constraint_trait.name, trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_constraint_type.to_string(),
span: trait_impl.object_type.span,
missing_trait_location: Location::new(trait_bound.span, the_trait_file),
});
}
}
}

fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) {
let Some(trait_id) = trait_impl.trait_id else {
return;
Expand Down Expand Up @@ -1168,7 +1260,7 @@ impl<'context> Elaborator<'context> {
trait_id,
trait_generics,
file: trait_impl.file_id,
where_clause: where_clause.clone(),
where_clause,
methods,
});

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl<'context> Elaborator<'context> {
&resolved_generics,
);

let where_clause =
this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause);

// Each associated type in this trait is also an implicit generic
for associated_type in &this.interner.get_trait(*trait_id).associated_types {
this.generics.push(associated_type.clone());
Expand All @@ -48,6 +51,7 @@ impl<'context> Elaborator<'context> {
this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub struct Trait {

/// The resolved trait bounds (for example in `trait Foo: Bar + Baz`, this would be `Bar + Baz`)
pub trait_bounds: Vec<ResolvedTraitBound>,

pub where_clause: Vec<TraitConstraint>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -154,6 +156,10 @@ impl Trait {
self.trait_bounds = trait_bounds;
}

pub fn set_where_clause(&mut self, where_clause: Vec<TraitConstraint>) {
self.where_clause = where_clause;
}

pub fn find_method(&self, name: &str) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if &method.name == name {
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@
interned_statement_kinds: noirc_arena::Arena<StatementKind>,

// Interned `UnresolvedTypeData`s during comptime code.
interned_unresolved_type_datas: noirc_arena::Arena<UnresolvedTypeData>,

Check warning on line 222 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)

// Interned `Pattern`s during comptime code.
interned_patterns: noirc_arena::Arena<Pattern>,

/// Determins whether to run in LSP mode. In LSP mode references are tracked.

Check warning on line 227 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Determins)
pub(crate) lsp_mode: bool,

/// Store the location of the references in the graph.
Expand Down Expand Up @@ -669,7 +669,7 @@
quoted_types: Default::default(),
interned_expression_kinds: Default::default(),
interned_statement_kinds: Default::default(),
interned_unresolved_type_datas: Default::default(),

Check warning on line 672 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
interned_patterns: Default::default(),
lsp_mode: false,
location_indices: LocationIndices::default(),
Expand Down Expand Up @@ -735,6 +735,7 @@
method_ids: unresolved_trait.method_ids.clone(),
associated_types,
trait_bounds: Vec::new(),
where_clause: Vec::new(),
};

self.traits.insert(type_id, new_trait);
Expand Down Expand Up @@ -2179,11 +2180,11 @@
&mut self,
typ: UnresolvedTypeData,
) -> InternedUnresolvedTypeData {
InternedUnresolvedTypeData(self.interned_unresolved_type_datas.insert(typ))

Check warning on line 2183 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
}

pub fn get_unresolved_type_data(&self, id: InternedUnresolvedTypeData) -> &UnresolvedTypeData {
&self.interned_unresolved_type_datas[id.0]

Check warning on line 2187 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
}

/// Returns the type of an operator (which is always a function), along with its return type.
Expand Down
15 changes: 9 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1974,7 +1974,7 @@
}

// TODO(https://github.com/noir-lang/noir/issues/6238):
// The EvaluatedGlobalIsntU32 warning is a stopgap

Check warning on line 1977 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Isnt)
// (originally from https://github.com/noir-lang/noir/issues/6125)
#[test]
fn numeric_generic_field_larger_than_u32() {
Expand All @@ -1991,7 +1991,7 @@
assert_eq!(errors.len(), 2);
assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::EvaluatedGlobalIsntU32 { .. }),

Check warning on line 1994 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Isnt)
));
assert!(matches!(
errors[1].0,
Expand All @@ -2000,7 +2000,7 @@
}

// TODO(https://github.com/noir-lang/noir/issues/6238):
// The EvaluatedGlobalIsntU32 warning is a stopgap

Check warning on line 2003 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Isnt)
// (originally from https://github.com/noir-lang/noir/issues/6126)
#[test]
fn numeric_generic_field_arithmetic_larger_than_u32() {
Expand Down Expand Up @@ -2969,9 +2969,7 @@
}
}

struct Bar {

}
struct Bar {}

impl Foo for Bar {

Expand All @@ -2983,12 +2981,17 @@
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert_eq!(errors.len(), 2);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0
else {
panic!("Expected a trait not implemented error, got {:?}", errors[0].0);
};

let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) =
&errors[0].0
&errors[1].0
else {
panic!("Expected an unresolved method call error, got {:?}", errors[0].0);
panic!("Expected an unresolved method call error, got {:?}", errors[1].0);
};

assert_eq!(method_name, "trait_func");
Expand Down
113 changes: 113 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,116 @@ fn trait_inheritance_missing_parent_implementation() {
assert_eq!(typ, "Struct");
assert_eq!(impl_trait, "Bar");
}

#[test]
fn errors_on_unknown_type_in_trait_where_clause() {
let src = r#"
pub trait Foo<T> where T: Unknown {}

fn main() {
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_concrete_type() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}

pub struct SomeGreeter;
impl Greeter for SomeGreeter {
fn greet(self) {}
}

pub struct Bar;

impl Foo<SomeGreeter> for Bar {}

fn main() {}
"#;
assert_no_errors(src);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_type_variable() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T> where T: Greeter {
fn greet(object: T) {
object.greet();
}
}

pub struct Bar;

impl<T> Foo<T> for Bar where T: Greeter {
}

fn main() {
}
"#;
assert_no_errors(src);
}
#[test]
fn errors_if_impl_trait_constraint_is_not_satisfied() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}

pub struct SomeGreeter;

pub struct Bar;

impl Foo<SomeGreeter> for Bar {}

fn main() {}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented {
impl_trait,
missing_trait: the_trait,
type_missing_trait: typ,
..
}) = &errors[0].0
else {
panic!("Expected a TraitNotImplemented error, got {:?}", &errors[0].0);
};

assert_eq!(the_trait, "Greeter");
assert_eq!(typ, "SomeGreeter");
assert_eq!(impl_trait, "Foo");
}
Loading