Skip to content

Commit

Permalink
Add dynamic borrow checking for dereferencing NumPy arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Mar 16, 2022
1 parent 983514d commit b945b14
Show file tree
Hide file tree
Showing 10 changed files with 1,024 additions and 398 deletions.
48 changes: 48 additions & 0 deletions benches/borrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#![feature(test)]

extern crate test;
use test::{black_box, Bencher};

use numpy::PyArray;
use pyo3::Python;

#[bench]
fn initial_shared_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

bencher.iter(|| {
let array = black_box(array);

let _shared = array.readonly();
});
});
}

#[bench]
fn additional_shared_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

let _shared = (0..128).map(|_| array.readonly()).collect::<Vec<_>>();

bencher.iter(|| {
let array = black_box(array);

let _shared = array.readonly();
});
});
}

#[bench]
fn exclusive_borrow(bencher: &mut Bencher) {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

bencher.iter(|| {
let array = black_box(array);

let _exclusive = array.readwrite();
});
});
}
8 changes: 5 additions & 3 deletions examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
use numpy::{
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
};
use pyo3::{
pymodule,
types::{PyDict, PyModule},
Expand Down Expand Up @@ -41,8 +43,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
let x = x.as_array_mut();
mult(a, x);
}

Expand Down
119 changes: 67 additions & 52 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ use pyo3::{
Python, ToPyObject,
};

use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
#[allow(deprecated)]
use crate::npyiter::{NpySingleIter, NpySingleIterBuilder, ReadWrite};
use crate::readonly::PyReadonlyArray;
use crate::slice_container::PySliceContainer;

/// A safe, static-typed interface for
Expand Down Expand Up @@ -194,18 +192,8 @@ impl<T, D> PyArray<T, D> {
}

#[inline(always)]
fn check_flag(&self, flag: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flag == flag
}

#[inline(always)]
pub(crate) fn get_flag(&self) -> c_int {
unsafe { *self.as_array_ptr() }.flags
}

/// Returns a temporally unwriteable reference of the array.
pub fn readonly(&self) -> PyReadonlyArray<T, D> {
self.into()
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
unsafe { *self.as_array_ptr() }.flags & flags != 0
}

/// Returns `true` if the internal data of the array is C-style contiguous
Expand All @@ -227,18 +215,17 @@ impl<T, D> PyArray<T, D> {
/// });
/// ```
pub fn is_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is Fortran-style contiguous.
pub fn is_fortran_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
}

/// Returns `true` if the internal data of the array is C-style contiguous.
pub fn is_c_contiguous(&self) -> bool {
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
}

/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
Expand Down Expand Up @@ -684,27 +671,61 @@ impl<T: Element, D: Dimension> PyArray<T, D> {

/// Get the immutable reference of the specified element, with checking the passed index is valid.
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::get`](../struct.PyReadonlyArray.html#method.get)
/// or [`get_owned`](#method.get_owned)) instead of this.
/// Consider using safe alternatives like [`PyReadonlyArray::get`].
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// pyo3::Python::with_gil(|py| {
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
/// assert_eq!(*unsafe { arr.get([1, 0, 3]) }.unwrap(), 11);
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 11);
/// });
/// ```
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the slice might cause undefined behavior.
///
/// Calling this method is undefined behaviour if the underlying array
/// is aliased mutably by other instances of `PyArray`
/// or concurrently modified by Python or other native code.
#[inline(always)]
pub unsafe fn get(&self, index: impl NpyIndex<Dim = D>) -> Option<&T> {
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
Some(&*self.data().offset(offset))
}

/// Same as [`get`][Self::get], but returns `Option<&mut T>`.
///
/// Consider using safe alternatives like [`PyReadwriteArray::get_mut`].
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray::arange(py, 0, 16, 1).reshape([2, 2, 4]).unwrap();
/// unsafe {
/// *arr.get_mut([1, 0, 3]).unwrap() = 42;
/// }
/// assert_eq!(unsafe { *arr.get([1, 0, 3]).unwrap() }, 42);
/// });
/// ```
///
/// # Safety
///
/// Calling this method is undefined behaviour if the underlying array
/// is aliased immutably by mutably by other instances of `PyArray`
/// or concurrently modified by Python or other native code.
#[inline(always)]
pub unsafe fn get_mut(&self, index: impl NpyIndex<Dim = D>) -> Option<&mut T> {
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
Some(&mut *self.data().offset(offset))
}

/// Get the immutable reference of the specified element, without checking the
/// passed index is valid.
///
Expand Down Expand Up @@ -827,28 +848,37 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
ToPyArray::to_pyarray(arr, py)
}

/// Get the immutable view of the internal data of `PyArray`, as
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
/// Get an immutable borrow of the NumPy array
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
PyReadonlyArray::try_new(self).unwrap()
}

/// Get a mutable borrow of the NumPy array
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
PyReadwriteArray::try_new(self).unwrap()
}

/// Returns the internal array as [`ArrayView`].
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
/// or [`to_array`](#method.to_array)) instead of this.
/// See also [`PyReadonlyArray::as_array`].
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
/// Returns the internal array as [`ArrayViewMut`].
///
/// See also [`PyReadwriteArray::as_array_mut`].
///
/// # Safety
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
Expand Down Expand Up @@ -924,7 +954,7 @@ impl<D: Dimension> PyArray<PyObject, D> {
///
/// let pyarray = PyArray::from_owned_object_array(py, array);
///
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
/// assert!(pyarray.readonly().as_array().get(0).unwrap().as_ref(py).is_instance_of::<CustomElement>().unwrap());
/// });
/// ```
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
Expand Down Expand Up @@ -1073,21 +1103,6 @@ impl<T: Element> PyArray<T, Ix1> {
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
}

/// Iterates all elements of this array.
/// See [NpySingleIter](../npyiter/struct.NpySingleIter.html) for more.
///
/// # Safety
///
/// The iterator will produce mutable references into the array which must not be
/// aliased by other references for the life time of the iterator.
#[deprecated(
note = "The wrappers of the array iterator API are deprecated, please use ndarray's `ArrayBase::iter_mut` instead."
)]
#[allow(deprecated)]
pub unsafe fn iter<'py>(&'py self) -> PyResult<NpySingleIter<'py, T, ReadWrite>> {
NpySingleIterBuilder::readwrite(self).build()
}

fn resize_<D: IntoDimension>(
&self,
py: Python,
Expand Down
Loading

0 comments on commit b945b14

Please sign in to comment.