Skip to content

Commit

Permalink
Fix returning invalid strides and dimensions for rank zero arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Mar 22, 2022
1 parent 9ec102d commit e75190b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Fix returning invalid slices from `PyArray::{strides,shape}` for rank zero arrays. ([#???](https://github.com/PyO3/rust-numpy/pull/???))

- v0.16.2
- Fix build on platforms where `c_char` is `u8` like Linux/AArch64. ([#296](https://github.com/PyO3/rust-numpy/pull/296))
Expand Down
9 changes: 9 additions & 0 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pyo3::{
Python, ToPyObject,
};

use crate::cold;
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
Expand Down Expand Up @@ -314,6 +315,10 @@ impl<T, D> PyArray<T, D> {
// C API: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
pub fn strides(&self) -> &[isize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).strides;
Expand All @@ -335,6 +340,10 @@ impl<T, D> PyArray<T, D> {
// C API: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
pub fn shape(&self) -> &[usize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).dimensions as *mut usize;
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ mod doctest {
doc_comment!(include_str!("../README.md"), readme);
}

#[cold]
fn cold() {}

/// Create a [`PyArray`] with one, two or three dimensions.
///
/// This macro is backed by [`ndarray::array`].
Expand Down
11 changes: 11 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ fn tuple_as_dim() {
});
}

#[test]
fn rank_zero_array_has_invalid_strides_dimensions() {
Python::with_gil(|py| {
let arr = PyArray::<f64, _>::zeros(py, (), false);

assert_eq!(arr.ndim(), 0);
assert_eq!(arr.strides(), &[]);
assert_eq!(arr.shape(), &[]);
})
}

#[test]
fn zeros() {
Python::with_gil(|py| {
Expand Down

0 comments on commit e75190b

Please sign in to comment.