Skip to content

Commit

Permalink
Support specializtion for RPITITs
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Feb 23, 2023
1 parent b42a811 commit 7ceaf52
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 114 deletions.
28 changes: 21 additions & 7 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(super) fn compare_impl_method<'tcx>(
compare_generic_param_kinds(tcx, impl_m, trait_m, false)?;
compare_number_of_method_arguments(tcx, impl_m, trait_m)?;
compare_synthetic_generics(tcx, impl_m, trait_m)?;
compare_asyncness(tcx, impl_m, trait_m)?;
compare_asyncness(tcx, impl_m, trait_m, false)?;
compare_method_predicate_entailment(
tcx,
impl_m,
Expand Down Expand Up @@ -191,6 +191,11 @@ fn compare_method_predicate_entailment<'tcx>(
.map(|(predicate, _)| predicate),
);

// Additionally, we are allowed to assume that we can project RPITITs to their
// associated hidden types within method signatures. This is to allow us to support
// specialization with `impl Trait` in traits.
hybrid_preds.predicates.extend(tcx.additional_method_assumptions(impl_m_def_id));

// Construct trait parameter environment and then shift it into the placeholder viewpoint.
// The key step here is to update the caller_bounds's predicates to be
// the new hybrid bounds we computed.
Expand Down Expand Up @@ -526,6 +531,7 @@ fn compare_asyncness<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: ty::AssocItem,
trait_m: ty::AssocItem,
delay: bool,
) -> Result<(), ErrorGuaranteed> {
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
Expand All @@ -536,11 +542,14 @@ fn compare_asyncness<'tcx>(
// We don't know if it's ok, but at least it's already an error.
}
_ => {
return Err(tcx.sess.emit_err(crate::errors::AsyncTraitImplShouldBeAsync {
span: tcx.def_span(impl_m.def_id),
method_name: trait_m.name,
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
}));
return Err(tcx
.sess
.create_err(crate::errors::AsyncTraitImplShouldBeAsync {
span: tcx.def_span(impl_m.def_id),
method_name: trait_m.name,
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
})
.emit_unless(delay));
}
};
}
Expand Down Expand Up @@ -590,10 +599,15 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
let impl_trait_ref =
tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap().subst_identity();
let param_env = tcx.param_env(def_id);

// We use the RPITIT values computed in this method to construct the param-env,
// so to avoid cycles, we do computations in this function without assuming anything
// about RPITIT projection.
let param_env = tcx.param_env_no_assumptions(def_id);

// First, check a few of the same things as `compare_impl_method`,
// just so we don't ICE during substitution later.
compare_asyncness(tcx, impl_m, trait_m, true)?;
compare_number_of_generics(tcx, impl_m, trait_m, true)?;
compare_generic_param_kinds(tcx, impl_m, trait_m, true)?;
check_region_bounds_on_impl_item(tcx, impl_m, trait_m, true)?;
Expand Down
30 changes: 1 addition & 29 deletions compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,34 +1101,6 @@ fn should_encode_const(def_kind: DefKind) -> bool {
}
}

fn should_encode_trait_impl_trait_tys(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
if tcx.def_kind(def_id) != DefKind::AssocFn {
return false;
}

let Some(item) = tcx.opt_associated_item(def_id) else { return false; };
if item.container != ty::AssocItemContainer::ImplContainer {
return false;
}

let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };

// FIXME(RPITIT): This does a somewhat manual walk through the signature
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
// of work. We can probably remove this when we refactor RPITITs to be
// associated types.
tcx.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, data) = ty.kind()
&& tcx.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
{
true
} else {
false
}
})
}

// Return `false` to avoid encoding impl trait in trait, while we don't use the query.
fn should_encode_fn_impl_trait_in_trait<'tcx>(_tcx: TyCtxt<'tcx>, _def_id: DefId) -> bool {
false
Expand Down Expand Up @@ -1211,7 +1183,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
if let DefKind::Enum | DefKind::Struct | DefKind::Union = def_kind {
self.encode_info_for_adt(def_id);
}
if should_encode_trait_impl_trait_tys(tcx, def_id)
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
&& let Ok(table) = self.tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
{
record!(self.tables.trait_impl_trait_tys[def_id] <- table);
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,14 @@ rustc_queries! {
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
}

query param_env_no_assumptions(def_id: DefId) -> ty::ParamEnv<'tcx> {
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
}

query additional_method_assumptions(def_id: DefId) -> &'tcx ty::List<ty::Predicate<'tcx>> {
desc { |tcx| "computing additional predicate assumptions for the body of `{}`", tcx.def_path_str(def_id) }
}

/// Like `param_env`, but returns the `ParamEnv` in `Reveal::All` mode.
/// Prefer this over `tcx.param_env(def_id).with_reveal_all_normalized(tcx)`,
/// as this method is more efficient.
Expand Down
28 changes: 28 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2544,6 +2544,34 @@ impl<'tcx> TyCtxt<'tcx> {
}
def_id
}

pub fn impl_method_has_trait_impl_trait_tys(self, def_id: DefId) -> bool {
if self.def_kind(def_id) != DefKind::AssocFn {
return false;
}

let Some(item) = self.opt_associated_item(def_id) else { return false; };
if item.container != ty::AssocItemContainer::ImplContainer {
return false;
}

let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };

// FIXME(RPITIT): This does a somewhat manual walk through the signature
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
// of work. We can probably remove this when we refactor RPITITs to be
// associated types.
self.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, data) = ty.kind()
&& self.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
{
true
} else {
false
}
})
}
}

/// Yields the parent function's `LocalDefId` if `def_id` is an `impl Trait` definition.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ impl<'tcx> InternalSubsts<'tcx> {
target_substs: SubstsRef<'tcx>,
) -> SubstsRef<'tcx> {
let defs = tcx.generics_of(source_ancestor);
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.params.len())))
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.count())))
}

pub fn truncate_to(&self, tcx: TyCtxt<'tcx>, generics: &ty::Generics) -> SubstsRef<'tcx> {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,13 +1246,13 @@ fn project<'cx, 'tcx>(

let mut candidates = ProjectionCandidateSet::None;

assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);

// Make sure that the following procedures are kept in order. ParamEnv
// needs to be first because it has highest priority, and Select checks
// the return value of push_candidate which assumes it's ran at last.
assemble_candidates_from_param_env(selcx, obligation, &mut candidates);

assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);

assemble_candidates_from_trait_def(selcx, obligation, &mut candidates);

assemble_candidates_from_object_ty(selcx, obligation, &mut candidates);
Expand Down
89 changes: 73 additions & 16 deletions compiler/rustc_ty_utils/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn adt_sized_constraint(tcx: TyCtxt<'_>, def_id: DefId) -> &[Ty<'_>] {
}

/// See `ParamEnv` struct definition for details.
fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
fn param_env(tcx: TyCtxt<'_>, def_id: DefId, add_assumptions: bool) -> ty::ParamEnv<'_> {
// Compute the bounds on Self and the type parameters.
let ty::InstantiatedPredicates { mut predicates, .. } =
tcx.predicates_of(def_id).instantiate_identity(tcx);
Expand All @@ -138,17 +138,8 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
predicates.extend(environment);
}

if tcx.def_kind(def_id) == DefKind::AssocFn
&& tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer
{
let sig = tcx.fn_sig(def_id).subst_identity();
sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: def_id,
bound_vars: sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
});
if add_assumptions && tcx.def_kind(def_id) == DefKind::AssocFn {
predicates.extend(tcx.additional_method_assumptions(def_id))
}

let local_did = def_id.as_local();
Expand Down Expand Up @@ -237,19 +228,83 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
traits::normalize_param_env_or_error(tcx, unnormalized_env, cause)
}

fn additional_method_assumptions<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
) -> &'tcx ty::List<Predicate<'tcx>> {
let assoc_item = tcx.associated_item(def_id);
let mut predicates = vec![];

match assoc_item.container {
ty::AssocItemContainer::TraitContainer => {
let sig = tcx.fn_sig(def_id).subst_identity();
sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: def_id,
bound_vars: sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
hidden_ty: |alias_ty| tcx.mk_alias(ty::Opaque, alias_ty),
});
}
ty::AssocItemContainer::ImplContainer => {
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
&& let Ok(table)
= tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
{
let impl_def_id = assoc_item.container_id(tcx);
let trait_to_impl_substs =
tcx.impl_trait_ref(impl_def_id).unwrap().subst_identity().substs;
// Create mapping from impl to placeholder.
let impl_to_placeholder_substs = ty::InternalSubsts::identity_for_item(tcx, def_id);
// Create mapping from trait to placeholder.
let trait_to_placeholder_substs =
impl_to_placeholder_substs.rebase_onto(tcx, impl_def_id, trait_to_impl_substs);

let trait_fn_def_id = assoc_item.trait_item_def_id.unwrap();
let trait_fn_sig =
tcx.fn_sig(trait_fn_def_id).subst(tcx, trait_to_placeholder_substs);
trait_fn_sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: trait_fn_def_id,
bound_vars: trait_fn_sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
hidden_ty: |alias_ty| {
EarlyBinder(*table.get(&alias_ty.def_id).unwrap()).subst(
tcx,
alias_ty.substs.rebase_onto(
tcx,
trait_fn_def_id,
impl_to_placeholder_substs,
),
)
},
});
}
}
}

tcx.intern_predicates(&predicates)
}

/// Walk through a function type, gathering all RPITITs and installing a
/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the
/// predicates list. This allows us to observe that an RPITIT projects to
/// its corresponding opaque within the body of a default-body trait method.
struct ImplTraitInTraitFinder<'a, 'tcx> {
struct ImplTraitInTraitFinder<'a, 'tcx, F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>> {
tcx: TyCtxt<'tcx>,
predicates: &'a mut Vec<Predicate<'tcx>>,
fn_def_id: DefId,
bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
seen: FxHashSet<DefId>,
hidden_ty: F,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
impl<'tcx, F> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx, F>
where
F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>,
{
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let ty::Alias(ty::Projection, alias_ty) = *ty.kind()
&& self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder
Expand All @@ -260,7 +315,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
ty::Binder::bind_with_vars(
ty::ProjectionPredicate {
projection_ty: alias_ty,
term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(),
term: (self.hidden_ty)(alias_ty).into(),
},
self.bound_vars,
)
Expand Down Expand Up @@ -514,7 +569,9 @@ pub fn provide(providers: &mut ty::query::Providers) {
*providers = ty::query::Providers {
asyncness,
adt_sized_constraint,
param_env,
param_env: |tcx, def_id| param_env(tcx, def_id, true),
param_env_no_assumptions: |tcx, def_id| param_env(tcx, def_id, false),
additional_method_assumptions,
param_env_reveal_all_normalized,
instance_def_size_estimate,
issue33140_self_ty,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// edition: 2021
// known-bug: #108309
// check-pass

#![feature(async_fn_in_trait)]
//~^ WARN the feature `async_fn_in_trait` is incomplete
#![feature(min_specialization)]

struct MyStruct;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,5 @@ LL | #![feature(async_fn_in_trait)]
= note: see issue #91611 <https://github.com/rust-lang/rust/issues/91611> for more information
= note: `#[warn(incomplete_features)]` on by default

error[E0053]: method `foo` has an incompatible type for trait
--> $DIR/dont-project-to-specializable-projection.rs:14:35
|
LL | default async fn foo(_: T) -> &'static str {
| ^^^^^^^^^^^^ expected associated type, found future
|
note: type in trait
--> $DIR/dont-project-to-specializable-projection.rs:10:27
|
LL | async fn foo(_: T) -> &'static str;
| ^^^^^^^^^^^^
= note: expected signature `fn(_) -> impl Future<Output = &'static str>`
found signature `fn(_) -> impl Future<Output = &'static str>`

error: aborting due to previous error; 1 warning emitted
warning: 1 warning emitted

For more information about this error, try `rustc --explain E0053`.
4 changes: 2 additions & 2 deletions tests/ui/impl-trait/in-trait/method-signature-matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trait TooMuch {

impl TooMuch for () {
fn calm_down_please(_: (), _: (), _: ()) {}
//~^ ERROR method `calm_down_please` has 3 parameters but the declaration in trait `TooMuch::calm_down_please` has 0
//~^ ERROR method `calm_down_please` has an incompatible type for trait
}

trait TooLittle {
Expand All @@ -36,7 +36,7 @@ trait TooLittle {

impl TooLittle for () {
fn come_on_a_little_more_effort() {}
//~^ ERROR method `come_on_a_little_more_effort` has 0 parameters but the declaration in trait `TooLittle::come_on_a_little_more_effort` has 3
//~^ ERROR method `come_on_a_little_more_effort` has an incompatible type for trait
}

trait Lifetimes {
Expand Down
Loading

0 comments on commit 7ceaf52

Please sign in to comment.