Skip to content

Commit

Permalink
Implement FusedIterator for gen block
Browse files Browse the repository at this point in the history
  • Loading branch information
ShoyuVanilla committed Mar 21, 2024
1 parent 03994e4 commit 7bc4d25
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ language_item_table! {
FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;

Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
FusedIterator, sym::fused_iterator, fused_iterator_trait, Target::Trait, GenericRequirement::Exact(0);
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0);

Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ pub enum SelectionCandidate<'tcx> {
/// generated for a `gen` construct.
IteratorCandidate,

/// Implementation of an `FusedIterator` trait by one of the coroutine types
/// generated for a `gen` construct.
FusedIteratorCandidate,

/// Implementation of an `AsyncIterator` trait by one of the coroutine types
/// generated for a `async gen` construct.
AsyncIteratorCandidate,
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ impl<'tcx> Instance<'tcx> {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
);
hir::LangItem::FuturePoll
} else if Some(trait_id) == lang_items.iterator_trait() {
} else if Some(trait_id) == lang_items.iterator_trait()
|| Some(trait_id) == lang_items.fused_iterator_trait()
{
assert_matches!(
coroutine_kind,
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ symbols! {
FromResidual,
FsOpenOptions,
FsPermissions,
FusedIterator,
Future,
FutureOutput,
GlobalAlloc,
Expand Down Expand Up @@ -885,6 +886,7 @@ symbols! {
fsub_algebraic,
fsub_fast,
fundamental,
fused_iterator,
future,
future_trait,
gdb_script_file,
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_trait_selection/src/solve/assembly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ pub(super) trait GoalKind<'tcx>:
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx>;

/// A coroutine (that comes from a `gen` desugaring) is known to implement
/// `FusedIterator<Item = O>`, where `O` is given by the generator's yield type
/// that was computed during type-checking.
fn consider_builtin_fused_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx>;

fn consider_builtin_async_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
Expand Down Expand Up @@ -497,6 +505,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
G::consider_builtin_future_candidate(self, goal)
} else if lang_items.iterator_trait() == Some(trait_def_id) {
G::consider_builtin_iterator_candidate(self, goal)
} else if lang_items.fused_iterator_trait() == Some(trait_def_id) {
G::consider_builtin_fused_iterator_candidate(self, goal)
} else if lang_items.async_iterator_trait() == Some(trait_def_id) {
G::consider_builtin_async_iterator_candidate(self, goal)
} else if lang_items.coroutine_trait() == Some(trait_def_id) {
Expand Down
35 changes: 35 additions & 0 deletions compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,41 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> {
)
}

fn consider_builtin_fused_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx> {
let self_ty = goal.predicate.self_ty();
let ty::Coroutine(def_id, args) = *self_ty.kind() else {
return Err(NoSolution);
};

// Coroutines are not Iterators unless they come from `gen` desugaring
let tcx = ecx.tcx();
if !tcx.coroutine_is_gen(def_id) {
return Err(NoSolution);
}

let Some(iterator_trait) = tcx.lang_items().iterator_trait() else {
return Err(NoSolution);
};

let term = args.as_coroutine().yield_ty().into();

Self::consider_implied_clause(
ecx,
goal,
ty::ProjectionPredicate {
projection_ty: ty::AliasTy::new(ecx.tcx(), iterator_trait, [self_ty]),
term,
}
.to_predicate(tcx),
// Technically, we need to check that the iterator type is Sized,
// but that's already proven by the generator being WF.
[],
)
}

fn consider_builtin_async_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_trait_selection/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}

fn consider_builtin_fused_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
) -> QueryResult<'tcx> {
Self::consider_builtin_iterator_candidate(ecx, goal)
}

fn consider_builtin_async_iterator_candidate(
ecx: &mut EvalCtxt<'_, 'tcx>,
goal: Goal<'tcx, Self>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
self.assemble_future_candidates(obligation, &mut candidates);
} else if lang_items.iterator_trait() == Some(def_id) {
self.assemble_iterator_candidates(obligation, &mut candidates);
} else if lang_items.fused_iterator_trait() == Some(def_id) {
self.assemble_fused_iterator_candidates(obligation, &mut candidates);
} else if lang_items.async_iterator_trait() == Some(def_id) {
self.assemble_async_iterator_candidates(obligation, &mut candidates);
} else if lang_items.async_fn_kind_helper() == Some(def_id) {
Expand Down Expand Up @@ -313,6 +315,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
}

fn assemble_fused_iterator_candidates(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
candidates: &mut SelectionCandidateSet<'tcx>,
) {
let self_ty = obligation.self_ty().skip_binder();
if let ty::Coroutine(did, ..) = self_ty.kind() {
// gen constructs get lowered to a special kind of coroutine that
// should directly `impl FusedIterator`.
if self.tcx().coroutine_is_gen(*did) {
debug!(?self_ty, ?obligation, "assemble_fused_iterator_candidates",);

candidates.vec.push(FusedIteratorCandidate);
}
}
}

fn assemble_async_iterator_candidates(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
Expand Down
34 changes: 34 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
}

FusedIteratorCandidate => {
let vtable_iterator = self.confirm_fused_iterator_candidate(obligation)?;
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
}

AsyncIteratorCandidate => {
let vtable_iterator = self.confirm_async_iterator_candidate(obligation)?;
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
Expand Down Expand Up @@ -838,6 +843,35 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
Ok(nested)
}

fn confirm_fused_iterator_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};

debug!(?obligation, ?coroutine_def_id, ?args, "confirm_fused_iterator_candidate");

let gen_sig = args.as_coroutine().sig();

let (trait_ref, _) = super::util::fused_iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
debug!(?trait_ref, ?nested, "fused iterator candidate obligations");

Ok(nested)
}

fn confirm_async_iterator_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand Down Expand Up @@ -1887,6 +1888,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand Down Expand Up @@ -1925,6 +1927,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand All @@ -1943,6 +1946,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand Down Expand Up @@ -2053,6 +2057,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand All @@ -2067,6 +2072,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
| CoroutineCandidate
| FutureCandidate
| IteratorCandidate
| FusedIteratorCandidate
| AsyncIteratorCandidate
| FnPointerCandidate { .. }
| BuiltinObjectCandidate
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_trait_selection/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ pub fn iterator_trait_ref_and_outputs<'tcx>(
(trait_ref, sig.yield_ty)
}

pub fn fused_iterator_trait_ref_and_outputs<'tcx>(
tcx: TyCtxt<'tcx>,
fused_iterator_def_id: DefId,
self_ty: Ty<'tcx>,
sig: ty::GenSig<'tcx>,
) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
assert!(!self_ty.has_escaping_bound_vars());
let trait_ref = ty::TraitRef::new(tcx, fused_iterator_def_id, [self_ty]);
(trait_ref, sig.yield_ty)
}

pub fn async_iterator_trait_ref_and_outputs<'tcx>(
tcx: TyCtxt<'tcx>,
async_iterator_def_id: DefId,
Expand Down
1 change: 1 addition & 0 deletions library/core/src/iter/traits/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub unsafe trait TrustedFused {}
#[rustc_unsafe_specialization_marker]
// FIXME: this should be a #[marker] and have another blanket impl for T: TrustedFused
// but that ICEs iter::Fuse specializations.
#[cfg_attr(not(bootstrap), lang = "fused_iterator")]
pub trait FusedIterator: Iterator {}

#[stable(feature = "fused", since = "1.26.0")]
Expand Down
21 changes: 21 additions & 0 deletions tests/ui/coroutine/gen_block_is_fused_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//@ revisions: next old
//@compile-flags: --edition 2024 -Zunstable-options
//@[next] compile-flags: -Znext-solver
//@ check-pass
#![feature(gen_blocks)]

use std::iter::FusedIterator;

fn foo() -> impl FusedIterator {
gen { yield 42 }
}

fn bar() -> impl FusedIterator<Item = u16> {
gen { yield 42 }
}

fn baz() -> impl FusedIterator + Iterator<Item = i64> {
gen { yield 42 }
}

fn main() {}

0 comments on commit 7bc4d25

Please sign in to comment.