diff --git a/CHANGELOG.md b/CHANGELOG.md index 6044b090f..a883e770e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog - Unreleased + - Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274)) - Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292)) - v0.16.2 diff --git a/benches/borrow.rs b/benches/borrow.rs new file mode 100644 index 000000000..a091ddabf --- /dev/null +++ b/benches/borrow.rs @@ -0,0 +1,48 @@ +#![feature(test)] + +extern crate test; +use test::{black_box, Bencher}; + +use numpy::PyArray; +use pyo3::Python; + +#[bench] +fn initial_shared_borrow(bencher: &mut Bencher) { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + bencher.iter(|| { + let array = black_box(array); + + let _shared = array.readonly(); + }); + }); +} + +#[bench] +fn additional_shared_borrow(bencher: &mut Bencher) { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _shared = (0..128).map(|_| array.readonly()).collect::>(); + + bencher.iter(|| { + let array = black_box(array); + + let _shared = array.readonly(); + }); + }); +} + +#[bench] +fn exclusive_borrow(bencher: &mut Bencher) { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + bencher.iter(|| { + let array = black_box(array); + + let _exclusive = array.readwrite(); + }); + }); +} diff --git a/examples/simple/src/lib.rs b/examples/simple/src/lib.rs index 2c1e2751d..c381f9d04 100644 --- a/examples/simple/src/lib.rs +++ b/examples/simple/src/lib.rs @@ -1,5 +1,7 @@ use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD}; -use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn}; +use numpy::{ + Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn, +}; use pyo3::{ pymodule, types::{PyDict, PyModule}, @@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // wrapper of `mult` #[pyfn(m)] #[pyo3(name = "mult")] - fn mult_py(a: f64, x: &PyArrayDyn) { - let x = unsafe { x.as_array_mut() }; + fn mult_py(a: f64, mut x: PyReadwriteArrayDyn) { + let x = x.as_array_mut(); mult(a, x); } diff --git a/src/array.rs b/src/array.rs index eccf0e5f1..4ee5279ca 100644 --- a/src/array.rs +++ b/src/array.rs @@ -19,14 +19,12 @@ use pyo3::{ Python, ToPyObject, }; +use crate::borrow::{PyReadonlyArray, PyReadwriteArray}; use crate::cold; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; use crate::dtype::{Element, PyArrayDescr}; use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; -#[allow(deprecated)] -use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite}; -use crate::readonly::PyReadonlyArray; use crate::slice_container::PySliceContainer; /// A safe, static-typed interface for @@ -195,18 +193,8 @@ impl PyArray { } #[inline(always)] - fn check_flag(&self, flag: c_int) -> bool { - unsafe { *self.as_array_ptr() }.flags & flag == flag - } - - #[inline(always)] - pub(crate) fn get_flag(&self) -> c_int { - unsafe { *self.as_array_ptr() }.flags - } - - /// Returns a temporally unwriteable reference of the array. - pub fn readonly(&self) -> PyReadonlyArray { - self.into() + pub(crate) fn check_flags(&self, flags: c_int) -> bool { + unsafe { *self.as_array_ptr() }.flags & flags != 0 } /// Returns `true` if the internal data of the array is C-style contiguous @@ -228,18 +216,17 @@ impl PyArray { /// }); /// ``` pub fn is_contiguous(&self) -> bool { - self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS) - | self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS) + self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) } /// Returns `true` if the internal data of the array is Fortran-style contiguous. pub fn is_fortran_contiguous(&self) -> bool { - self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS) + self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS) } /// Returns `true` if the internal data of the array is C-style contiguous. pub fn is_c_contiguous(&self) -> bool { - self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS) + self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS) } /// Get `Py` from `&PyArray`, which is the owned wrapper of PyObject. @@ -681,27 +668,61 @@ impl PyArray { /// Get the immutable reference of the specified element, with checking the passed index is valid. /// - /// Please consider the use of safe alternatives - /// ([`PyReadonlyArray::get`](../struct.PyReadonlyArray.html#method.get) - /// or [`get_owned`](#method.get_owned)) instead of this. + /// Consider using safe alternatives like [`PyReadonlyArray::get`]. + /// /// # Example + /// /// ``` /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { /// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap(); - /// assert_eq!(*unsafe { arr.get([1, 0, 3]) }.unwrap(), 11); + /// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 11); /// }); /// ``` /// /// # Safety - /// If the internal array is not readonly and can be mutated from Python code, - /// holding the slice might cause undefined behavior. + /// + /// Calling this method is undefined behaviour if the underlying array + /// is aliased mutably by other instances of `PyArray` + /// or concurrently modified by Python or other native code. #[inline(always)] pub unsafe fn get(&self, index: impl NpyIndex) -> Option<&T> { let offset = index.get_checked::(self.shape(), self.strides())?; Some(&*self.data().offset(offset)) } + /// Same as [`get`][Self::get], but returns `Option<&mut T>`. + /// + /// Consider using safe alternatives like [`PyReadwriteArray::get_mut`]. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap(); + /// unsafe { + /// *arr.get_mut([1, 0, 3]).unwrap() = 42; + /// } + /// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 42); + /// }); + /// ``` + /// + /// # Safety + /// + /// Calling this method is undefined behaviour if the underlying array + /// is aliased immutably by mutably by other instances of `PyArray` + /// or concurrently modified by Python or other native code. + #[inline(always)] + pub unsafe fn get_mut(&self, index: impl NpyIndex) -> Option<&mut T> { + let offset = index.get_checked::(self.shape(), self.strides())?; + Some(&mut *self.data().offset(offset)) + } + /// Get the immutable reference of the specified element, without checking the /// passed index is valid. /// @@ -824,16 +845,23 @@ impl PyArray { ToPyArray::to_pyarray(arr, py) } - /// Get the immutable view of the internal data of `PyArray`, as - /// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). + /// Get an immutable borrow of the NumPy array + pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> { + PyReadonlyArray::try_new(self).unwrap() + } + + /// Get a mutable borrow of the NumPy array + pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> { + PyReadwriteArray::try_new(self).unwrap() + } + + /// Returns the internal array as [`ArrayView`]. /// - /// Please consider the use of safe alternatives - /// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array) - /// or [`to_array`](#method.to_array)) instead of this. + /// See also [`PyReadonlyArray::as_array`]. /// /// # Safety - /// If the internal array is not readonly and can be mutated from Python code, - /// holding the `ArrayView` might cause undefined behavior. + /// + /// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior. pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> { let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); let mut res = ArrayView::from_shape_ptr(shape, ptr); @@ -841,11 +869,13 @@ impl PyArray { res } - /// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array). + /// Returns the internal array as [`ArrayViewMut`]. + /// + /// See also [`PyReadwriteArray::as_array_mut`]. /// /// # Safety - /// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`), - /// it might cause undefined behavior. + /// + /// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior. pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> { let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr(); let mut res = ArrayViewMut::from_shape_ptr(shape, ptr); @@ -921,7 +951,7 @@ impl PyArray { /// /// let pyarray = PyArray::from_owned_object_array(py, array); /// - /// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance_of::().unwrap()); + /// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance_of::().unwrap()); /// }); /// ``` pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array, D>) -> &'py Self { @@ -1043,21 +1073,6 @@ impl PyArray { self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER) } - /// Iterates all elements of this array. - /// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more. - /// - /// # Safety - /// - /// The iterator will produce mutable references into the array which must not be - /// aliased by other references for the life time of the iterator. - #[deprecated( - note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead." - )] - #[allow(deprecated)] - pub unsafe fn iter<'py>(&'py self) -> PyResult> { - NpySingleIterBuilder::readwrite(self).build() - } - fn resize_( &self, py: Python, diff --git a/src/borrow.rs b/src/borrow.rs new file mode 100644 index 000000000..76a13e364 --- /dev/null +++ b/src/borrow.rs @@ -0,0 +1,589 @@ +//! Types to safely create references into NumPy arrays +//! +//! It is assumed that unchecked code - which includes unsafe Rust and Python - is validated by its author +//! which together with the dynamic borrow checking performed by this crate ensures that +//! safe Rust code cannot cause undefined behaviour by creating references into NumPy arrays. +//! +//! With these borrows established, [references to individual elements][PyReadonlyArray::get] or [reference-based views of whole array][PyReadonlyArray::as_array] +//! can be created safely. These are then the starting point for algorithms iteraing over and operating on the elements of the array. +//! +//! # Examples +//! +//! The first example shows that dynamic borrow checking works to constrain +//! both what safe Rust code can invoke and how it is invoked. +//! +//! ```rust +//! # use std::panic::{catch_unwind, AssertUnwindSafe}; +//! # +//! use numpy::PyArray1; +//! use ndarray::Zip; +//! use pyo3::Python; +//! +//! fn add(x: &PyArray1, y: &PyArray1, z: &PyArray1) { +//! let x1 = x.readonly(); +//! let y1 = y.readonly(); +//! let mut z1 = z.readwrite(); +//! +//! let x2 = x1.as_array(); +//! let y2 = y1.as_array(); +//! let z2 = z1.as_array_mut(); +//! +//! Zip::from(x2) +//! .and(y2) +//! .and(z2) +//! .for_each(|x3, y3, z3| *z3 = x3 + y3); +//! +//! // Will fail at runtime due to conflict with `x1`. +//! let res = catch_unwind(AssertUnwindSafe(|| { +//! let _x4 = x.readwrite(); +//! })); +//! assert!(res.is_err()); +//! } +//! +//! Python::with_gil(|py| { +//! let x = PyArray1::::zeros(py, 42, false); +//! let y = PyArray1::::zeros(py, 42, false); +//! let z = PyArray1::::zeros(py, 42, false); +//! +//! // Will work as the three arrays are distinct. +//! add(x, y, z); +//! +//! // Will work as `x1` and `y1` are compatible borrows. +//! add(x, x, z); +//! +//! // Will fail at runtime due to conflict between `y1` and `z1`. +//! let res = catch_unwind(AssertUnwindSafe(|| { +//! add(x, y, y); +//! })); +//! assert!(res.is_err()); +//! }); +//! ``` +//! +//! The second example shows that non-overlapping and interleaved views which do not alias +//! are currently not supported due to over-approximating which borrows are in conflict. +//! +//! ```rust +//! # use std::panic::{catch_unwind, AssertUnwindSafe}; +//! # +//! use numpy::PyArray1; +//! use pyo3::{types::IntoPyDict, Python}; +//! +//! Python::with_gil(|py| { +//! let array = PyArray1::arange(py, 0.0, 10.0, 1.0); +//! let locals = [("array", array)].into_py_dict(py); +//! +//! let view1 = py.eval("array[:5]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! let view2 = py.eval("array[5:]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::>().unwrap(); +//! +//! // Will fail at runtime even though `view1` and `view2` +//! // do not overlap as they are based on the same array. +//! let res = catch_unwind(AssertUnwindSafe(|| { +//! let _view1 = view1.readwrite(); +//! let _view2 = view2.readwrite(); +//! })); +//! assert!(res.is_err()); +//! +//! // Will fail at runtime even though `view3` and `view4` +//! // interleave as they are based on the same array. +//! let res = catch_unwind(AssertUnwindSafe(|| { +//! let _view3 = view3.readwrite(); +//! let _view4 = view4.readwrite(); +//! })); +//! assert!(res.is_err()); +//! }); +//! ``` +//! +//! # Rationale +//! +//! Rust references require aliasing discipline to be maintained, i.e. there must always +//! exist only a single mutable (aka exclusive) reference or multiple immutable (aka shared) references +//! for each object, otherwise the program contains undefined behaviour. +//! +//! The aim of this module is to ensure that safe Rust code is unable to violate these requirements on its own. +//! We cannot prevent unchecked code - this includes unsafe Rust, Python or other native code like C or Fortran - +//! from violating them. Therefore the responsibility to avoid this lies with the author of that code instead of the compiler. +//! However, assuming that the unchecked code is correct, we can ensure that safe Rust is unable to introduce mistakes +//! into an otherwise correct program by dynamically checking which arrays are currently borrowed and in what manner. +//! +//! This means that we follow the [base object chain][base] of each array to the original allocation backing it and +//! track which parts of that allocation are covered by the array and thereby ensure that only a single read-write array +//! or multiple read-only arrays overlapping with that region are borrowed at any time. +//! +//! In contrast to Rust references, the mere existence of Python references or raw pointers is not an issue +//! because these values are not assumed to follow aliasing discipline by the Rust compiler. +//! +//! This cannot prevent unchecked code from concurrently modifying an array via callbacks or using multiple threads, +//! but that would lead to incorrect results even if the code that is interfered with is implemented in another language +//! which does not require aliasing discipline. +//! +//! Concerning multi-threading in particular: While the GIL needs to be acquired to create borrows, they are not bound to the GIL +//! and will stay active after the GIL is released, for example by calling [`allow_threads`][pyo3::Python::allow_threads]. +//! Borrows also do not provide synchronization, i.e. multiple threads borrowing the same array will lead to runtime panics, +//! it will not block those threads until already active borrows are released. +//! +//! In summary, this crate takes the position that all unchecked code - unsafe Rust, Python, C, Fortran, etc. - must be checked for correctness by its author. +//! Safe Rust code can then rely on this correctness, but should not be able to introduce memory safety issues on its own. Additionally, dynamic borrow checking +//! can catch _some_ mistakes introduced by unchecked code, e.g. Python calling a function with the same array as an input and as an output argument. +//! +//! # Limitations +//! +//! Note that the current implementation of this is an over-approximation: It will consider all borrows potentially conflicting +//! if the initial arrays have the same object at the end of their [base object chain][base]. +//! For example, creating two views of the same underlying array by slicing will always yield potentially conflicting borrows +//! even if the slice indices are chosen so that the two views do not actually share any elements by splitting the array into +//! non-overlapping parts of by interleaving along one of its axes. +//! +//! This does limit the set of programs that can be written using safe Rust in way similar to rustc itself +//! which ensures that all accepted programs are memory safe but does not necessarily accept all memory safe programs. +//! The plan is to refine this checking to correctly handle more involved cases like non-overlapping and interleaved +//! views into the same array and until then the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch. +//! +//! [base]: https://numpy.org/doc/stable/reference/c-api/types-and-structures.html#c.NPY_AO.base +#![deny(missing_docs)] + +use std::cell::UnsafeCell; +use std::collections::hash_map::{Entry, HashMap}; +use std::ops::Deref; + +use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; +use pyo3::{FromPyObject, PyAny, PyResult}; + +use crate::array::PyArray; +use crate::cold; +use crate::convert::NpyIndex; +use crate::dtype::Element; +use crate::error::{BorrowError, NotContiguousError}; +use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE}; + +struct BorrowFlags(UnsafeCell>>); + +unsafe impl Sync for BorrowFlags {} + +impl BorrowFlags { + const fn new() -> Self { + Self(UnsafeCell::new(None)) + } + + #[allow(clippy::mut_from_ref)] + unsafe fn get(&self) -> &mut HashMap { + (*self.0.get()).get_or_insert_with(HashMap::new) + } +} + +static BORROW_FLAGS: BorrowFlags = BorrowFlags::new(); + +/// Read-only borrow of an array. +/// +/// An instance of this type ensures that there are no instances of [`PyReadwriteArray`], +/// i.e. that only shared references into the interior of the array can be created safely. +/// +/// See the [module-level documentation](self) for more. +pub struct PyReadonlyArray<'py, T, D>(&'py PyArray); + +/// Read-only borrow of a one-dimensional array. +pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>; + +/// Read-only borrow of a two-dimensional array. +pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>; + +/// Read-only borrow of a three-dimensional array. +pub type PyReadonlyArray3<'py, T> = PyReadonlyArray<'py, T, Ix3>; + +/// Read-only borrow of a four-dimensional array. +pub type PyReadonlyArray4<'py, T> = PyReadonlyArray<'py, T, Ix4>; + +/// Read-only borrow of a five-dimensional array. +pub type PyReadonlyArray5<'py, T> = PyReadonlyArray<'py, T, Ix5>; + +/// Read-only borrow of a six-dimensional array. +pub type PyReadonlyArray6<'py, T> = PyReadonlyArray<'py, T, Ix6>; + +/// Read-only borrow of an array whose dimensionality is determined at runtime. +pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>; + +impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D> { + type Target = PyArray; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadonlyArray<'py, T, D> { + fn extract(obj: &'py PyAny) -> PyResult { + let array: &'py PyArray = obj.extract()?; + Ok(array.readonly()) + } +} + +impl<'py, T, D> PyReadonlyArray<'py, T, D> +where + T: Element, + D: Dimension, +{ + pub(crate) fn try_new(array: &'py PyArray) -> Result { + let address = base_address(array); + + // SAFETY: Access to a `&'py PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let readers = entry.into_mut(); + + let new_readers = readers.wrapping_add(1); + + if new_readers <= 0 { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + *readers = new_readers; + } + Entry::Vacant(entry) => { + entry.insert(1); + } + } + + Ok(Self(array)) + } + + /// Provides an immutable array view of the interior of the NumPy array. + #[inline(always)] + pub fn as_array(&self) -> ArrayView { + // SAFETY: Global borrow flags ensure aliasing discipline. + unsafe { self.0.as_array() } + } + + /// Provide an immutable slice view of the interior of the NumPy array if it is contiguous. + #[inline(always)] + pub fn as_slice(&self) -> Result<&[T], NotContiguousError> { + // SAFETY: Global borrow flags ensure aliasing discipline. + unsafe { self.0.as_slice() } + } + + /// Provide an immutable reference to an element of the NumPy array if the index is within bounds. + #[inline(always)] + pub fn get(&self, index: I) -> Option<&T> + where + I: NpyIndex, + { + unsafe { self.0.get(index) } + } +} + +impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> { + fn drop(&mut self) { + let address = base_address(self.0); + + // SAFETY: Access to a `&'py PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + let readers = borrow_flags.get_mut(&address).unwrap(); + + *readers -= 1; + + if *readers == 0 { + borrow_flags.remove(&address).unwrap(); + } + } +} + +/// Read-write borrow of an array. +/// +/// An instance of this type ensures that there are no instances of [`PyReadonlyArray`] and no other instances of [`PyReadwriteArray`], +/// i.e. that only a single exclusive reference into the interior of the array can be created safely. +/// +/// See the [module-level documentation](self) for more. +pub struct PyReadwriteArray<'py, T, D>(&'py PyArray); + +/// Read-write borrow of a one-dimensional array. +pub type PyReadwriteArray1<'py, T> = PyReadwriteArray<'py, T, Ix1>; + +/// Read-write borrow of a two-dimensional array. +pub type PyReadwriteArray2<'py, T> = PyReadwriteArray<'py, T, Ix2>; + +/// Read-write borrow of a three-dimensional array. +pub type PyReadwriteArray3<'py, T> = PyReadwriteArray<'py, T, Ix3>; + +/// Read-write borrow of a four-dimensional array. +pub type PyReadwriteArray4<'py, T> = PyReadwriteArray<'py, T, Ix4>; + +/// Read-write borrow of a five-dimensional array. +pub type PyReadwriteArray5<'py, T> = PyReadwriteArray<'py, T, Ix5>; + +/// Read-write borrow of a six-dimensional array. +pub type PyReadwriteArray6<'py, T> = PyReadwriteArray<'py, T, Ix6>; + +/// Read-write borrow of an array whose dimensionality is determined at runtime. +pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>; + +impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D> { + type Target = PyReadonlyArray<'py, T, D>; + + fn deref(&self) -> &Self::Target { + // SAFETY: Exclusive references decay implictly into shared references. + unsafe { &*(self as *const Self as *const Self::Target) } + } +} + +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> { + fn extract(obj: &'py PyAny) -> PyResult { + let array: &'py PyArray = obj.extract()?; + Ok(array.readwrite()) + } +} + +impl<'py, T, D> PyReadwriteArray<'py, T, D> +where + T: Element, + D: Dimension, +{ + pub(crate) fn try_new(array: &'py PyArray) -> Result { + if !array.check_flags(NPY_ARRAY_WRITEABLE) { + return Err(BorrowError::NotWriteable); + } + + let address = base_address(array); + + // SAFETY: Access to a `&'py PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let writers = entry.into_mut(); + + if *writers != 0 { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + *writers = -1; + } + Entry::Vacant(entry) => { + entry.insert(-1); + } + } + + Ok(Self(array)) + } + + /// Provides a mutable array view of the interior of the NumPy array. + #[inline(always)] + pub fn as_array_mut(&mut self) -> ArrayViewMut { + // SAFETY: Global borrow flags ensure aliasing discipline. + unsafe { self.0.as_array_mut() } + } + + /// Provide a mutable slice view of the interior of the NumPy array if it is contiguous. + #[inline(always)] + pub fn as_slice_mut(&mut self) -> Result<&mut [T], NotContiguousError> { + // SAFETY: Global borrow flags ensure aliasing discipline. + unsafe { self.0.as_slice_mut() } + } + + /// Provide a mutable reference to an element of the NumPy array if the index is within bounds. + #[inline(always)] + pub fn get_mut(&mut self, index: I) -> Option<&mut T> + where + I: NpyIndex, + { + unsafe { self.0.get_mut(index) } + } +} + +impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> { + fn drop(&mut self) { + let address = base_address(self.0); + + // SAFETY: Access to a `&'py PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + borrow_flags.remove(&address).unwrap(); + } +} + +// FIXME(adamreichold): This is a coarse approximation and needs to be refined, +// i.e. borrows of non-overlapping views into the same base should not be considered conflicting. +fn base_address(array: &PyArray) -> usize { + let py = array.py(); + let mut array = array.as_array_ptr(); + + loop { + let base = unsafe { (*array).base }; + + if base.is_null() { + return array as usize; + } else if unsafe { npyffi::PyArray_Check(py, base) } != 0 { + array = base as *mut PyArrayObject; + } else { + return base as usize; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use ndarray::Array; + use pyo3::{types::IntoPyDict, Python}; + + use crate::array::{PyArray1, PyArray2}; + use crate::convert::IntoPyArray; + + #[test] + fn without_base_object() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let base = unsafe { (*array.as_array_ptr()).base }; + assert!(base.is_null()); + + let base_address = base_address(array); + assert_eq!(base_address, array as *const _ as usize); + }); + } + + #[test] + fn with_base_object() { + Python::with_gil(|py| { + let array = Array::::zeros((1, 2, 3)).into_pyarray(py); + + let base = unsafe { (*array.as_array_ptr()).base }; + assert!(!base.is_null()); + + let base_address = base_address(array); + assert_ne!(base_address, array as *const _ as usize); + assert_eq!(base_address, base as usize); + }); + } + + #[test] + fn view_without_base_object() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let locals = [("array", array)].into_py_dict(py); + let view = py + .eval("array[:,:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view as *const _ as usize, array as *const _ as usize); + + let base = unsafe { (*view.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base_address = base_address(view); + assert_ne!(base_address, view as *const _ as usize); + assert_eq!(base_address, base as usize); + }); + } + + #[test] + fn view_with_base_object() { + Python::with_gil(|py| { + let array = Array::::zeros((1, 2, 3)).into_pyarray(py); + + let locals = [("array", array)].into_py_dict(py); + let view = py + .eval("array[:,:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view as *const _ as usize, array as *const _ as usize); + + let base = unsafe { (*view.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base = unsafe { (*array.as_array_ptr()).base }; + assert!(!base.is_null()); + + let base_address = base_address(view); + assert_ne!(base_address, view as *const _ as usize); + assert_ne!(base_address, array as *const _ as usize); + assert_eq!(base_address, base as usize); + }); + } + + #[test] + fn view_of_view_without_base_object() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let locals = [("array", array)].into_py_dict(py); + let view1 = py + .eval("array[:,:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view1 as *const _ as usize, array as *const _ as usize); + + let locals = [("view1", view1)].into_py_dict(py); + let view2 = py + .eval("view1[:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view2 as *const _ as usize, array as *const _ as usize); + assert_ne!(view2 as *const _ as usize, view1 as *const _ as usize); + + let base = unsafe { (*view2.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base = unsafe { (*view1.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base_address = base_address(view2); + assert_ne!(base_address, view2 as *const _ as usize); + assert_ne!(base_address, view1 as *const _ as usize); + assert_eq!(base_address, base as usize); + }); + } + + #[test] + fn view_of_view_with_base_object() { + Python::with_gil(|py| { + let array = Array::::zeros((1, 2, 3)).into_pyarray(py); + + let locals = [("array", array)].into_py_dict(py); + let view1 = py + .eval("array[:,:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view1 as *const _ as usize, array as *const _ as usize); + + let locals = [("view1", view1)].into_py_dict(py); + let view2 = py + .eval("view1[:,0]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_ne!(view2 as *const _ as usize, array as *const _ as usize); + assert_ne!(view2 as *const _ as usize, view1 as *const _ as usize); + + let base = unsafe { (*view2.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base = unsafe { (*view1.as_array_ptr()).base }; + assert_eq!(base as usize, array as *const _ as usize); + + let base = unsafe { (*array.as_array_ptr()).base }; + assert!(!base.is_null()); + + let base_address = base_address(view2); + assert_ne!(base_address, view2 as *const _ as usize); + assert_ne!(base_address, view1 as *const _ as usize); + assert_ne!(base_address, array as *const _ as usize); + assert_eq!(base_address, base as usize); + }); + } +} diff --git a/src/error.rs b/src/error.rs index 19200e213..6844861df 100644 --- a/src/error.rs +++ b/src/error.rs @@ -115,3 +115,22 @@ impl fmt::Display for NotContiguousError { } impl_pyerr!(NotContiguousError); + +/// Inidcates why borrowing an array failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum BorrowError { + AlreadyBorrowed, + NotWriteable, +} + +impl fmt::Display for BorrowError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::AlreadyBorrowed => write!(f, "The given array is already borrowed"), + Self::NotWriteable => write!(f, "The given array is not writeable"), + } + } +} + +impl_pyerr!(BorrowError); diff --git a/src/lib.rs b/src/lib.rs index fa3191703..1a6e8aedf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,12 +36,12 @@ #![allow(clippy::needless_lifetimes)] pub mod array; +pub mod borrow; pub mod convert; mod dtype; mod error; pub mod npyffi; pub mod npyiter; -mod readonly; mod slice_container; mod sum_products; @@ -52,6 +52,12 @@ pub use crate::array::{ get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6, PyArrayDyn, }; +pub use crate::borrow::{ + PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4, + PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray, PyReadwriteArray1, + PyReadwriteArray2, PyReadwriteArray3, PyReadwriteArray4, PyReadwriteArray5, PyReadwriteArray6, + PyReadwriteArrayDyn, +}; pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; @@ -60,10 +66,6 @@ pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; pub use crate::npyiter::{ IterMode, NpyIterFlag, NpyMultiIter, NpyMultiIterBuilder, NpySingleIter, NpySingleIterBuilder, }; -pub use crate::readonly::{ - PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4, - PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, -}; pub use crate::sum_products::{dot, einsum_impl, inner}; pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; diff --git a/src/npyiter.rs b/src/npyiter.rs index 1fc4eb9e9..2dfd809af 100644 --- a/src/npyiter.rs +++ b/src/npyiter.rs @@ -24,19 +24,18 @@ use std::ptr; use ndarray::Dimension; use pyo3::{PyErr, PyNativeType, PyResult, Python}; -use crate::array::{PyArray, PyArrayDyn}; +use crate::array::PyArrayDyn; +use crate::borrow::{PyReadonlyArray, PyReadwriteArray}; use crate::dtype::Element; use crate::npyffi::{ array::PY_ARRAY_API, npy_intp, npy_uint32, - objects::{NpyIter, PyArrayObject}, + objects::NpyIter, types::{NPY_CASTING, NPY_ORDER}, - NPY_ARRAY_WRITEABLE, NPY_ITER_BUFFERED, NPY_ITER_COMMON_DTYPE, NPY_ITER_COPY_IF_OVERLAP, - NPY_ITER_DELAY_BUFALLOC, NPY_ITER_DONT_NEGATE_STRIDES, NPY_ITER_GROWINNER, NPY_ITER_RANGED, - NPY_ITER_READONLY, NPY_ITER_READWRITE, NPY_ITER_REDUCE_OK, NPY_ITER_REFS_OK, - NPY_ITER_ZEROSIZE_OK, + NPY_ITER_BUFFERED, NPY_ITER_COMMON_DTYPE, NPY_ITER_COPY_IF_OVERLAP, NPY_ITER_DELAY_BUFALLOC, + NPY_ITER_DONT_NEGATE_STRIDES, NPY_ITER_GROWINNER, NPY_ITER_RANGED, NPY_ITER_READONLY, + NPY_ITER_READWRITE, NPY_ITER_REDUCE_OK, NPY_ITER_REFS_OK, NPY_ITER_ZEROSIZE_OK, }; -use crate::readonly::PyReadonlyArray; use crate::sealed::Sealed; /// Flags for constructing an iterator. @@ -159,36 +158,26 @@ pub struct NpySingleIterBuilder<'py, T, I: IterMode> { flags: npy_uint32, array: &'py PyArrayDyn, mode: PhantomData, - was_writable: bool, } impl<'py, T: Element> NpySingleIterBuilder<'py, T, Readonly> { - /// Create a new builder for a readonly iterator. - pub fn readonly(array: PyReadonlyArray<'py, T, D>) -> Self { - let (array, was_writable) = array.destruct(); - + /// Makes a new builder for a readonly iterator. + pub fn readonly(array: &'py PyReadonlyArray<'_, T, D>) -> Self { Self { flags: NPY_ITER_READONLY, array: array.to_dyn(), mode: PhantomData, - was_writable, } } } impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> { - /// Create a new builder for a writable iterator. - /// - /// # Safety - /// - /// The iterator will produce mutable references into the array which must not be - /// aliased by other references for the life time of the iterator. - pub unsafe fn readwrite(array: &'py PyArray) -> Self { + /// Makes a new builder for a writable iterator. + pub fn readwrite(array: &'py mut PyReadwriteArray<'_, T, D>) -> Self { Self { flags: NPY_ITER_READWRITE, array: array.to_dyn(), mode: PhantomData, - was_writable: false, } } } @@ -217,13 +206,7 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { ) }; - let readonly_array_ptr = if self.was_writable { - Some(array_ptr) - } else { - None - }; - - NpySingleIter::new(iter_ptr, readonly_array_ptr, py) + NpySingleIter::new(iter_ptr, py) } } @@ -242,8 +225,9 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { /// /// Python::with_gil(|py| { /// let array = PyArray::arange(py, 0, 10, 1); +/// let mut array = array.readwrite(); /// -/// let iter = unsafe { NpySingleIterBuilder::readwrite(array).build().unwrap() }; +/// let iter = NpySingleIterBuilder::readwrite(&mut array).build().unwrap(); /// /// for (i, elem) in iter.enumerate() { /// assert_eq!(*elem, i as i64); @@ -261,8 +245,9 @@ impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> { /// /// Python::with_gil(|py| { /// let array = PyArray::arange(py, 0, 1, 10); +/// let array = array.readonly(); /// -/// let iter = NpySingleIterBuilder::readonly(array.readonly()).build().unwrap(); +/// let iter = NpySingleIterBuilder::readonly(&array).build().unwrap(); /// /// for (i, elem) in iter.enumerate() { /// assert_eq!(*elem, i as i64); @@ -276,16 +261,11 @@ pub struct NpySingleIter<'py, T, I> { dataptr: *mut *mut c_char, return_type: PhantomData, mode: PhantomData, - readonly_array_ptr: Option<*mut PyArrayObject>, py: Python<'py>, } impl<'py, T, I> NpySingleIter<'py, T, I> { - fn new( - iterator: *mut NpyIter, - readonly_array_ptr: Option<*mut PyArrayObject>, - py: Python<'py>, - ) -> PyResult { + fn new(iterator: *mut NpyIter, py: Python<'py>) -> PyResult { let mut iterator = match ptr::NonNull::new(iterator) { Some(iter) => iter, None => return Err(PyErr::fetch(py)), @@ -313,7 +293,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { dataptr, return_type: PhantomData, mode: PhantomData, - readonly_array_ptr, py, }) } @@ -339,12 +318,6 @@ impl<'py, T, I> NpySingleIter<'py, T, I> { impl<'py, T, I> Drop for NpySingleIter<'py, T, I> { fn drop(&mut self) { let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.py, self.iterator.as_mut()) }; - - if let Some(ptr) = self.readonly_array_ptr { - unsafe { - (*ptr).flags |= NPY_ARRAY_WRITEABLE; - } - } } } @@ -389,7 +362,6 @@ pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> { flags: npy_uint32, arrays: Vec<&'py PyArrayDyn>, structure: PhantomData, - was_writables: Vec, } impl<'py, T: Element> Default for NpyMultiIterBuilder<'py, T, ()> { @@ -405,7 +377,6 @@ impl<'py, T: Element> NpyMultiIterBuilder<'py, T, ()> { flags: 0, arrays: Vec::new(), structure: PhantomData, - was_writables: Vec::new(), } } @@ -421,38 +392,25 @@ impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> { /// Add a readonly array to the resulting iterator. pub fn add_readonly( mut self, - array: PyReadonlyArray<'py, T, D>, + array: &'py PyReadonlyArray<'_, T, D>, ) -> NpyMultiIterBuilder<'py, T, RO> { - let (array, was_writable) = array.destruct(); - self.arrays.push(array.to_dyn()); - self.was_writables.push(was_writable); - NpyMultiIterBuilder { flags: self.flags, arrays: self.arrays, - was_writables: self.was_writables, structure: PhantomData, } } /// Adds a writable array to the resulting iterator. - /// - /// # Safety - /// - /// The iterator will produce mutable references into the array which must not be - /// aliased by other references for the life time of the iterator. - pub unsafe fn add_readwrite( + pub fn add_readwrite( mut self, - array: &'py PyArray, + array: &'py mut PyReadwriteArray<'_, T, D>, ) -> NpyMultiIterBuilder<'py, T, RW> { self.arrays.push(array.to_dyn()); - self.was_writables.push(false); - NpyMultiIterBuilder { flags: self.flags, arrays: self.arrays, - was_writables: self.was_writables, structure: PhantomData, } } @@ -461,12 +419,7 @@ impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> { impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T, S> { /// Creates an iterator from this builder. pub fn build(self) -> PyResult> { - let Self { - flags, - arrays, - was_writables, - .. - } = self; + let Self { flags, arrays, .. } = self; debug_assert!(arrays.len() <= i32::MAX as usize); debug_assert!(2 <= arrays.len()); @@ -490,7 +443,7 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T ) }; - NpyMultiIter::new(iter_ptr, arrays, was_writables, py) + NpyMultiIter::new(iter_ptr, py) } } @@ -507,17 +460,18 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T /// /// Python::with_gil(|py| { /// let array1 = numpy::PyArray::arange(py, 0, 10, 1); +/// let array1 = array1.readonly(); /// let array2 = numpy::PyArray::arange(py, 10, 20, 1); +/// let mut array2 = array2.readwrite(); /// let array3 = numpy::PyArray::arange(py, 10, 30, 2); +/// let array3 = array3.readonly(); /// -/// let iter = unsafe { -/// NpyMultiIterBuilder::new() -/// .add_readonly(array1.readonly()) -/// .add_readwrite(array2) -/// .add_readonly(array3.readonly()) +/// let iter = NpyMultiIterBuilder::new() +/// .add_readonly(&array1) +/// .add_readwrite(&mut array2) +/// .add_readonly(&array3) /// .build() -/// .unwrap() -/// }; +/// .unwrap(); /// /// for (i, j, k) in iter { /// assert_eq!(*i + *j, *k); @@ -531,18 +485,11 @@ pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> { iter_size: npy_intp, dataptr: *mut *mut c_char, marker: PhantomData<(T, S)>, - arrays: Vec<*mut PyArrayObject>, - was_writables: Vec, py: Python<'py>, } impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { - fn new( - iterator: *mut NpyIter, - arrays: Vec<*mut PyArrayObject>, - was_writables: Vec, - py: Python<'py>, - ) -> PyResult { + fn new(iterator: *mut NpyIter, py: Python<'py>) -> PyResult { let mut iterator = match ptr::NonNull::new(iterator) { Some(ptr) => ptr, None => return Err(PyErr::fetch(py)), @@ -569,8 +516,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { iter_size, dataptr, marker: PhantomData, - arrays, - was_writables, py, }) } @@ -579,12 +524,6 @@ impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> { impl<'py, T, S: MultiIterModeWithManyArrays> Drop for NpyMultiIter<'py, T, S> { fn drop(&mut self) { let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.py, self.iterator.as_mut()) }; - - for (array_ptr, &was_writable) in self.arrays.iter().zip(self.was_writables.iter()) { - if was_writable { - unsafe { (**array_ptr).flags |= NPY_ARRAY_WRITEABLE }; - } - } } } diff --git a/src/readonly.rs b/src/readonly.rs deleted file mode 100644 index cce7d9f7a..000000000 --- a/src/readonly.rs +++ /dev/null @@ -1,224 +0,0 @@ -//! Readonly arrays -use ndarray::{ArrayView, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; -use pyo3::{prelude::*, types::PyAny, AsPyPointer}; - -use crate::npyffi::NPY_ARRAY_WRITEABLE; -#[allow(deprecated)] -use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, Readonly}; -use crate::{Element, NotContiguousError, NpyIndex, PyArray}; - -/// Readonly reference of [`PyArray`](../array/struct.PyArray.html). -/// -/// This struct ensures that the internal array is not writeable while holding `PyReadonlyArray`. -/// We use a simple trick for this: modifying the internal flag of the array when creating -/// `PyReadonlyArray` and recover the original flag when it drops. -/// -/// So, importantly, it does not recover the original flag when it does not drop -/// (e.g., by the use of `IntoPy::intopy` or `std::mem::forget`) -/// and then the internal array remains readonly. -/// -/// # Example -/// In this example, we get a 'temporal' readonly array and the internal array -/// becomes writeble again after it drops. -/// ``` -/// use numpy::{PyArray, npyffi::NPY_ARRAY_WRITEABLE}; -/// pyo3::Python::with_gil(|py| { -/// let py_array = PyArray::arange(py, 0, 4, 1).reshape([2, 2]).unwrap(); -/// { -/// let readonly = py_array.readonly(); -/// // The internal array is not writeable now. -/// pyo3::py_run!(py, py_array, "assert not py_array.flags['WRITEABLE']"); -/// } -/// // After the `readonly` drops, the internal array gets writeable again. -/// pyo3::py_run!(py, py_array, "assert py_array.flags['WRITEABLE']"); -/// }); -/// ``` -/// However, if we convert the `PyReadonlyArray` directly into `PyObject`, -/// the internal array remains readonly. -/// ``` -/// use numpy::{PyArray, npyffi::NPY_ARRAY_WRITEABLE}; -/// use pyo3::{IntoPy, PyObject, Python}; -/// pyo3::Python::with_gil(|py| { -/// let py_array = PyArray::arange(py, 0, 4, 1).reshape([2, 2]).unwrap(); -/// let obj: PyObject = { -/// let readonly = py_array.readonly(); -/// // The internal array is not writeable now. -/// pyo3::py_run!(py, py_array, "assert not py_array.flags['WRITEABLE']"); -/// readonly.into_py(py) -/// }; -/// // The internal array remains readonly. -/// pyo3::py_run!(py, py_array, "assert py_array.flags['WRITEABLE']"); -/// }); -/// ``` -pub struct PyReadonlyArray<'py, T, D> { - array: &'py PyArray, - was_writeable: bool, -} - -impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> { - /// Returns the immutable view of the internal data of `PyArray` as slice. - /// - /// Returns `ErrorKind::NotContiguous` if the internal array is not contiguous. - /// # Example - /// ``` - /// use numpy::{PyArray, PyArray1}; - /// use pyo3::types::IntoPyDict; - /// pyo3::Python::with_gil(|py| { - /// let py_array = PyArray::arange(py, 0, 4, 1).reshape([2, 2]).unwrap(); - /// let readonly = py_array.readonly(); - /// assert_eq!(readonly.as_slice().unwrap(), &[0, 1, 2, 3]); - /// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py); - /// let not_contiguous: &PyArray1 = py - /// .eval("np.arange(10, dtype='int32')[::2]", Some(locals), None) - /// .unwrap() - /// .downcast() - /// .unwrap(); - /// assert!(not_contiguous.readonly().as_slice().is_err()); - /// }); - /// ``` - pub fn as_slice(&self) -> Result<&[T], NotContiguousError> { - unsafe { self.array.as_slice() } - } - - /// Get the immutable view of the internal data of `PyArray`, as - /// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). - /// - /// # Example - /// ``` - /// # #[macro_use] extern crate ndarray; - /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { - /// let array = PyArray::arange(py, 0, 4, 1).reshape([2, 2]).unwrap(); - /// let readonly = array.readonly(); - /// assert_eq!(readonly.as_array(), array![[0, 1], [2, 3]]); - /// }); - /// ``` - pub fn as_array(&self) -> ArrayView<'_, T, D> { - unsafe { self.array.as_array() } - } - - /// Get an immutable reference of the specified element, with checking the passed index is valid. - /// - /// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index. - /// - /// If you pass an invalid index to this function, it returns `None`. - /// - /// # Example - /// ``` - /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { - /// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap().readonly(); - /// assert_eq!(*arr.get([1, 0, 3]).unwrap(), 11); - /// assert!(arr.get([2, 0, 3]).is_none()); - /// }); - /// ``` - /// - /// For fixed dimension arrays, passing an index with invalid dimension causes compile error. - /// ```compile_fail - /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { - /// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap().readonly(); - /// let a = arr.get([1, 2]); // Compile Error! - /// }); - /// ``` - /// - /// However, for dinamic arrays, we cannot raise a compile error and just returns `None`. - /// ``` - /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { - /// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap().readonly(); - /// let arr = arr.to_dyn().readonly(); - /// assert!(arr.get([1, 2].as_ref()).is_none()); - /// }); - /// ``` - #[inline(always)] - pub fn get(&self, index: impl NpyIndex) -> Option<&T> { - unsafe { self.array.get(index) } - } - - /// Iterates all elements of this array. - /// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more. - #[deprecated( - note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter` instead." - )] - #[allow(deprecated)] - pub fn iter(self) -> PyResult> { - NpySingleIterBuilder::readonly(self).build() - } - - pub(crate) fn destruct(self) -> (&'py PyArray, bool) { - let PyReadonlyArray { - array, - was_writeable, - } = self; - (array, was_writeable) - } -} - -/// One-dimensional readonly array. -pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>; -/// Two-dimensional readonly array. -pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>; -/// Three-dimensional readonly array. -pub type PyReadonlyArray3<'py, T> = PyReadonlyArray<'py, T, Ix3>; -/// Four-dimensional readonly array. -pub type PyReadonlyArray4<'py, T> = PyReadonlyArray<'py, T, Ix4>; -/// Five-dimensional readonly array. -pub type PyReadonlyArray5<'py, T> = PyReadonlyArray<'py, T, Ix5>; -/// Six-dimensional readonly array. -pub type PyReadonlyArray6<'py, T> = PyReadonlyArray<'py, T, Ix6>; -/// Dynamic-dimensional readonly array. -pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>; - -impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadonlyArray<'py, T, D> { - fn extract(obj: &'py PyAny) -> PyResult { - let array: &PyArray = obj.extract()?; - Ok(PyReadonlyArray::from(array)) - } -} - -impl<'py, T, D> IntoPy for PyReadonlyArray<'py, T, D> { - fn into_py(self, py: Python<'_>) -> PyObject { - let PyReadonlyArray { array, .. } = self; - unsafe { PyObject::from_borrowed_ptr(py, array.as_ptr()) } - } -} - -impl<'py, T, D> From<&'py PyArray> for PyReadonlyArray<'py, T, D> { - fn from(array: &'py PyArray) -> PyReadonlyArray<'py, T, D> { - let flag = array.get_flag(); - let writeable = flag & NPY_ARRAY_WRITEABLE != 0; - if writeable { - unsafe { - (*array.as_array_ptr()).flags &= !NPY_ARRAY_WRITEABLE; - } - } - Self { - array, - was_writeable: writeable, - } - } -} - -impl<'py, T, D> Drop for PyReadonlyArray<'py, T, D> { - fn drop(&mut self) { - if self.was_writeable { - unsafe { - (*self.array.as_array_ptr()).flags |= NPY_ARRAY_WRITEABLE; - } - } - } -} - -impl<'py, T, D> AsRef> for PyReadonlyArray<'py, T, D> { - fn as_ref(&self) -> &PyArray { - self.array - } -} - -impl<'py, T, D> std::ops::Deref for PyReadonlyArray<'py, T, D> { - type Target = PyArray; - fn deref(&self) -> &PyArray { - self.array - } -} diff --git a/tests/borrow.rs b/tests/borrow.rs new file mode 100644 index 000000000..c090b5852 --- /dev/null +++ b/tests/borrow.rs @@ -0,0 +1,237 @@ +use std::thread::spawn; + +use numpy::{ + npyffi::NPY_ARRAY_WRITEABLE, PyArray, PyArray1, PyArray2, PyReadonlyArray3, PyReadwriteArray3, +}; +use pyo3::{py_run, pyclass, pymethods, types::IntoPyDict, Py, PyAny, Python}; + +#[test] +fn distinct_borrows() { + Python::with_gil(|py| { + let array1 = PyArray::::zeros(py, (1, 2, 3), false); + let array2 = PyArray::::zeros(py, (1, 2, 3), false); + + let exclusive1 = array1.readwrite(); + let exclusive2 = array2.readwrite(); + + assert_eq!(exclusive2.shape(), [1, 2, 3]); + assert_eq!(exclusive1.shape(), [1, 2, 3]); + }); +} + +#[test] +fn multiple_shared_borrows() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let shared1 = array.readonly(); + let shared2 = array.readonly(); + + assert_eq!(shared2.shape(), [1, 2, 3]); + assert_eq!(shared1.shape(), [1, 2, 3]); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn exclusive_and_shared_borrows() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _exclusive = array.readwrite(); + let _shared = array.readonly(); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn multiple_exclusive_borrows() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _exclusive1 = array.readwrite(); + let _exclusive2 = array.readwrite(); + }); +} + +#[test] +#[should_panic(expected = "NotWriteable")] +fn exclusive_borrow_requires_writeable() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + unsafe { + (*array.as_array_ptr()).flags &= !NPY_ARRAY_WRITEABLE; + } + + let _exclusive = array.readwrite(); + }); +} + +#[test] +#[should_panic(expected = "Unwrapped panic from Python code")] +fn borrows_span_frames() { + #[pyclass] + struct Borrower; + + #[pymethods] + impl Borrower { + fn shared(&self, _array: PyReadonlyArray3) {} + + fn exclusive(&self, _array: PyReadwriteArray3) {} + } + + Python::with_gil(|py| { + let borrower = Py::new(py, Borrower).unwrap(); + + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _exclusive = array.readwrite(); + + py_run!(py, borrower array, "borrower.exclusive(array)"); + }); +} + +#[test] +fn borrows_span_threads() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let _exclusive = array.readwrite(); + + let array = array.to_owned(); + + py.allow_threads(move || { + let thread = spawn(move || { + Python::with_gil(|py| { + let array = array.as_ref(py); + + let _exclusive = array.readwrite(); + }); + }); + + assert!(thread.join().is_err()); + }); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn overlapping_views_conflict() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[0,0,0:2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view1.shape(), [2]); + + let view2 = py + .eval("array[0,0,1:3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [2]); + + let _exclusive1 = view1.readwrite(); + let _exclusive2 = view2.readwrite(); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn non_overlapping_views_conflict() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[0,0,0:1]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view1.shape(), [1]); + + let view2 = py + .eval("array[0,0,2:3]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [1]); + + let _exclusive1 = view1.readwrite(); + let _exclusive2 = view2.readwrite(); + }); +} + +#[test] +#[should_panic(expected = "AlreadyBorrowed")] +fn interleaved_views_conflict() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + let locals = [("array", array)].into_py_dict(py); + + let view1 = py + .eval("array[:,:,1]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view1.shape(), [1, 2]); + + let view2 = py + .eval("array[:,:,2]", None, Some(locals)) + .unwrap() + .downcast::>() + .unwrap(); + assert_eq!(view2.shape(), [1, 2]); + + let _exclusive1 = view1.readwrite(); + let _exclusive2 = view2.readwrite(); + }); +} + +#[test] +fn extract_readonly() { + Python::with_gil(|py| { + let ob: &PyAny = PyArray::::zeros(py, (1, 2, 3), false); + ob.extract::>().unwrap(); + }); +} + +#[test] +fn extract_readwrite() { + Python::with_gil(|py| { + let ob: &PyAny = PyArray::::zeros(py, (1, 2, 3), false); + ob.extract::>().unwrap(); + }); +} + +#[test] +fn readonly_as_array_slice_get() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + let array = array.readonly(); + + assert_eq!(array.as_array().shape(), [1, 2, 3]); + assert_eq!(array.as_slice().unwrap().len(), 2 * 3); + assert_eq!(*array.get([0, 1, 2]).unwrap(), 0.0); + }); +} + +#[test] +fn readwrite_as_array_slice() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + let mut array = array.readwrite(); + + assert_eq!(array.as_array().shape(), [1, 2, 3]); + assert_eq!(array.as_array_mut().shape(), [1, 2, 3]); + assert_eq!(*array.get([0, 1, 2]).unwrap(), 0.0); + assert_eq!(array.as_slice().unwrap().len(), 2 * 3); + assert_eq!(array.as_slice_mut().unwrap().len(), 2 * 3); + assert_eq!(*array.get_mut([0, 1, 2]).unwrap(), 0.0); + }); +} diff --git a/tests/iter.rs b/tests/iter.rs index faa7a3c70..6be9c8e4a 100644 --- a/tests/iter.rs +++ b/tests/iter.rs @@ -10,10 +10,9 @@ use pyo3::Python; fn readonly_iter() { Python::with_gil(|py| { let arr = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let ro_arr = arr.readonly(); - let iter = NpySingleIterBuilder::readonly(arr.readonly()) - .build() - .unwrap(); + let iter = NpySingleIterBuilder::readonly(&ro_arr).build().unwrap(); assert_eq!(iter.sum::(), 15); }); @@ -23,16 +22,20 @@ fn readonly_iter() { fn mutable_iter() { Python::with_gil(|py| { let arr = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let mut rw_arr = arr.readwrite(); - let iter = unsafe { NpySingleIterBuilder::readwrite(arr).build().unwrap() }; + let iter = NpySingleIterBuilder::readwrite(&mut rw_arr) + .build() + .unwrap(); for elem in iter { *elem *= 2; } - let iter = NpySingleIterBuilder::readonly(arr.readonly()) - .build() - .unwrap(); + drop(rw_arr); + let ro_arr = arr.readonly(); + + let iter = NpySingleIterBuilder::readonly(&ro_arr).build().unwrap(); assert_eq!(iter.sum::(), 30); }); @@ -42,11 +45,13 @@ fn mutable_iter() { fn multiiter_rr() { Python::with_gil(|py| { let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let ro_arr1 = arr1.readonly(); let arr2 = pyarray![py, [6, 7], [8, 9], [10, 11]]; + let ro_arr2 = arr2.readonly(); let iter = NpyMultiIterBuilder::new() - .add_readonly(arr1.readonly()) - .add_readonly(arr2.readonly()) + .add_readonly(&ro_arr1) + .add_readonly(&ro_arr2) .build() .unwrap(); @@ -58,23 +63,26 @@ fn multiiter_rr() { fn multiiter_rw() { Python::with_gil(|py| { let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let ro_arr1 = arr1.readonly(); let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]]; + let mut rw_arr2 = arr2.readwrite(); - let iter = unsafe { - NpyMultiIterBuilder::new() - .add_readonly(arr1.readonly()) - .add_readwrite(arr2) - .build() - .unwrap() - }; + let iter = NpyMultiIterBuilder::new() + .add_readonly(&ro_arr1) + .add_readwrite(&mut rw_arr2) + .build() + .unwrap(); for (x, y) in iter { *y = *x * 2; } + drop(rw_arr2); + let ro_arr2 = arr2.readonly(); + let iter = NpyMultiIterBuilder::new() - .add_readonly(arr1.readonly()) - .add_readonly(arr2.readonly()) + .add_readonly(&ro_arr1) + .add_readonly(&ro_arr2) .build() .unwrap(); @@ -88,10 +96,9 @@ fn multiiter_rw() { fn single_iter_size_hint_len() { Python::with_gil(|py| { let arr = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let ro_arr = arr.readonly(); - let mut iter = NpySingleIterBuilder::readonly(arr.readonly()) - .build() - .unwrap(); + let mut iter = NpySingleIterBuilder::readonly(&ro_arr).build().unwrap(); for len in (1..=6).rev() { assert_eq!(iter.len(), len); @@ -109,11 +116,13 @@ fn single_iter_size_hint_len() { fn multi_iter_size_hint_len() { Python::with_gil(|py| { let arr1 = pyarray![py, [0, 1], [2, 3], [4, 5]]; + let ro_arr1 = arr1.readonly(); let arr2 = pyarray![py, [0, 0], [0, 0], [0, 0]]; + let ro_arr2 = arr2.readonly(); let mut iter = NpyMultiIterBuilder::new() - .add_readonly(arr1.readonly()) - .add_readonly(arr2.readonly()) + .add_readonly(&ro_arr1) + .add_readonly(&ro_arr2) .build() .unwrap();