Skip to content

Commit

Permalink
Merge #473
Browse files Browse the repository at this point in the history
473: Add helper method for taking the k smallest elements in an iterator r=jswrenn a=nbraud



Co-authored-by: nicoo <[email protected]>
Co-authored-by: Giacomo Stevanato <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2020
2 parents 130ffd3 + f28ffd0 commit 00756e0
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ either = { version = "1.0", default-features = false }
[dev-dependencies]
rand = "0.7"
criterion = "=0" # TODO how could this work with our minimum supported rust version?
paste = "1.0.0" # Used in test_std to instanciate generic tests

[dev-dependencies.quickcheck]
version = "0.9"
Expand Down
20 changes: 20 additions & 0 deletions src/k_smallest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use alloc::collections::BinaryHeap;
use core::cmp::Ord;

pub(crate) fn k_smallest<T: Ord, I: Iterator<Item = T>>(mut iter: I, k: usize) -> BinaryHeap<T> {
if k == 0 { return BinaryHeap::new(); }

let mut heap = iter.by_ref().take(k).collect::<BinaryHeap<_>>();

for i in iter {
debug_assert_eq!(heap.len(), k);
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
// This should be done with a single `.peek_mut().unwrap()` but
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
if *heap.peek().unwrap() > i {
*heap.peek_mut().unwrap() = i;
}
}

heap
}
39 changes: 39 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ mod group_map;
mod groupbylazy;
mod intersperse;
#[cfg(feature = "use_alloc")]
mod k_smallest;
#[cfg(feature = "use_alloc")]
mod kmerge_impl;
#[cfg(feature = "use_alloc")]
mod lazy_buffer;
Expand Down Expand Up @@ -2419,6 +2421,43 @@ pub trait Itertools : Iterator {
v.into_iter()
}

/// Sort the k smallest elements into a new iterator, in ascending order.
///
/// **Note:** This consumes the entire iterator, and returns the result
/// as a new iterator that owns its elements. If the input contains
/// less than k elements, the result is equivalent to `self.sorted()`.
///
/// This is guaranteed to use `k * sizeof(Self::Item) + O(1)` memory
/// and `O(n log k)` time, with `n` the number of elements in the input.
///
/// The sorted iterator, if directly collected to a `Vec`, is converted
/// without any extra copying or allocation cost.
///
/// **Note:** This is functionally-equivalent to `self.sorted().take(k)`
/// but much more efficient.
///
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_smallest = numbers
/// .into_iter()
/// .k_smallest(5);
///
/// itertools::assert_equal(five_smallest, 0..5);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest(self, k: usize) -> VecIntoIter<Self::Item>
where Self: Sized,
Self::Item: Ord
{
crate::k_smallest::k_smallest(self, k)
.into_sorted_vec()
.into_iter()
}

/// Collect all iterator elements into one of two
/// partitions. Unlike `Iterator::partition`, each partition may
/// have a distinct type.
Expand Down
88 changes: 87 additions & 1 deletion tests/test_std.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use paste;
use permutohedron;
use quickcheck as qc;
use rand::{distributions::{Distribution, Standard}, Rng, SeedableRng, rngs::StdRng};
use rand::{seq::SliceRandom, thread_rng};
use std::{cmp::min, fmt::Debug, marker::PhantomData};
use itertools as it;
use crate::it::Itertools;
use crate::it::ExactlyOneError;
Expand Down Expand Up @@ -374,6 +379,88 @@ fn sorted_by() {
it::assert_equal(v, vec![4, 3, 2, 1, 0]);
}

qc::quickcheck! {
fn k_smallest_range(n: u64, m: u16, k: u16) -> () {
// u16 is used to constrain k and m to 0..2¹⁶,
// otherwise the test could use too much memory.
let (k, m) = (k as u64, m as u64);

// Generate a random permutation of n..n+m
let i = {
let mut v: Vec<u64> = (n..n.saturating_add(m)).collect();
v.shuffle(&mut thread_rng());
v.into_iter()
};

// Check that taking the k smallest elements yields n..n+min(k, m)
it::assert_equal(
i.k_smallest(k as usize),
n..n.saturating_add(min(k, m))
);
}
}

#[derive(Clone, Debug)]
struct RandIter<T: 'static + Clone + Send, R: 'static + Clone + Rng + SeedableRng + Send = StdRng> {
idx: usize,
len: usize,
rng: R,
_t: PhantomData<T>
}

impl<T: Clone + Send, R: Clone + Rng + SeedableRng + Send> Iterator for RandIter<T, R>
where Standard: Distribution<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if self.idx == self.len {
None
} else {
self.idx += 1;
Some(self.rng.gen())
}
}
}

impl<T: Clone + Send, R: Clone + Rng + SeedableRng + Send> qc::Arbitrary for RandIter<T, R> {
fn arbitrary<G: qc::Gen>(g: &mut G) -> Self {
RandIter {
idx: 0,
len: g.size(),
rng: R::seed_from_u64(g.next_u64()),
_t : PhantomData{},
}
}
}

// Check that taking the k smallest is the same as
// sorting then taking the k first elements
fn k_smallest_sort<I>(i: I, k: u16) -> ()
where
I: Iterator + Clone,
I::Item: Ord + Debug,
{
let j = i.clone();
let k = k as usize;
it::assert_equal(
i.k_smallest(k),
j.sorted().take(k)
)
}

macro_rules! generic_test {
($f:ident, $($t:ty),+) => {
$(paste::item! {
qc::quickcheck! {
fn [< $f _ $t >](i: RandIter<$t>, k: u16) -> () {
$f(i, k)
}
}
})+
};
}

generic_test!(k_smallest_sort, u8, u16, u32, u64, i8, i16, i32, i64);

#[test]
fn sorted_by_key() {
let sc = [3, 4, 1, 2].iter().cloned().sorted_by_key(|&x| x);
Expand Down Expand Up @@ -407,7 +494,6 @@ fn test_multipeek() {
assert_eq!(mp.next(), Some(5));
assert_eq!(mp.next(), None);
assert_eq!(mp.peek(), None);

}

#[test]
Expand Down

0 comments on commit 00756e0

Please sign in to comment.