Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add dynamic borrow checking for dereferencing NumPy arrays. #274

Merged
merged 1 commit into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))

- v0.16.2
Expand Down
48 changes: 48 additions & 0 deletions benches/borrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#![feature(test)]

extern crate test;
use test::{black_box, Bencher};

use numpy::PyArray;
use pyo3::Python;

#[bench]
fn initial_shared_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

bencher.iter(|| {
let array = black_box(array);

let _shared = array.readonly();
});
});
}

#[bench]
fn additional_shared_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

let _shared = (0..128).map(|_| array.readonly()).collect::<Vec<_>>();

bencher.iter(|| {
let array = black_box(array);

let _shared = array.readonly();
});
});
}

#[bench]
fn exclusive_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

bencher.iter(|| {
let array = black_box(array);

let _exclusive = array.readwrite();
});
});
}
8 changes: 5 additions & 3 deletions examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
use numpy::{
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
};
use pyo3::{
pymodule,
types::{PyDict, PyModule},
Expand Down Expand Up @@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
let x = x.as_array_mut();
mult(a, x);
}

Expand Down
119 changes: 67 additions & 52 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ use pyo3::{
Python, ToPyObject,
};

use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
use crate::cold;
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
#[allow(deprecated)]
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
use crate::readonly::PyReadonlyArray;
use crate::slice_container::PySliceContainer;

/// A safe, static-typed interface for
Expand Down Expand Up @@ -195,18 +193,8 @@ impl<T, D> PyArray<T, D> {
}

#[inline(always)]
fn check_flag(&self, flag: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flag == flag
}

#[inline(always)]
pub(crate) fn get_flag(&self) -> c_int {
unsafe { *self.as_array_ptr() }.flags
}

/// Returns a temporally unwriteable reference of the array.
pub fn readonly(&self) -> PyReadonlyArray<T, D> {
self.into()
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flags != 0
}

/// Returns `true` if the internal data of the array is C-style contiguous
Expand All @@ -228,18 +216,17 @@ impl<T, D> PyArray<T, D> {
/// });
/// ```
pub fn is_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is Fortran-style contiguous.
pub fn is_fortran_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is C-style contiguous.
pub fn is_c_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
}

/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
Expand Down Expand Up @@ -681,27 +668,61 @@ impl<T: Element, D: Dimension> PyArray<T, D> {

/// Get the immutable reference of the specified element, with checking the passed index is valid.
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::get`](../struct.PyReadonlyArray.html#method.get)
/// or [`get_owned`](#method.get_owned)) instead of this.
/// Consider using safe alternatives like [`PyReadonlyArray::get`].
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// pyo3::Python::with_gil(|py| {
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
/// assert_eq!(*unsafe { arr.get([1, 0, 3]) }.unwrap(), 11);
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 11);
/// });
/// ```
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the slice might cause undefined behavior.
///
/// Calling this method is undefined behaviour if the underlying array
/// is aliased mutably by other instances of `PyArray`
/// or concurrently modified by Python or other native code.
#[inline(always)]
pub unsafe fn get(&self, index: impl NpyIndex<Dim = D>) -> Option<&T> {
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
Some(&*self.data().offset(offset))
}

/// Same as [`get`][Self::get], but returns `Option<&mut T>`.
///
/// Consider using safe alternatives like [`PyReadwriteArray::get_mut`].
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
/// unsafe {
/// *arr.get_mut([1, 0, 3]).unwrap() = 42;
/// }
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 42);
/// });
/// ```
///
/// # Safety
///
/// Calling this method is undefined behaviour if the underlying array
/// is aliased immutably by mutably by other instances of `PyArray`
/// or concurrently modified by Python or other native code.
#[inline(always)]
pub unsafe fn get_mut(&self, index: impl NpyIndex<Dim = D>) -> Option<&mut T> {
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
Some(&mut *self.data().offset(offset))
}

/// Get the immutable reference of the specified element, without checking the
/// passed index is valid.
///
Expand Down Expand Up @@ -824,28 +845,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
ToPyArray::to_pyarray(arr, py)
}

/// Get the immutable view of the internal data of `PyArray`, as
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
/// Get an immutable borrow of the NumPy array
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
PyReadonlyArray::try_new(self).unwrap()
}

/// Get a mutable borrow of the NumPy array
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
PyReadwriteArray::try_new(self).unwrap()
}

/// Returns the internal array as [`ArrayView`].
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
/// or [`to_array`](#method.to_array)) instead of this.
/// See also [`PyReadonlyArray::as_array`].
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
/// Returns the internal array as [`ArrayViewMut`].
///
/// See also [`PyReadwriteArray::as_array_mut`].
///
/// # Safety
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
Expand Down Expand Up @@ -921,7 +951,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
///
/// let pyarray = PyArray::from_owned_object_array(py, array);
///
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
/// });
/// ```
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
Expand Down Expand Up @@ -1043,21 +1073,6 @@ impl<T: Element> PyArray<T, Ix1> {
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
}

/// Iterates all elements of this array.
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
///
/// # Safety
///
/// The iterator will produce mutable references into the array which must not be
/// aliased by other references for the life time of the iterator.
#[deprecated(
note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead."
)]
#[allow(deprecated)]
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
NpySingleIterBuilder::readwrite(self).build()
}

fn resize_<D: IntoDimension>(
&self,
py: Python,
Expand Down
Loading