diff --git a/library/core/src/iter/adapters/flatten.rs b/library/core/src/iter/adapters/flatten.rs index 99344a88efc3f..145c9d3dacc84 100644 --- a/library/core/src/iter/adapters/flatten.rs +++ b/library/core/src/iter/adapters/flatten.rs @@ -3,7 +3,7 @@ use crate::iter::{ Cloned, Copied, Filter, FilterMap, Fuse, FusedIterator, InPlaceIterable, Map, TrustedFused, TrustedLen, }; -use crate::iter::{Once, OnceWith}; +use crate::iter::{Empty, Once, OnceWith}; use crate::num::NonZero; use crate::ops::{ControlFlow, Try}; use crate::result; @@ -593,6 +593,7 @@ where } } +// See also the `OneShot` specialization below. impl Iterator for FlattenCompat where I: Iterator>, @@ -601,7 +602,7 @@ where type Item = U::Item; #[inline] - fn next(&mut self) -> Option { + default fn next(&mut self) -> Option { loop { if let elt @ Some(_) = and_then_or_clear(&mut self.frontiter, Iterator::next) { return elt; @@ -614,7 +615,7 @@ where } #[inline] - fn size_hint(&self) -> (usize, Option) { + default fn size_hint(&self) -> (usize, Option) { let (flo, fhi) = self.frontiter.as_ref().map_or((0, Some(0)), U::size_hint); let (blo, bhi) = self.backiter.as_ref().map_or((0, Some(0)), U::size_hint); let lo = flo.saturating_add(blo); @@ -636,7 +637,7 @@ where } #[inline] - fn try_fold(&mut self, init: Acc, fold: Fold) -> R + default fn try_fold(&mut self, init: Acc, fold: Fold) -> R where Self: Sized, Fold: FnMut(Acc, Self::Item) -> R, @@ -653,7 +654,7 @@ where } #[inline] - fn fold(self, init: Acc, fold: Fold) -> Acc + default fn fold(self, init: Acc, fold: Fold) -> Acc where Fold: FnMut(Acc, Self::Item) -> Acc, { @@ -669,7 +670,7 @@ where #[inline] #[rustc_inherit_overflow_checks] - fn advance_by(&mut self, n: usize) -> Result<(), NonZero> { + default fn advance_by(&mut self, n: usize) -> Result<(), NonZero> { #[inline] #[rustc_inherit_overflow_checks] fn advance(n: usize, iter: &mut U) -> ControlFlow<(), usize> { @@ -686,7 +687,7 @@ where } #[inline] - fn count(self) -> usize { + default fn count(self) -> usize { #[inline] #[rustc_inherit_overflow_checks] fn count(acc: usize, iter: U) -> usize { @@ -697,7 +698,7 @@ where } #[inline] - fn last(self) -> Option { + default fn last(self) -> Option { #[inline] fn last(last: Option, iter: U) -> Option { iter.last().or(last) @@ -707,13 +708,14 @@ where } } +// See also the `OneShot` specialization below. impl DoubleEndedIterator for FlattenCompat where I: DoubleEndedIterator>, U: DoubleEndedIterator, { #[inline] - fn next_back(&mut self) -> Option { + default fn next_back(&mut self) -> Option { loop { if let elt @ Some(_) = and_then_or_clear(&mut self.backiter, |b| b.next_back()) { return elt; @@ -726,7 +728,7 @@ where } #[inline] - fn try_rfold(&mut self, init: Acc, fold: Fold) -> R + default fn try_rfold(&mut self, init: Acc, fold: Fold) -> R where Self: Sized, Fold: FnMut(Acc, Self::Item) -> R, @@ -743,7 +745,7 @@ where } #[inline] - fn rfold(self, init: Acc, fold: Fold) -> Acc + default fn rfold(self, init: Acc, fold: Fold) -> Acc where Fold: FnMut(Acc, Self::Item) -> Acc, { @@ -759,7 +761,7 @@ where #[inline] #[rustc_inherit_overflow_checks] - fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero> { + default fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero> { #[inline] #[rustc_inherit_overflow_checks] fn advance(n: usize, iter: &mut U) -> ControlFlow<(), usize> { @@ -841,3 +843,198 @@ fn and_then_or_clear(opt: &mut Option, f: impl FnOnce(&mut T) -> Option } x } + +/// Specialization trait for iterator types that never return more than one item. +/// +/// Note that we still have to deal with the possibility that the iterator was +/// already exhausted before it came into our control. +#[rustc_specialization_trait] +trait OneShot {} + +// These all have exactly one item, if not already consumed. +impl OneShot for Once {} +impl OneShot for OnceWith {} +impl OneShot for array::IntoIter {} +impl OneShot for option::IntoIter {} +impl OneShot for option::Iter<'_, T> {} +impl OneShot for option::IterMut<'_, T> {} +impl OneShot for result::IntoIter {} +impl OneShot for result::Iter<'_, T> {} +impl OneShot for result::IterMut<'_, T> {} + +// These are always empty, which is fine to optimize too. +impl OneShot for Empty {} +impl OneShot for array::IntoIter {} + +// These adaptors never increase the number of items. +// (There are more possible, but for now this matches BoundedSize above.) +impl OneShot for Cloned {} +impl OneShot for Copied {} +impl OneShot for Filter {} +impl OneShot for FilterMap {} +impl OneShot for Map {} + +// Blanket impls pass this property through as well +// (but we can't do `Box` unless we expose this trait to alloc) +impl OneShot for &mut I {} + +#[inline] +fn into_item(inner: I) -> Option +where + I: IntoIterator, +{ + inner.into_iter().next() +} + +#[inline] +fn flatten_one, Acc>( + mut fold: impl FnMut(Acc, I::Item) -> Acc, +) -> impl FnMut(Acc, I) -> Acc { + move |acc, inner| match inner.into_iter().next() { + Some(item) => fold(acc, item), + None => acc, + } +} + +#[inline] +fn try_flatten_one, Acc, R: Try>( + mut fold: impl FnMut(Acc, I::Item) -> R, +) -> impl FnMut(Acc, I) -> R { + move |acc, inner| match inner.into_iter().next() { + Some(item) => fold(acc, item), + None => try { acc }, + } +} + +#[inline] +fn advance_by_one(n: NonZero, inner: I) -> Option> +where + I: IntoIterator, +{ + match inner.into_iter().next() { + Some(_) => NonZero::new(n.get() - 1), + None => Some(n), + } +} + +// Specialization: When the inner iterator `U` never returns more than one item, the `frontiter` and +// `backiter` states are a waste, because they'll always have already consumed their item. So in +// this impl, we completely ignore them and just focus on `self.iter`, and we only call the inner +// `U::next()` one time. +// +// It's mostly fine if we accidentally mix this with the more generic impls, e.g. by forgetting to +// specialize one of the methods. If the other impl did set the front or back, we wouldn't see it +// here, but it would be empty anyway; and if the other impl looked for a front or back that we +// didn't bother setting, it would just see `None` (or a previous empty) and move on. +// +// An exception to that is `advance_by(0)` and `advance_back_by(0)`, where the generic impls may set +// `frontiter` or `backiter` without consuming the item, so we **must** override those. +impl Iterator for FlattenCompat +where + I: Iterator>, + U: Iterator + OneShot, +{ + #[inline] + fn next(&mut self) -> Option { + while let Some(inner) = self.iter.next() { + if let item @ Some(_) = inner.into_iter().next() { + return item; + } + } + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.iter.size_hint(); + match ::size() { + Some(0) => (0, Some(0)), + Some(1) => (lower, upper), + _ => (0, upper), + } + } + + #[inline] + fn try_fold(&mut self, init: Acc, fold: Fold) -> R + where + Self: Sized, + Fold: FnMut(Acc, Self::Item) -> R, + R: Try, + { + self.iter.try_fold(init, try_flatten_one(fold)) + } + + #[inline] + fn fold(self, init: Acc, fold: Fold) -> Acc + where + Fold: FnMut(Acc, Self::Item) -> Acc, + { + self.iter.fold(init, flatten_one(fold)) + } + + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), NonZero> { + if let Some(n) = NonZero::new(n) { + self.iter.try_fold(n, advance_by_one).map_or(Ok(()), Err) + } else { + // Just advance the outer iterator + self.iter.advance_by(0) + } + } + + #[inline] + fn count(self) -> usize { + self.iter.filter_map(into_item).count() + } + + #[inline] + fn last(self) -> Option { + self.iter.filter_map(into_item).last() + } +} + +// Note: We don't actually care about `U: DoubleEndedIterator`, since forward and backward are the +// same for a one-shot iterator, but we have to keep that to match the default specialization. +impl DoubleEndedIterator for FlattenCompat +where + I: DoubleEndedIterator>, + U: DoubleEndedIterator + OneShot, +{ + #[inline] + fn next_back(&mut self) -> Option { + while let Some(inner) = self.iter.next_back() { + if let item @ Some(_) = inner.into_iter().next() { + return item; + } + } + None + } + + #[inline] + fn try_rfold(&mut self, init: Acc, fold: Fold) -> R + where + Self: Sized, + Fold: FnMut(Acc, Self::Item) -> R, + R: Try, + { + self.iter.try_rfold(init, try_flatten_one(fold)) + } + + #[inline] + fn rfold(self, init: Acc, fold: Fold) -> Acc + where + Fold: FnMut(Acc, Self::Item) -> Acc, + { + self.iter.rfold(init, flatten_one(fold)) + } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero> { + if let Some(n) = NonZero::new(n) { + self.iter.try_rfold(n, advance_by_one).map_or(Ok(()), Err) + } else { + // Just advance the outer iterator + self.iter.advance_back_by(0) + } + } +} diff --git a/library/core/tests/iter/adapters/flatten.rs b/library/core/tests/iter/adapters/flatten.rs index 2af7e0c388a3b..1f953f2aa0110 100644 --- a/library/core/tests/iter/adapters/flatten.rs +++ b/library/core/tests/iter/adapters/flatten.rs @@ -212,3 +212,69 @@ fn test_flatten_last() { assert_eq!(it.advance_by(3), Ok(())); // 22..22 assert_eq!(it.clone().last(), None); } + +#[test] +fn test_flatten_one_shot() { + // This could be `filter_map`, but people often do flatten options. + let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7)); + assert_eq!(it.size_hint(), (0, Some(10))); + assert_eq!(it.clone().count(), 8); + assert_eq!(it.clone().last(), NonZero::new(2)); + + // sum -> fold + let sum: i8 = it.clone().map(|n| n.get()).sum(); + assert_eq!(sum, 24); + + // the product overflows at 6, remaining are 7,8,9 -> 1,2 + let one = NonZero::new(1i8).unwrap(); + let product = it.try_fold(one, |acc, x| acc.checked_mul(x)); + assert_eq!(product, None); + assert_eq!(it.size_hint(), (0, Some(3))); + assert_eq!(it.clone().count(), 2); + + assert_eq!(it.advance_by(0), Ok(())); + assert_eq!(it.clone().next(), NonZero::new(1)); + assert_eq!(it.advance_by(1), Ok(())); + assert_eq!(it.clone().next(), NonZero::new(2)); + assert_eq!(it.advance_by(100), Err(NonZero::new(99).unwrap())); + assert_eq!(it.next(), None); +} + +#[test] +fn test_flatten_one_shot_rev() { + let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7)).rev(); + assert_eq!(it.size_hint(), (0, Some(10))); + assert_eq!(it.clone().count(), 8); + assert_eq!(it.clone().last(), NonZero::new(1)); + + // sum -> Rev fold -> rfold + let sum: i8 = it.clone().map(|n| n.get()).sum(); + assert_eq!(sum, 24); + + // Rev try_fold -> try_rfold + // the product overflows at 4, remaining are 3,2,1,0 -> 3,2,1 + let one = NonZero::new(1i8).unwrap(); + let product = it.try_fold(one, |acc, x| acc.checked_mul(x)); + assert_eq!(product, None); + assert_eq!(it.size_hint(), (0, Some(4))); + assert_eq!(it.clone().count(), 3); + + // Rev advance_by -> advance_back_by + assert_eq!(it.advance_by(0), Ok(())); + assert_eq!(it.clone().next(), NonZero::new(3)); + assert_eq!(it.advance_by(1), Ok(())); + assert_eq!(it.clone().next(), NonZero::new(2)); + assert_eq!(it.advance_by(100), Err(NonZero::new(98).unwrap())); + assert_eq!(it.next(), None); +} + +#[test] +fn test_flatten_one_shot_arrays() { + let it = (0..10).flat_map(|i| [i]); + assert_eq!(it.size_hint(), (10, Some(10))); + assert_eq!(it.sum::(), 45); + + let mut it = (0..10).flat_map(|_| -> [i32; 0] { [] }); + assert_eq!(it.size_hint(), (0, Some(0))); + assert_eq!(it.next(), None); +}