Skip to content

Commit

Permalink
Fix a bug importing declarations in generics (#4752)
Browse files Browse the repository at this point in the history
Imported declarations that are contained within a generic, such as an
`impl forall...` would be given an abstract symbolic constant value
instead of a concrete generic value in some cases where the function was
used in an api file and impl file of the same library. This caused #4679
to fail using `ImplicitAs.Convert` transitively imported from the
prelude.

---------

Co-authored-by: Josh L <[email protected]>
Co-authored-by: Richard Smith <[email protected]>
  • Loading branch information
3 people authored Jan 2, 2025
1 parent c5fd8f4 commit 725e80f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 154 deletions.
55 changes: 35 additions & 20 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,27 @@ class ImportRefResolver;
struct ResolveResult {
// The new constant value, if known.
SemIR::ConstantId const_id;
// Newly created declaration whose value is being resolved, if any.
SemIR::InstId decl_id = SemIR::InstId::Invalid;
// Whether resolution has been attempted once and needs to be retried.
bool retry = false;

// Produces a resolve result that tries resolving this instruction again. If
// `const_id` is specified, then this is the end of the second phase, and the
// constant value will be passed to the next resolution attempt. Otherwise,
// this is the end of the first phase.
static auto Retry(SemIR::ConstantId const_id = SemIR::ConstantId::Invalid)
static auto Retry(SemIR::ConstantId const_id = SemIR::ConstantId::Invalid,
SemIR::InstId decl_id = SemIR::InstId::Invalid)
-> ResolveResult {
return {.const_id = const_id, .retry = true};
return {.const_id = const_id, .decl_id = decl_id, .retry = true};
}

// Produces a resolve result that provides the given constant value. Requires
// that there is no new work.
static auto Done(SemIR::ConstantId const_id) -> ResolveResult {
return {.const_id = const_id};
static auto Done(SemIR::ConstantId const_id,
SemIR::InstId decl_id = SemIR::InstId::Invalid)
-> ResolveResult {
return {.const_id = const_id, .decl_id = decl_id};
}
};
} // namespace
Expand Down Expand Up @@ -455,7 +460,7 @@ class ImportRefResolver : public ImportContext {

// Step 2: resolve the instruction.
initial_work_ = work_stack_.size();
auto [new_const_id, retry] =
auto [new_const_id, _, retry] =
TryResolveInst(*this, work.inst_id, existing.const_id);
CARBON_CHECK(!HasNewWork() || retry);

Expand Down Expand Up @@ -1309,7 +1314,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
resolver.local_context().MakeImportedLocAndInst<SemIR::AdaptDecl>(
AddImportIRInst(resolver, import_inst_id),
{.adapted_type_inst_id = adapted_type_inst_id}));
return ResolveResult::Done(resolver.local_constant_values().Get(inst_id));
return ResolveResult::Done(resolver.local_constant_values().Get(inst_id),
inst_id);
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1389,7 +1395,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
resolver.local_context().GetTypeIdForTypeConstant(type_const_id),
.base_type_inst_id = base_type_inst_id,
.index = inst.index}));
return ResolveResult::Done(resolver.local_constant_values().Get(inst_id));
return ResolveResult::Done(resolver.local_constant_values().Get(inst_id),
inst_id);
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1602,12 +1609,12 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
auto adapt_id = import_class.adapt_id.is_valid()
? GetLocalConstantInstId(resolver, import_class.adapt_id)
: SemIR::InstId::Invalid;
auto& new_class = resolver.local_classes().Get(class_id);

if (resolver.HasNewWork()) {
return ResolveResult::Retry(class_const_id);
return ResolveResult::Retry(class_const_id, new_class.first_decl_id());
}

auto& new_class = resolver.local_classes().Get(class_id);
new_class.parent_scope_id = parent_scope_id;
new_class.implicit_param_patterns_id = GetLocalParamPatternsId(
resolver, import_class.implicit_param_patterns_id);
Expand All @@ -1628,7 +1635,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
complete_type_witness_id, base_id, adapt_id);
}

return ResolveResult::Done(class_const_id);
return ResolveResult::Done(class_const_id, new_class.first_decl_id());
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1713,7 +1720,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
resolver.local_context().GetTypeIdForTypeConstant(const_id),
.name_id = GetLocalNameId(resolver, inst.name_id),
.index = inst.index}));
return {.const_id = resolver.local_constant_values().Get(inst_id)};
return ResolveResult::Done(resolver.local_constant_values().Get(inst_id),
inst_id);
}

// Make a declaration of a function. This is done as a separate step from
Expand Down Expand Up @@ -1794,13 +1802,14 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
import_function.implicit_param_patterns_id);
LoadLocalPatternConstantIds(resolver, import_function.param_patterns_id);
auto generic_data = GetLocalGenericData(resolver, import_function.generic_id);
auto& new_function = resolver.local_functions().Get(function_id);

if (resolver.HasNewWork()) {
return ResolveResult::Retry(function_const_id);
return ResolveResult::Retry(function_const_id,
new_function.first_decl_id());
}

// Add the function declaration.
auto& new_function = resolver.local_functions().Get(function_id);
new_function.parent_scope_id = parent_scope_id;
new_function.implicit_param_patterns_id = GetLocalParamPatternsId(
resolver, import_function.implicit_param_patterns_id);
Expand All @@ -1815,7 +1824,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
new_function.definition_id = new_function.first_owning_decl_id;
}

return ResolveResult::Done(function_const_id);
return ResolveResult::Done(function_const_id, new_function.first_decl_id());
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1963,12 +1972,12 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
auto constraint_const_id = GetLocalConstantId(
resolver,
resolver.import_constant_values().Get(import_impl.constraint_id));
auto& new_impl = resolver.local_impls().Get(impl_id);

if (resolver.HasNewWork()) {
return ResolveResult::Retry(impl_const_id);
return ResolveResult::Retry(impl_const_id, new_impl.first_decl_id());
}

auto& new_impl = resolver.local_impls().Get(impl_id);
new_impl.parent_scope_id = parent_scope_id;
new_impl.implicit_param_patterns_id =
GetLocalParamPatternsId(resolver, import_impl.implicit_param_patterns_id);
Expand All @@ -1995,7 +2004,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
resolver.local_impls().GetOrAddLookupBucket(new_impl).push_back(impl_id);
}

return ResolveResult::Done(impl_const_id);
return ResolveResult::Done(impl_const_id, new_impl.first_decl_id());
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -2135,12 +2144,13 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
self_param_id =
GetLocalConstantInstId(resolver, import_interface.self_param_id);
}
auto& new_interface = resolver.local_interfaces().Get(interface_id);

if (resolver.HasNewWork()) {
return ResolveResult::Retry(interface_const_id);
return ResolveResult::Retry(interface_const_id,
new_interface.first_decl_id());
}

auto& new_interface = resolver.local_interfaces().Get(interface_id);
new_interface.parent_scope_id = parent_scope_id;
new_interface.implicit_param_patterns_id = GetLocalParamPatternsId(
resolver, import_interface.implicit_param_patterns_id);
Expand All @@ -2154,7 +2164,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
AddInterfaceDefinition(resolver, import_interface, new_interface,
*self_param_id);
}
return ResolveResult::Done(interface_const_id);
return ResolveResult::Done(interface_const_id, new_interface.first_decl_id());
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -2747,6 +2757,11 @@ static auto TryResolveInst(ImportRefResolver& resolver, SemIR::InstId inst_id,
resolver.local_constant_values().GetInstId(result.const_id),
.generic_id = GetLocalGenericId(resolver, generic_const_id),
.index = symbolic_const.index});
if (result.decl_id.is_valid()) {
// Overwrite the abstract symbolic constant given initially to the
// declaration with its final concrete symbolic value.
resolver.local_constant_values().Set(result.decl_id, result.const_id);
}
}
} else {
// Third phase: perform a consistency check and produce the constant we
Expand Down
34 changes: 16 additions & 18 deletions toolchain/check/testdata/class/generic/import.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ class Class(U:! type) {
// CHECK:STDOUT: --- foo.impl.carbon
// CHECK:STDOUT:
// CHECK:STDOUT: constants {
// CHECK:STDOUT: %N: Core.IntLiteral = bind_symbolic_name N, 0 [symbolic]
// CHECK:STDOUT: %Int: type = class_type @Int, @Int(%N) [symbolic]
// CHECK:STDOUT: %Convert.type.2: type = fn_type @Convert.2, @impl.1(%N) [symbolic]
// CHECK:STDOUT: %Convert.2: %Convert.type.2 = struct_value () [symbolic]
// CHECK:STDOUT: %int_32: Core.IntLiteral = int_value 32 [template]
// CHECK:STDOUT: %i32: type = class_type @Int, @Int(%int_32) [template]
// CHECK:STDOUT: %T: type = bind_symbolic_name T, 0 [symbolic]
Expand All @@ -271,14 +267,16 @@ class Class(U:! type) {
// CHECK:STDOUT: %CompleteClass.elem.2: type = unbound_element_type %CompleteClass.2, %i32 [template]
// CHECK:STDOUT: %F.type.3: type = fn_type @F.1, @CompleteClass(%i32) [template]
// CHECK:STDOUT: %F.3: %F.type.3 = struct_value () [template]
// CHECK:STDOUT: %int_1: Core.IntLiteral = int_value 1 [template]
// CHECK:STDOUT: %int_1.1: Core.IntLiteral = int_value 1 [template]
// CHECK:STDOUT: %struct_type.n.2: type = struct_type {.n: Core.IntLiteral} [template]
// CHECK:STDOUT: %Convert.type.9: type = fn_type @Convert.1, @ImplicitAs(%i32) [template]
// CHECK:STDOUT: %impl_witness.19: <witness> = impl_witness (imports.%import_ref.14), @impl.1(%int_32) [template]
// CHECK:STDOUT: %Convert.bound: <bound method> = bound_method %int_1, %Convert.2 [symbolic]
// CHECK:STDOUT: %Convert.specific_fn: <specific function> = specific_function %Convert.bound, @Convert.2(%N) [symbolic]
// CHECK:STDOUT: %int.convert_checked: init %Int = call %Convert.specific_fn(%int_1) [symbolic]
// CHECK:STDOUT: %CompleteClass.val: %CompleteClass.2 = struct_value (%int.convert_checked) [symbolic]
// CHECK:STDOUT: %Convert.type.10: type = fn_type @Convert.2, @impl.1(%int_32) [template]
// CHECK:STDOUT: %Convert.10: %Convert.type.10 = struct_value () [template]
// CHECK:STDOUT: %Convert.bound: <bound method> = bound_method %int_1.1, %Convert.10 [template]
// CHECK:STDOUT: %Convert.specific_fn: <specific function> = specific_function %Convert.bound, @Convert.2(%int_32) [template]
// CHECK:STDOUT: %int_1.2: %i32 = int_value 1 [template]
// CHECK:STDOUT: %CompleteClass.val: %CompleteClass.2 = struct_value (%int_1.2) [template]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: imports {
Expand Down Expand Up @@ -374,17 +372,17 @@ class Class(U:! type) {
// CHECK:STDOUT:
// CHECK:STDOUT: fn @F.2() -> %return.param_patt: %CompleteClass.2 [from "foo.carbon"] {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %int_1: Core.IntLiteral = int_value 1 [template = constants.%int_1]
// CHECK:STDOUT: %int_1: Core.IntLiteral = int_value 1 [template = constants.%int_1.1]
// CHECK:STDOUT: %.loc9_17.1: %struct_type.n.2 = struct_literal (%int_1)
// CHECK:STDOUT: %impl.elem0: %Convert.type.9 = impl_witness_access constants.%impl_witness.19, element0 [symbolic = constants.%Convert.2]
// CHECK:STDOUT: %Convert.bound: <bound method> = bound_method %int_1, %impl.elem0 [symbolic = constants.%Convert.bound]
// CHECK:STDOUT: %Convert.specific_fn: <specific function> = specific_function %Convert.bound, @Convert.2(constants.%N) [symbolic = constants.%Convert.specific_fn]
// CHECK:STDOUT: %int.convert_checked: init %Int = call %Convert.specific_fn(%int_1) [symbolic = constants.%int.convert_checked]
// CHECK:STDOUT: %.loc9_17.2: init %i32 = converted %int_1, %int.convert_checked [symbolic = constants.%int.convert_checked]
// CHECK:STDOUT: %impl.elem0: %Convert.type.9 = impl_witness_access constants.%impl_witness.19, element0 [template = constants.%Convert.10]
// CHECK:STDOUT: %Convert.bound: <bound method> = bound_method %int_1, %impl.elem0 [template = constants.%Convert.bound]
// CHECK:STDOUT: %Convert.specific_fn: <specific function> = specific_function %Convert.bound, @Convert.2(constants.%int_32) [template = constants.%Convert.specific_fn]
// CHECK:STDOUT: %int.convert_checked: init %i32 = call %Convert.specific_fn(%int_1) [template = constants.%int_1.2]
// CHECK:STDOUT: %.loc9_17.2: init %i32 = converted %int_1, %int.convert_checked [template = constants.%int_1.2]
// CHECK:STDOUT: %.loc9_17.3: ref %i32 = class_element_access %return, element0
// CHECK:STDOUT: %.loc9_17.4: init %i32 = initialize_from %.loc9_17.2 to %.loc9_17.3 [symbolic = constants.%int.convert_checked]
// CHECK:STDOUT: %.loc9_17.5: init %CompleteClass.2 = class_init (%.loc9_17.4), %return [symbolic = constants.%CompleteClass.val]
// CHECK:STDOUT: %.loc9_18: init %CompleteClass.2 = converted %.loc9_17.1, %.loc9_17.5 [symbolic = constants.%CompleteClass.val]
// CHECK:STDOUT: %.loc9_17.4: init %i32 = initialize_from %.loc9_17.2 to %.loc9_17.3 [template = constants.%int_1.2]
// CHECK:STDOUT: %.loc9_17.5: init %CompleteClass.2 = class_init (%.loc9_17.4), %return [template = constants.%CompleteClass.val]
// CHECK:STDOUT: %.loc9_18: init %CompleteClass.2 = converted %.loc9_17.1, %.loc9_17.5 [template = constants.%CompleteClass.val]
// CHECK:STDOUT: return %.loc9_18 to %return
// CHECK:STDOUT: }
// CHECK:STDOUT:
Expand Down
Loading

0 comments on commit 725e80f

Please sign in to comment.