diff --git a/Cargo.toml b/Cargo.toml index c6131625e..a22fc0b20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ either = { version = "1.0", default-features = false } [dev-dependencies] rand = "0.7" criterion = "=0" # TODO how could this work with our minimum supported rust version? +paste = "1.0.0" # Used in test_std to instanciate generic tests [dev-dependencies.quickcheck] version = "0.9" diff --git a/src/k_smallest.rs b/src/k_smallest.rs new file mode 100644 index 000000000..d58ec70d0 --- /dev/null +++ b/src/k_smallest.rs @@ -0,0 +1,20 @@ +use alloc::collections::BinaryHeap; +use core::cmp::Ord; + +pub(crate) fn k_smallest>(mut iter: I, k: usize) -> BinaryHeap { + if k == 0 { return BinaryHeap::new(); } + + let mut heap = iter.by_ref().take(k).collect::>(); + + for i in iter { + debug_assert_eq!(heap.len(), k); + // Equivalent to heap.push(min(i, heap.pop())) but more efficient. + // This should be done with a single `.peek_mut().unwrap()` but + // `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior. + if *heap.peek().unwrap() > i { + *heap.peek_mut().unwrap() = i; + } + } + + heap +} diff --git a/src/lib.rs b/src/lib.rs index f5b3617b2..d214910d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,6 +196,8 @@ mod group_map; mod groupbylazy; mod intersperse; #[cfg(feature = "use_alloc")] +mod k_smallest; +#[cfg(feature = "use_alloc")] mod kmerge_impl; #[cfg(feature = "use_alloc")] mod lazy_buffer; @@ -2419,6 +2421,43 @@ pub trait Itertools : Iterator { v.into_iter() } + /// Sort the k smallest elements into a new iterator, in ascending order. + /// + /// **Note:** This consumes the entire iterator, and returns the result + /// as a new iterator that owns its elements. If the input contains + /// less than k elements, the result is equivalent to `self.sorted()`. + /// + /// This is guaranteed to use `k * sizeof(Self::Item) + O(1)` memory + /// and `O(n log k)` time, with `n` the number of elements in the input. + /// + /// The sorted iterator, if directly collected to a `Vec`, is converted + /// without any extra copying or allocation cost. + /// + /// **Note:** This is functionally-equivalent to `self.sorted().take(k)` + /// but much more efficient. + /// + /// ``` + /// use itertools::Itertools; + /// + /// // A random permutation of 0..15 + /// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5]; + /// + /// let five_smallest = numbers + /// .into_iter() + /// .k_smallest(5); + /// + /// itertools::assert_equal(five_smallest, 0..5); + /// ``` + #[cfg(feature = "use_alloc")] + fn k_smallest(self, k: usize) -> VecIntoIter + where Self: Sized, + Self::Item: Ord + { + crate::k_smallest::k_smallest(self, k) + .into_sorted_vec() + .into_iter() + } + /// Collect all iterator elements into one of two /// partitions. Unlike `Iterator::partition`, each partition may /// have a distinct type. diff --git a/tests/test_std.rs b/tests/test_std.rs index e0468db05..d1ff815da 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -1,4 +1,9 @@ +use paste; use permutohedron; +use quickcheck as qc; +use rand::{distributions::{Distribution, Standard}, Rng, SeedableRng, rngs::StdRng}; +use rand::{seq::SliceRandom, thread_rng}; +use std::{cmp::min, fmt::Debug, marker::PhantomData}; use itertools as it; use crate::it::Itertools; use crate::it::ExactlyOneError; @@ -374,6 +379,88 @@ fn sorted_by() { it::assert_equal(v, vec![4, 3, 2, 1, 0]); } +qc::quickcheck! { + fn k_smallest_range(n: u64, m: u16, k: u16) -> () { + // u16 is used to constrain k and m to 0..2¹⁶, + // otherwise the test could use too much memory. + let (k, m) = (k as u64, m as u64); + + // Generate a random permutation of n..n+m + let i = { + let mut v: Vec = (n..n.saturating_add(m)).collect(); + v.shuffle(&mut thread_rng()); + v.into_iter() + }; + + // Check that taking the k smallest elements yields n..n+min(k, m) + it::assert_equal( + i.k_smallest(k as usize), + n..n.saturating_add(min(k, m)) + ); + } +} + +#[derive(Clone, Debug)] +struct RandIter { + idx: usize, + len: usize, + rng: R, + _t: PhantomData +} + +impl Iterator for RandIter +where Standard: Distribution { + type Item = T; + fn next(&mut self) -> Option { + if self.idx == self.len { + None + } else { + self.idx += 1; + Some(self.rng.gen()) + } + } +} + +impl qc::Arbitrary for RandIter { + fn arbitrary(g: &mut G) -> Self { + RandIter { + idx: 0, + len: g.size(), + rng: R::seed_from_u64(g.next_u64()), + _t : PhantomData{}, + } + } +} + +// Check that taking the k smallest is the same as +// sorting then taking the k first elements +fn k_smallest_sort(i: I, k: u16) -> () +where + I: Iterator + Clone, + I::Item: Ord + Debug, +{ + let j = i.clone(); + let k = k as usize; + it::assert_equal( + i.k_smallest(k), + j.sorted().take(k) + ) +} + +macro_rules! generic_test { + ($f:ident, $($t:ty),+) => { + $(paste::item! { + qc::quickcheck! { + fn [< $f _ $t >](i: RandIter<$t>, k: u16) -> () { + $f(i, k) + } + } + })+ + }; +} + +generic_test!(k_smallest_sort, u8, u16, u32, u64, i8, i16, i32, i64); + #[test] fn sorted_by_key() { let sc = [3, 4, 1, 2].iter().cloned().sorted_by_key(|&x| x); @@ -407,7 +494,6 @@ fn test_multipeek() { assert_eq!(mp.next(), Some(5)); assert_eq!(mp.next(), None); assert_eq!(mp.peek(), None); - } #[test]