diff --git a/compiler/rustc_hir_analysis/src/collect/item_bounds.rs b/compiler/rustc_hir_analysis/src/collect/item_bounds.rs index c64741625a4ec..6bd251a198daf 100644 --- a/compiler/rustc_hir_analysis/src/collect/item_bounds.rs +++ b/compiler/rustc_hir_analysis/src/collect/item_bounds.rs @@ -1,8 +1,9 @@ -use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; use rustc_hir as hir; use rustc_infer::traits::util; +use rustc_middle::ty::fold::shift_vars; use rustc_middle::ty::{ - self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, + self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, }; use rustc_middle::{bug, span_bug}; use rustc_span::Span; @@ -42,14 +43,18 @@ fn associated_type_bounds<'tcx>( let trait_def_id = tcx.local_parent(assoc_item_def_id); let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id); - let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| { - match pred.kind().skip_binder() { - ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty, - ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty, - ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty, - _ => false, - } - }); + let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id())); + let bounds_from_parent = + trait_predicates.predicates.iter().copied().filter_map(|(clause, span)| { + remap_gat_vars_and_recurse_into_nested_projections( + tcx, + filter, + item_trait_ref, + assoc_item_def_id, + span, + clause, + ) + }); let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent)); debug!( @@ -63,6 +68,228 @@ fn associated_type_bounds<'tcx>( all_bounds } +/// The code below is quite involved, so let me explain. +/// +/// We loop here, because we also want to collect vars for nested associated items as +/// well. For example, given a clause like `Self::A::B`, we want to add that to the +/// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is +/// rigid. +/// +/// Secondly, regarding bound vars, when we see a where clause that mentions a GAT +/// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into +/// an item bound on the GAT, where all of the GAT args are substituted with the GAT's +/// param regions, and then keep all of the other late-bound vars in the bound around. +/// We need to "compress" the binder so that it doesn't mention any of those vars that +/// were mapped to params. +fn remap_gat_vars_and_recurse_into_nested_projections<'tcx>( + tcx: TyCtxt<'tcx>, + filter: PredicateFilter, + item_trait_ref: ty::TraitRef<'tcx>, + assoc_item_def_id: LocalDefId, + span: Span, + clause: ty::Clause<'tcx>, +) -> Option<(ty::Clause<'tcx>, Span)> { + let mut clause_ty = match clause.kind().skip_binder() { + ty::ClauseKind::Trait(tr) => tr.self_ty(), + ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(), + ty::ClauseKind::TypeOutlives(outlives) => outlives.0, + _ => return None, + }; + + let gat_vars = loop { + if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() { + if alias_ty.trait_ref(tcx) == item_trait_ref + && alias_ty.def_id == assoc_item_def_id.to_def_id() + { + // We have found the GAT in question... + // Return the vars, since we may need to remap them. + break &alias_ty.args[item_trait_ref.args.len()..]; + } else { + // Only collect *self* type bounds if the filter is for self. + match filter { + PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => { + return None; + } + PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {} + } + + clause_ty = alias_ty.self_ty(); + continue; + } + } + + return None; + }; + + // Special-case: No GAT vars, no mapping needed. + if gat_vars.is_empty() { + return Some((clause, span)); + } + + // First, check that all of the GAT args are substituted with a unique late-bound arg. + // If we find a duplicate, then it can't be mapped to the definition's params. + let mut mapping = FxIndexMap::default(); + let generics = tcx.generics_of(assoc_item_def_id); + for (param, var) in std::iter::zip(&generics.own_params, gat_vars) { + let existing = match var.unpack() { + ty::GenericArgKind::Lifetime(re) => { + if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() { + mapping.insert(bv.var, tcx.mk_param_from_def(param)) + } else { + return None; + } + } + ty::GenericArgKind::Type(ty) => { + if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() { + mapping.insert(bv.var, tcx.mk_param_from_def(param)) + } else { + return None; + } + } + ty::GenericArgKind::Const(ct) => { + if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() { + mapping.insert(bv, tcx.mk_param_from_def(param)) + } else { + return None; + } + } + }; + + if existing.is_some() { + return None; + } + } + + // Finally, map all of the args in the GAT to the params we expect, and compress + // the remaining late-bound vars so that they count up from var 0. + let mut folder = + MapAndCompressBoundVars { tcx, binder: ty::INNERMOST, still_bound_vars: vec![], mapping }; + let pred = clause.kind().skip_binder().fold_with(&mut folder); + + Some(( + ty::Binder::bind_with_vars(pred, tcx.mk_bound_variable_kinds(&folder.still_bound_vars)) + .upcast(tcx), + span, + )) +} + +/// Given some where clause like `for<'b, 'c> >::Gat<'b>: Bound<'c>`, +/// the mapping will map `'b` back to the GAT's `'b_identity`. Then we need to compress the +/// remaining bound var `'c` to index 0. +/// +/// This folder gives us: `for<'c> >::Gat<'b_identity>: Bound<'c>`, +/// which is sufficient for an item bound for `Gat`, since all of the GAT's args are identity. +struct MapAndCompressBoundVars<'tcx> { + tcx: TyCtxt<'tcx>, + /// How deep are we? Makes sure we don't touch the vars of nested binders. + binder: ty::DebruijnIndex, + /// List of bound vars that remain unsubstituted because they were not + /// mentioned in the GAT's args. + still_bound_vars: Vec, + /// Subtle invariant: If the `GenericArg` is bound, then it should be + /// stored with the debruijn index of `INNERMOST` so it can be shifted + /// correctly during substitution. + mapping: FxIndexMap>, +} + +impl<'tcx> TypeFolder> for MapAndCompressBoundVars<'tcx> { + fn cx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn fold_binder(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T> + where + ty::Binder<'tcx, T>: TypeSuperFoldable>, + { + self.binder.shift_in(1); + let out = t.super_fold_with(self); + self.binder.shift_out(1); + out + } + + fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> { + if !ty.has_bound_vars() { + return ty; + } + + if let ty::Bound(binder, old_bound) = *ty.kind() + && self.binder == binder + { + let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) { + mapped.expect_ty() + } else { + // If we didn't find a mapped generic, then make a new one. + // Allocate a new var idx, and insert a new bound ty. + let var = ty::BoundVar::from_usize(self.still_bound_vars.len()); + self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind)); + let mapped = Ty::new_bound( + self.tcx, + ty::INNERMOST, + ty::BoundTy { var, kind: old_bound.kind }, + ); + self.mapping.insert(old_bound.var, mapped.into()); + mapped + }; + + shift_vars(self.tcx, mapped, self.binder.as_u32()) + } else { + ty.super_fold_with(self) + } + } + + fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> { + if let ty::ReBound(binder, old_bound) = re.kind() + && self.binder == binder + { + let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) { + mapped.expect_region() + } else { + let var = ty::BoundVar::from_usize(self.still_bound_vars.len()); + self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind)); + let mapped = ty::Region::new_bound( + self.tcx, + ty::INNERMOST, + ty::BoundRegion { var, kind: old_bound.kind }, + ); + self.mapping.insert(old_bound.var, mapped.into()); + mapped + }; + + shift_vars(self.tcx, mapped, self.binder.as_u32()) + } else { + re + } + } + + fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { + if !ct.has_bound_vars() { + return ct; + } + + if let ty::ConstKind::Bound(binder, old_var) = ct.kind() + && self.binder == binder + { + let mapped = if let Some(mapped) = self.mapping.get(&old_var) { + mapped.expect_const() + } else { + let var = ty::BoundVar::from_usize(self.still_bound_vars.len()); + self.still_bound_vars.push(ty::BoundVariableKind::Const); + let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var); + self.mapping.insert(old_var, mapped.into()); + mapped + }; + + shift_vars(self.tcx, mapped, self.binder.as_u32()) + } else { + ct.super_fold_with(self) + } + } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if !p.has_bound_vars() { p } else { p.super_fold_with(self) } + } +} + /// Opaque types don't inherit bounds from their parent: for return position /// impl trait it isn't possible to write a suitable predicate on the /// containing function and for type-alias impl trait we don't have a backwards diff --git a/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.rs b/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.rs new file mode 100644 index 0000000000000..eb616631d1dc8 --- /dev/null +++ b/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.rs @@ -0,0 +1,28 @@ +// Demonstrates a mostly-theoretical inference guidance now that we turn the where +// clause on `Trait` into an item bound, given that we prefer item bounds somewhat +// greedily in trait selection. + +trait Bound {} +impl Bound for U {} + +trait Trait +where + <::Assoc as Other>::Assoc: Bound, +{ + type Assoc: Other; +} + +trait Other { + type Assoc; +} + +fn impls_trait, U>() -> Vec { vec![] } + +fn foo() { + let mut vec_u = impls_trait::<<::Assoc as Other>::Assoc, _>(); + vec_u.sort(); + drop::>(vec_u); + //~^ ERROR mismatched types +} + +fn main() {} diff --git a/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.stderr b/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.stderr new file mode 100644 index 0000000000000..c77400f382280 --- /dev/null +++ b/tests/ui/associated-type-bounds/nested-associated-type-bound-incompleteness.stderr @@ -0,0 +1,16 @@ +error[E0308]: mismatched types + --> $DIR/nested-associated-type-bound-incompleteness.rs:24:21 + | +LL | drop::>(vec_u); + | --------------- ^^^^^ expected `Vec`, found `Vec` + | | + | arguments to this function are incorrect + | + = note: expected struct `Vec` + found struct `Vec` +note: function defined here + --> $SRC_DIR/core/src/mem/mod.rs:LL:COL + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0308`. diff --git a/tests/ui/associated-type-bounds/nested-gat-projection.rs b/tests/ui/associated-type-bounds/nested-gat-projection.rs new file mode 100644 index 0000000000000..ad37da9ed1948 --- /dev/null +++ b/tests/ui/associated-type-bounds/nested-gat-projection.rs @@ -0,0 +1,31 @@ +//@ check-pass + +trait Trait +where + for<'a> Self::Gat<'a>: OtherTrait, + for<'a, 'b, 'c> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>, +{ + type Gat<'a>; +} + +trait OtherTrait { + type OtherGat<'b>; +} + +trait HigherRanked<'c> {} + +fn lower_ranked OtherTrait: HigherRanked<'c>>>() {} + +fn higher_ranked() +where + for<'a> T::Gat<'a>: OtherTrait, + for<'a, 'b, 'c> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>, +{ +} + +fn test() { + lower_ranked::>(); + higher_ranked::(); +} + +fn main() {} diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs new file mode 100644 index 0000000000000..864c31893504f --- /dev/null +++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs @@ -0,0 +1,28 @@ +//@ check-pass +//@ revisions: current next +//@[next] compile-flags: -Znext-solver + +trait Trait +where + Self::Assoc: Clone, +{ + type Assoc; +} + +fn foo(x: &T::Assoc) -> T::Assoc { + x.clone() +} + +trait Trait2 +where + Self::Assoc: Iterator, + ::Item: Clone, +{ + type Assoc; +} + +fn foo2(x: &::Item) -> ::Item { + x.clone() +} + +fn main() {} diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs new file mode 100644 index 0000000000000..4e3b0b3b148a3 --- /dev/null +++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs @@ -0,0 +1,19 @@ +//@ check-pass + +// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`, +// just as it would be if it weren't a GAT but just a regular associated type. + +use std::fmt::Debug; + +trait Foo +where + for<'a> Self::Gat<'a>: Debug, +{ + type Gat<'a>; +} + +fn test(x: T::Gat<'static>) { + println!("{:?}", x); +} + +fn main() {} diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs new file mode 100644 index 0000000000000..5a477a5b34941 --- /dev/null +++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs @@ -0,0 +1,23 @@ +//@ check-pass +//@ revisions: current next +//@[next] compile-flags: -Znext-solver + +trait Foo +where + Self::Iterator: Iterator, + ::Item: Bar, +{ + type Iterator; + + fn iter() -> Self::Iterator; +} + +trait Bar { + fn bar(&self); +} + +fn x() { + T::iter().next().unwrap().bar(); +} + +fn main() {}