diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 5118bf5c3b7ab..dbf86f5cf747d 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -214,6 +214,7 @@ language_item_table! { FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0); + FusedIterator, sym::fused_iterator, fused_iterator_trait, Target::Trait, GenericRequirement::Exact(0); Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0); AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0); diff --git a/compiler/rustc_middle/src/traits/select.rs b/compiler/rustc_middle/src/traits/select.rs index 8e9751f45294c..095626e272dfb 100644 --- a/compiler/rustc_middle/src/traits/select.rs +++ b/compiler/rustc_middle/src/traits/select.rs @@ -156,6 +156,10 @@ pub enum SelectionCandidate<'tcx> { /// generated for a `gen` construct. IteratorCandidate, + /// Implementation of an `FusedIterator` trait by one of the coroutine types + /// generated for a `gen` construct. + FusedIteratorCandidate, + /// Implementation of an `AsyncIterator` trait by one of the coroutine types /// generated for a `async gen` construct. AsyncIteratorCandidate, diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 4748e961019e1..37950a6f7c64b 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -624,7 +624,9 @@ impl<'tcx> Instance<'tcx> { hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) ); hir::LangItem::FuturePoll - } else if Some(trait_id) == lang_items.iterator_trait() { + } else if Some(trait_id) == lang_items.iterator_trait() + || Some(trait_id) == lang_items.fused_iterator_trait() + { assert_matches!( coroutine_kind, hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 8b911a41a112f..c28c577d78014 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -207,6 +207,7 @@ symbols! { FromResidual, FsOpenOptions, FsPermissions, + FusedIterator, Future, FutureOutput, GlobalAlloc, @@ -885,6 +886,7 @@ symbols! { fsub_algebraic, fsub_fast, fundamental, + fused_iterator, future, future_trait, gdb_script_file, diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs index 9f33dce2a6dfe..b98bfeb0a5b87 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs @@ -215,6 +215,14 @@ pub(super) trait GoalKind<'tcx>: goal: Goal<'tcx, Self>, ) -> QueryResult<'tcx>; + /// A coroutine (that comes from a `gen` desugaring) is known to implement + /// `FusedIterator`, where `O` is given by the generator's yield type + /// that was computed during type-checking. + fn consider_builtin_fused_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx>; + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, @@ -497,6 +505,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> { G::consider_builtin_future_candidate(self, goal) } else if lang_items.iterator_trait() == Some(trait_def_id) { G::consider_builtin_iterator_candidate(self, goal) + } else if lang_items.fused_iterator_trait() == Some(trait_def_id) { + G::consider_builtin_fused_iterator_candidate(self, goal) } else if lang_items.async_iterator_trait() == Some(trait_def_id) { G::consider_builtin_async_iterator_candidate(self, goal) } else if lang_items.coroutine_trait() == Some(trait_def_id) { diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs index 85bb6338daff9..5d6570c7f82e5 100644 --- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs @@ -647,6 +647,41 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> { ) } + fn consider_builtin_fused_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + let self_ty = goal.predicate.self_ty(); + let ty::Coroutine(def_id, args) = *self_ty.kind() else { + return Err(NoSolution); + }; + + // Coroutines are not Iterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_gen(def_id) { + return Err(NoSolution); + } + + let Some(iterator_trait) = tcx.lang_items().iterator_trait() else { + return Err(NoSolution); + }; + + let term = args.as_coroutine().yield_ty().into(); + + Self::consider_implied_clause( + ecx, + goal, + ty::ProjectionPredicate { + projection_ty: ty::AliasTy::new(ecx.tcx(), iterator_trait, [self_ty]), + term, + } + .to_predicate(tcx), + // Technically, we need to check that the iterator type is Sized, + // but that's already proven by the generator being WF. + [], + ) + } + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index c252ad76dfe1d..0858d05f9246c 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -456,6 +456,13 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) } + fn consider_builtin_fused_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + Self::consider_builtin_iterator_candidate(ecx, goal) + } + fn consider_builtin_async_iterator_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 49091e53be713..2be11a1944435 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -118,6 +118,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.assemble_future_candidates(obligation, &mut candidates); } else if lang_items.iterator_trait() == Some(def_id) { self.assemble_iterator_candidates(obligation, &mut candidates); + } else if lang_items.fused_iterator_trait() == Some(def_id) { + self.assemble_fused_iterator_candidates(obligation, &mut candidates); } else if lang_items.async_iterator_trait() == Some(def_id) { self.assemble_async_iterator_candidates(obligation, &mut candidates); } else if lang_items.async_fn_kind_helper() == Some(def_id) { @@ -313,6 +315,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } + fn assemble_fused_iterator_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + let self_ty = obligation.self_ty().skip_binder(); + if let ty::Coroutine(did, ..) = self_ty.kind() { + // gen constructs get lowered to a special kind of coroutine that + // should directly `impl FusedIterator`. + if self.tcx().coroutine_is_gen(*did) { + debug!(?self_ty, ?obligation, "assemble_fused_iterator_candidates",); + + candidates.vec.push(FusedIteratorCandidate); + } + } + } + fn assemble_async_iterator_candidates( &mut self, obligation: &PolyTraitObligation<'tcx>, diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 51fc223a5d1b3..c8d54546df2aa 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -107,6 +107,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) } + FusedIteratorCandidate => { + let vtable_iterator = self.confirm_fused_iterator_candidate(obligation)?; + ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) + } + AsyncIteratorCandidate => { let vtable_iterator = self.confirm_async_iterator_candidate(obligation)?; ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) @@ -838,6 +843,35 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { Ok(nested) } + fn confirm_fused_iterator_candidate( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + ) -> Result>, SelectionError<'tcx>> { + // Okay to skip binder because the args on coroutine types never + // touch bound regions, they just capture the in-scope + // type/region parameters. + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else { + bug!("closure candidate for non-closure {:?}", obligation); + }; + + debug!(?obligation, ?coroutine_def_id, ?args, "confirm_fused_iterator_candidate"); + + let gen_sig = args.as_coroutine().sig(); + + let (trait_ref, _) = super::util::fused_iterator_trait_ref_and_outputs( + self.tcx(), + obligation.predicate.def_id(), + obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(), + gen_sig, + ); + + let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?; + debug!(?trait_ref, ?nested, "fused iterator candidate obligations"); + + Ok(nested) + } + fn confirm_async_iterator_candidate( &mut self, obligation: &PolyTraitObligation<'tcx>, diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 53aadfb8a44d8..ea10fa434ce25 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -1855,6 +1855,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate @@ -1887,6 +1888,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate @@ -1925,6 +1927,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate @@ -1943,6 +1946,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate @@ -2053,6 +2057,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate @@ -2067,6 +2072,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | FusedIteratorCandidate | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index 3f433a9e919a3..16cf5279f384b 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -323,6 +323,17 @@ pub fn iterator_trait_ref_and_outputs<'tcx>( (trait_ref, sig.yield_ty) } +pub fn fused_iterator_trait_ref_and_outputs<'tcx>( + tcx: TyCtxt<'tcx>, + fused_iterator_def_id: DefId, + self_ty: Ty<'tcx>, + sig: ty::GenSig<'tcx>, +) -> (ty::TraitRef<'tcx>, Ty<'tcx>) { + assert!(!self_ty.has_escaping_bound_vars()); + let trait_ref = ty::TraitRef::new(tcx, fused_iterator_def_id, [self_ty]); + (trait_ref, sig.yield_ty) +} + pub fn async_iterator_trait_ref_and_outputs<'tcx>( tcx: TyCtxt<'tcx>, async_iterator_def_id: DefId, diff --git a/library/core/src/iter/traits/marker.rs b/library/core/src/iter/traits/marker.rs index 8bdbca120d7f9..ad4d63d83b5be 100644 --- a/library/core/src/iter/traits/marker.rs +++ b/library/core/src/iter/traits/marker.rs @@ -28,6 +28,7 @@ pub unsafe trait TrustedFused {} #[rustc_unsafe_specialization_marker] // FIXME: this should be a #[marker] and have another blanket impl for T: TrustedFused // but that ICEs iter::Fuse specializations. +#[cfg_attr(not(bootstrap), lang = "fused_iterator")] pub trait FusedIterator: Iterator {} #[stable(feature = "fused", since = "1.26.0")] diff --git a/tests/ui/coroutine/gen_block_is_fused_iter.rs b/tests/ui/coroutine/gen_block_is_fused_iter.rs new file mode 100644 index 0000000000000..f3e19a7f54f03 --- /dev/null +++ b/tests/ui/coroutine/gen_block_is_fused_iter.rs @@ -0,0 +1,21 @@ +//@ revisions: next old +//@compile-flags: --edition 2024 -Zunstable-options +//@[next] compile-flags: -Znext-solver +//@ check-pass +#![feature(gen_blocks)] + +use std::iter::FusedIterator; + +fn foo() -> impl FusedIterator { + gen { yield 42 } +} + +fn bar() -> impl FusedIterator { + gen { yield 42 } +} + +fn baz() -> impl FusedIterator + Iterator { + gen { yield 42 } +} + +fn main() {}