Skip to content

Commit

Permalink
fix(experimental elaborator): Only call add_generics once (#5091)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves an issue introduced by
#5082 where `add_generics` would
be called multiple times on a given item. Once when declaring the
FuncMeta beforehand, and again when defining it afterward. This is an
issue because `add_generics` will give a fresh set of generics each
time. This meant the generics in a function's signature weren't the same
type variables as the ones used internally, which led to type errors.

## Summary\*



## Additional Context

Down to ~~400~~ 226 errors in the stdlib down from 1050

Edit: Fixed a case where structs never popped their generics out of
scope

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Maxim Vezenov <[email protected]>
  • Loading branch information
jfecher and vezenovm authored May 28, 2024
1 parent ffcb410 commit f5d2946
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 103 deletions.
195 changes: 106 additions & 89 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ impl<'context> Elaborator<'context> {
}

this.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);

this.collect_traits(items.traits);

// Must resolve structs before we resolve globals.
Expand All @@ -224,7 +223,7 @@ impl<'context> Elaborator<'context> {
//
// These are resolved after trait impls so that struct methods are chosen
// over trait methods if there are name conflicts.
for ((typ, module), impls) in &items.impls {
for ((typ, module), impls) in &mut items.impls {
this.collect_impls(typ, *module, impls);
}

Expand All @@ -238,8 +237,8 @@ impl<'context> Elaborator<'context> {
this.elaborate_functions(functions);
}

for ((typ, module), impls) in items.impls {
this.elaborate_impls(typ, module, impls);
for impls in items.impls.into_values() {
this.elaborate_impls(impls);
}

for trait_impl in items.trait_impls {
Expand All @@ -251,15 +250,27 @@ impl<'context> Elaborator<'context> {
this.errors
}

/// Runs `f` and if it modifies `self.generics`, `self.generics` is truncated
/// back to the previous length.
fn recover_generics<T>(&mut self, f: impl FnOnce(&mut Self) -> T) -> T {
let generics_count = self.generics.len();
let ret = f(self);
self.generics.truncate(generics_count);
ret
}

fn elaborate_functions(&mut self, functions: UnresolvedFunctions) {
self.file = functions.file_id;
self.trait_id = functions.trait_id; // TODO: Resolve?
self.self_type = functions.self_type;

for (local_module, id, func) in functions.functions {
self.local_module = local_module;
let generics_count = self.generics.len();
self.elaborate_function(func, id);
self.generics.truncate(generics_count);
self.recover_generics(|this| this.elaborate_function(func, id));
}

self.self_type = None;
self.trait_id = None;
}

fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) {
Expand All @@ -277,12 +288,6 @@ impl<'context> Elaborator<'context> {
self.resolve_local_globals();
self.trait_bounds = function.def.where_clause.clone();

let is_low_level_or_oracle = function
.attributes()
.function
.as_ref()
.map_or(false, |func| func.is_low_level() || func.is_oracle());

if function.def.is_unconstrained {
self.in_unconstrained_fn = true;
}
Expand All @@ -297,7 +302,7 @@ impl<'context> Elaborator<'context> {
self.elaborate_pattern(parameter.pattern.clone(), param2.1.clone(), definition_kind);
}

self.add_generics(&function.def.generics);
self.generics = func_meta.all_generics.clone();
self.desugar_impl_trait_args(&mut function, id);
self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type());
self.add_trait_constraints_to_scope(&func_meta);
Expand All @@ -315,7 +320,8 @@ impl<'context> Elaborator<'context> {
}
};

if !func_meta.can_ignore_return_type() {
// Don't verify the return type for builtin functions & trait function declarations
if !func_meta.is_stub() {
self.type_check_function_body(body_type, &func_meta, hir_func.as_expr());
}

Expand Down Expand Up @@ -356,7 +362,7 @@ impl<'context> Elaborator<'context> {
let func_scope_tree = self.scopes.end_function();

// The arguments to low-level and oracle functions are always unused so we do not produce warnings for them.
if !is_low_level_or_oracle {
if !func_meta.is_stub() {
self.check_for_unused_variables_in_scope_tree(func_scope_tree);
}

Expand Down Expand Up @@ -525,7 +531,12 @@ impl<'context> Elaborator<'context> {
/// to be used in analysis and intern the function parameters
/// Prerequisite: any implicit generics, including any generics from the impl,
/// have already been added to scope via `self.add_generics`.
fn define_function_meta(&mut self, func: &mut NoirFunction, func_id: FuncId) {
fn define_function_meta(
&mut self,
func: &mut NoirFunction,
func_id: FuncId,
is_trait_function: bool,
) {
self.current_function = Some(func_id);
self.resolve_where_clause(&mut func.def.where_clause);

Expand Down Expand Up @@ -652,13 +663,15 @@ impl<'context> Elaborator<'context> {
location,
typ,
direct_generics,
all_generics: self.generics.clone(),
trait_impl: self.current_trait_impl,
parameters: parameters.into(),
return_type: func.def.return_type.clone(),
return_visibility: func.def.return_visibility,
has_body: !func.def.body.is_empty(),
trait_constraints: self.resolve_trait_constraints(&func.def.where_clause),
is_entry_point,
is_trait_function,
has_inline_attribute,
};

Expand Down Expand Up @@ -840,42 +853,10 @@ impl<'context> Elaborator<'context> {
}
}

fn elaborate_impls(
&mut self,
typ: UnresolvedType,
module: LocalModuleId,
impls: Vec<(Vec<Ident>, Span, UnresolvedFunctions)>,
) {
self.local_module = module;

for (generics, _, functions) in impls {
fn elaborate_impls(&mut self, impls: Vec<(Vec<Ident>, Span, UnresolvedFunctions)>) {
for (_, _, functions) in impls {
self.file = functions.file_id;
let old_generics_length = self.generics.len();
self.add_generics(&generics);
let self_type = self.resolve_type(typ.clone());
self.self_type = Some(self_type.clone());

let function_ids = vecmap(&functions.functions, |(_, id, _)| *id);
self.elaborate_functions(functions);

if self_type != Type::Error {
for method_id in function_ids {
let method_name = self.interner.function_name(&method_id).to_owned();

if let Some(first_fn) =
self.interner.add_method(&self_type, method_name.clone(), method_id, false)
{
let error = ResolverError::DuplicateDefinition {
name: method_name,
first_span: self.interner.function_ident(&first_fn).span(),
second_span: self.interner.function_ident(&method_id).span(),
};
self.push_err(error);
}
}
}

self.generics.truncate(old_generics_length);
self.recover_generics(|this| this.elaborate_functions(functions));
}
}

Expand All @@ -886,7 +867,7 @@ impl<'context> Elaborator<'context> {
let unresolved_type = trait_impl.object_type;
let self_type_span = unresolved_type.span;
let old_generics_length = self.generics.len();
self.add_generics(&trait_impl.generics);
self.generics = trait_impl.resolved_generics;

let trait_generics =
vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone()));
Expand All @@ -895,7 +876,6 @@ impl<'context> Elaborator<'context> {
let impl_id =
trait_impl.impl_id.expect("An impls' id should be set during define_function_metas");

self.self_type = Some(self_type.clone());
self.current_trait_impl = trait_impl.impl_id;

let methods = trait_impl.methods.function_ids();
Expand Down Expand Up @@ -960,13 +940,16 @@ impl<'context> Elaborator<'context> {
&mut self,
self_type: &UnresolvedType,
module: LocalModuleId,
impls: &[(Vec<Ident>, Span, UnresolvedFunctions)],
impls: &mut [(Vec<Ident>, Span, UnresolvedFunctions)],
) {
self.local_module = module;

for (generics, span, unresolved) in impls {
self.file = unresolved.file_id;
self.declare_method_on_struct(self_type, generics, false, unresolved, *span);
self.recover_generics(|this| {
this.add_generics(generics);
this.declare_methods_on_struct(self_type, false, unresolved, *span);
});
}
}

Expand All @@ -980,8 +963,12 @@ impl<'context> Elaborator<'context> {

let span = trait_impl.object_type.span.expect("All trait self types should have spans");
let object_type = &trait_impl.object_type;
let generics = &trait_impl.generics;
self.declare_method_on_struct(object_type, generics, true, &trait_impl.methods, span);

self.recover_generics(|this| {
this.add_generics(&trait_impl.generics);
trait_impl.resolved_generics = this.generics.clone();
this.declare_methods_on_struct(object_type, true, &mut trait_impl.methods, span);
});
}
}

Expand All @@ -990,33 +977,33 @@ impl<'context> Elaborator<'context> {
&mut self.def_maps.get_mut(&module.krate).expect(message).modules[module.local_id.0]
}

fn declare_method_on_struct(
fn declare_methods_on_struct(
&mut self,
self_type: &UnresolvedType,
generics: &UnresolvedGenerics,
is_trait_impl: bool,
functions: &UnresolvedFunctions,
functions: &mut UnresolvedFunctions,
span: Span,
) {
let generic_count = self.generics.len();
self.add_generics(generics);
let typ = self.resolve_type(self_type.clone());
let self_type = self.resolve_type(self_type.clone());

functions.self_type = Some(self_type.clone());

let function_ids = functions.function_ids();

if let Type::Struct(struct_type, _generics) = typ {
let struct_type = struct_type.borrow();
if let Type::Struct(struct_type, _) = &self_type {
let struct_ref = struct_type.borrow();

// `impl`s are only allowed on types defined within the current crate
if !is_trait_impl && struct_type.id.krate() != self.crate_id {
let type_name = struct_type.name.to_string();
if !is_trait_impl && struct_ref.id.krate() != self.crate_id {
let type_name = struct_ref.name.to_string();
self.push_err(DefCollectorErrorKind::ForeignImpl { span, type_name });
self.generics.truncate(generic_count);
return;
}

// Grab the module defined by the struct type. Note that impls are a case
// where the module the methods are added to is not the same as the module
// they are resolved in.
let module = self.get_module_mut(struct_type.id.module_id());
let module = self.get_module_mut(struct_ref.id.module_id());

for (_, method_id, method) in &functions.functions {
// If this method was already declared, remove it from the module so it cannot
Expand All @@ -1028,11 +1015,33 @@ impl<'context> Elaborator<'context> {
module.remove_function(method.name_ident());
}
}
// Prohibit defining impls for primitive types if we're not in the stdlib
} else if !is_trait_impl && typ != Type::Error && !self.crate_id.is_stdlib() {
self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span });

self.declare_struct_methods(&self_type, &function_ids);
// We can define methods on primitive types only if we're in the stdlib
} else if !is_trait_impl && self_type != Type::Error {
if self.crate_id.is_stdlib() {
self.declare_struct_methods(&self_type, &function_ids);
} else {
self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span });
}
}
}

fn declare_struct_methods(&mut self, self_type: &Type, function_ids: &[FuncId]) {
for method_id in function_ids {
let method_name = self.interner.function_name(method_id).to_owned();

if let Some(first_fn) =
self.interner.add_method(self_type, method_name.clone(), *method_id, false)
{
let error = ResolverError::DuplicateDefinition {
name: method_name,
first_span: self.interner.function_ident(&first_fn).span(),
second_span: self.interner.function_ident(method_id).span(),
};
self.push_err(error);
}
}
self.generics.truncate(generic_count);
}

fn collect_trait_impl_methods(
Expand Down Expand Up @@ -1203,18 +1212,20 @@ impl<'context> Elaborator<'context> {
unresolved: NoirStruct,
struct_id: StructId,
) -> (Generics, Vec<(Ident, Type)>) {
let generics = self.add_generics(&unresolved.generics);
self.recover_generics(|this| {
let generics = this.add_generics(&unresolved.generics);

// Check whether the struct definition has globals in the local module and add them to the scope
self.resolve_local_globals();
// Check whether the struct definition has globals in the local module and add them to the scope
this.resolve_local_globals();

self.current_item = Some(DependencyId::Struct(struct_id));
this.current_item = Some(DependencyId::Struct(struct_id));

self.resolving_ids.insert(struct_id);
let fields = vecmap(unresolved.fields, |(ident, typ)| (ident, self.resolve_type(typ)));
self.resolving_ids.remove(&struct_id);
this.resolving_ids.insert(struct_id);
let fields = vecmap(unresolved.fields, |(ident, typ)| (ident, this.resolve_type(typ)));
this.resolving_ids.remove(&struct_id);

(generics, fields)
(generics, fields)
})
}

fn elaborate_global(&mut self, global: UnresolvedGlobal) {
Expand Down Expand Up @@ -1256,11 +1267,16 @@ impl<'context> Elaborator<'context> {
self.define_function_metas_for_functions(function_set);
}

for ((_typ, local_module), function_sets) in impls {
for ((self_type, local_module), function_sets) in impls {
self.local_module = *local_module;

for (_generics, _, function_set) in function_sets {
for (generics, _, function_set) in function_sets {
self.add_generics(generics);
let self_type = self.resolve_type(self_type.clone());
function_set.self_type = Some(self_type.clone());
self.self_type = Some(self_type);
self.define_function_metas_for_functions(function_set);
self.generics.clear();
}
}

Expand All @@ -1269,11 +1285,12 @@ impl<'context> Elaborator<'context> {
self.local_module = trait_impl.module_id;

let unresolved_type = &trait_impl.object_type;
let old_generics_length = self.generics.len();
self.add_generics(&trait_impl.generics);

let self_type = self.resolve_type(unresolved_type.clone());

self.self_type = Some(self_type.clone());
trait_impl.methods.self_type = Some(self_type);

let impl_id = self.interner.next_trait_impl_id();
self.current_trait_impl = Some(impl_id);
Expand All @@ -1282,7 +1299,7 @@ impl<'context> Elaborator<'context> {

trait_impl.resolved_object_type = self.self_type.take();
trait_impl.impl_id = self.current_trait_impl.take();
self.generics.truncate(old_generics_length);
self.generics.clear();
}
}

Expand All @@ -1291,9 +1308,9 @@ impl<'context> Elaborator<'context> {

for (local_module, id, func) in &mut function_set.functions {
self.local_module = *local_module;
let old_generics_length = self.generics.len();
self.define_function_meta(func, *id);
self.generics.truncate(old_generics_length);
self.recover_generics(|this| {
this.define_function_meta(func, *id, false);
});
}
}
}
Loading

0 comments on commit f5d2946

Please sign in to comment.