Skip to content

Commit

Permalink
Fix incorrect impl of Iterator::size_hint and add impl of ExactSizeIt…
Browse files Browse the repository at this point in the history
…erator the NumPy iterator wrappers.
  • Loading branch information
adamreichold committed Mar 3, 2022
1 parent 72d967e commit 3363c8b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 13 deletions.
42 changes: 31 additions & 11 deletions src/npyiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
pub struct NpySingleIter<'py, T, I> {
iterator: ptr::NonNull<NpyIter>,
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
empty: bool,
iter_size: npy_intp,
dataptr: *mut *mut c_char,
return_type: PhantomData<T>,
Expand Down Expand Up @@ -324,7 +323,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
iterator,
iternext,
iter_size,
empty: iter_size == 0,
dataptr,
return_type: PhantomData,
mode: PhantomData,
Expand All @@ -334,15 +332,18 @@ impl<'py, T, I> NpySingleIter<'py, T, I> {
}

fn iternext(&mut self) -> Option<*mut T> {
if self.empty {
if self.iter_size == 0 {
None
} else {
// Note: This pointer is correct and doesn't need to be updated,
// note that we're derefencing a **char into a *char casting to a *T
// and then transforming that into a reference, the value that dataptr
// points to is being updated by iternext to point to the next value.
let ret = unsafe { *self.dataptr as *mut T };
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
debug_assert_ne!(self.iter_size, 0);
self.iter_size -= 1;
debug_assert!(self.iter_size > 0 || empty);
Some(ret)
}
}
Expand All @@ -368,7 +369,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, Readonly> {
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter_size as usize, Some(self.iter_size as usize))
(self.len(), Some(self.len()))
}
}

impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, Readonly> {
fn len(&self) -> usize {
self.iter_size as usize
}
}

Expand All @@ -380,7 +387,13 @@ impl<'py, T: 'py> Iterator for NpySingleIter<'py, T, ReadWrite> {
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter_size as usize, Some(self.iter_size as usize))
(self.len(), Some(self.len()))
}
}

impl<'py, T: 'py> ExactSizeIterator for NpySingleIter<'py, T, ReadWrite> {
fn len(&self) -> usize {
self.iter_size as usize
}
}

Expand Down Expand Up @@ -528,7 +541,6 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
iterator: ptr::NonNull<NpyIter>,
iternext: unsafe extern "C" fn(*mut NpyIter) -> c_int,
empty: bool,
iter_size: npy_intp,
dataptr: *mut *mut c_char,
marker: PhantomData<(T, S)>,
Expand Down Expand Up @@ -568,7 +580,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> {
iterator,
iternext,
iter_size,
empty: iter_size == 0,
dataptr,
marker: PhantomData,
arrays,
Expand Down Expand Up @@ -596,7 +607,7 @@ macro_rules! impl_multi_iter {
type Item = ($($ty,)+);

fn next(&mut self) -> Option<Self::Item> {
if self.empty {
if self.iter_size == 0 {
None
} else {
// Note: This pointer is correct and doesn't need to be updated,
Expand All @@ -605,13 +616,22 @@ macro_rules! impl_multi_iter {
// points to is being updated by iternext to point to the next value.
let ($($ptr,)+) = unsafe { $expand::<T>(self.dataptr) };
let retval = Some(unsafe { $deref });
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
let empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
debug_assert_ne!(self.iter_size, 0);
self.iter_size -= 1;
debug_assert!(self.iter_size > 0 || empty);
retval
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter_size as usize, Some(self.iter_size as usize))
(self.len(), Some(self.len()))
}
}

impl<'py, T: 'py> ExactSizeIterator for NpyMultiIter<'py, T, $structure> {
fn len(&self) -> usize {
self.iter_size as usize
}
}
};
Expand Down
49 changes: 47 additions & 2 deletions tests/iter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![allow(deprecated)]

use ndarray::array;
use numpy::{NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
use pyo3::PyResult;
use numpy::{pyarray, NpyMultiIterBuilder, NpySingleIterBuilder, PyArray};
use pyo3::{PyResult, Python};

macro_rules! assert_approx_eq {
($x: expr, $y: expr) => {
Expand Down Expand Up @@ -96,3 +96,48 @@ fn multiiter_rw() -> PyResult<()> {
Ok(())
})
}

#[test]
fn single_iter_size_hint_len() {
Python::with_gil(|py| {
let arr = pyarray![py, [0, 1], [2, 3], [4, 5]];

let mut iter = NpySingleIterBuilder::readonly(arr.readonly())
.build()
.unwrap();

for len in (1..=6).rev() {
assert_eq!(iter.len(), len);
assert_eq!(iter.size_hint(), (len, Some(len)));
assert!(iter.next().is_some());
}

assert_eq!(iter.len(), 0);
assert_eq!(iter.size_hint(), (0, Some(0)));
assert!(iter.next().is_none());
});
}

#[test]
fn multi_iter_size_hint_len() {
Python::with_gil(|py| {
let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]];
let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]];

let mut iter = NpyMultiIterBuilder::new()
.add_readonly(arr1.readonly())
.add_readonly(arr2.readonly())
.build()
.unwrap();

for len in (1..=6).rev() {
assert_eq!(iter.len(), len);
assert_eq!(iter.size_hint(), (len, Some(len)));
assert!(iter.next().is_some());
}

assert_eq!(iter.len(), 0);
assert_eq!(iter.size_hint(), (0, Some(0)));
assert!(iter.next().is_none());
});
}

0 comments on commit 3363c8b

Please sign in to comment.