Skip to content

Commit

Permalink
WIP: Add dynamic borrow checking for dereferencing NumPy arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Feb 18, 2022
1 parent 61882e3 commit 542b587
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 46 deletions.
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,40 +44,37 @@ numpy = "0.15"
```

```rust
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, IxDyn};
use numpy::{IntoPyArray, PyArrayDyn, PyArrayRef, PyArrayRefMut};
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};

#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// immutable example
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * x + y
}

// mutable example (no return)
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
x *= a;
fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) {
*x *= a;
}

// wrapper of `axpy`
#[pyfn(m, "axpy")]
fn axpy_py<'py>(
py: Python<'py>,
a: f64,
x: PyReadonlyArrayDyn<f64>,
y: PyReadonlyArrayDyn<f64>,
x: PyArrayRef<f64, IxDyn>,
y: PyArrayRef<f64, IxDyn>,
) -> &'py PyArrayDyn<f64> {
let x = x.as_array();
let y = y.as_array();
axpy(a, x, y).into_pyarray(py)
axpy(a, &x, &y).into_pyarray(py)
}

// wrapper of `mult`
#[pyfn(m, "mult")]
fn mult_py(_py: Python<'_>, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
let x = unsafe { x.as_array_mut() };
mult(a, x);
fn mult_py(_py: Python<'_>, a: f64, mut x: PyArrayRefMut<f64, IxDyn>) -> PyResult<()> {
mult(a, &mut x);
Ok(())
}

Expand Down
5 changes: 2 additions & 3 deletions examples/linalg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use ndarray_linalg::solve::Inverse;
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use numpy::{IntoPyArray, Ix2, PyArray2, PyArrayRef};
use pyo3::{exceptions::PyRuntimeError, pymodule, types::PyModule, PyErr, PyResult, Python};

#[pymodule]
fn rust_linalg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfn(m)]
fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<'py, f64>) -> PyResult<&'py PyArray2<f64>> {
let x = x.as_array();
fn inv<'py>(py: Python<'py>, x: PyArrayRef<'py, f64, Ix2>) -> PyResult<&'py PyArray2<f64>> {
let y = x
.inv()
.map_err(|e| PyErr::new::<PyRuntimeError, _>(format!("[rust_linalg] {}", e)))?;
Expand Down
13 changes: 6 additions & 7 deletions examples/parallel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
// We need to link `blas_src` directly, c.f. https://github.com/rust-ndarray/ndarray#how-to-enable-blas-integration
extern crate blas_src;

use ndarray::Zip;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use numpy::ndarray::{ArrayView1, Zip};
use numpy::{IntoPyArray, Ix1, Ix2, PyArray1, PyArrayRef};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[pymodule]
fn rust_parallel(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfn(m)]
fn rows_dot<'py>(
py: Python<'py>,
x: PyReadonlyArray2<'py, f64>,
y: PyReadonlyArray1<'py, f64>,
x: PyArrayRef<'py, f64, Ix2>,
y: PyArrayRef<'py, f64, Ix1>,
) -> &'py PyArray1<f64> {
let x = x.as_array();
let y = y.as_array();
let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(&y));
let y: &ArrayView1<f64> = &y;
let z = Zip::from(x.rows()).par_map_collect(|row| row.dot(y));
z.into_pyarray(py)
}
Ok(())
Expand Down
18 changes: 9 additions & 9 deletions examples/simple-extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ use pyo3::{
#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// immutable example
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
fn axpy(a: f64, x: &ArrayViewD<'_, f64>, y: &ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * x + y
}

// mutable example (no return)
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
x *= a;
fn mult(a: f64, x: &mut ArrayViewMutD<'_, f64>) {
*x *= a;
}

// complex example
fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
fn conj(x: &ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
x.map(|c| c.conj())
}

Expand All @@ -34,16 +34,16 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
) -> &'py PyArrayDyn<f64> {
let x = x.as_array();
let y = y.as_array();
let z = axpy(a, x, y);
let z = axpy(a, &x, &y);
z.into_pyarray(py)
}

// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
mult(a, x);
let mut x = x.as_array_mut();
mult(a, &mut x);
}

// wrapper of `conj`
Expand All @@ -53,7 +53,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
py: Python<'py>,
x: PyReadonlyArrayDyn<'_, Complex64>,
) -> &'py PyArrayDyn<Complex64> {
conj(x.as_array()).into_pyarray(py)
conj(&x.as_array()).into_pyarray(py)
}

#[pyfn(m)]
Expand Down
32 changes: 20 additions & 12 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use pyo3::{
Python, ToPyObject,
};

use crate::borrow::{PyArrayRef, PyArrayRefMut};
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::Element;
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
Expand Down Expand Up @@ -825,27 +826,34 @@ impl<T: Element, D: Dimension> PyArray<T, D> {

/// Get the immutable view of the internal data of `PyArray`, as
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
///
/// Please consider the use of safe alternatives
/// ([`PyReadonlyArray::as_array`](../struct.PyReadonlyArray.html#method.as_array)
/// or [`to_array`](#method.to_array)) instead of this.
pub fn as_array(&self) -> PyArrayRef<'_, T, D> {
PyArrayRef::try_new(self).expect("NumPy array already borrowed")
}

/// Get the immutable view of the internal data of `PyArray`, as
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
pub fn as_array_mut(&self) -> PyArrayRefMut<'_, T, D> {
PyArrayRefMut::try_new(self).expect("NumPy array already borrowed")
}

/// Returns the internal array as [`ArrayView`]. See also [`as_array_unchecked`].
///
/// # Safety
/// If the internal array is not readonly and can be mutated from Python code,
/// holding the `ArrayView` might cause undefined behavior.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
pub unsafe fn as_array_unchecked(&self) -> ArrayView<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayView::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
res
}

/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array`](#method.as_array).
/// Returns the internal array as [`ArrayViewMut`]. See also [`as_array_unchecked`].
///
/// # Safety
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
/// it might cause undefined behavior.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
pub unsafe fn as_array_mut_unchecked(&self) -> ArrayViewMut<'_, T, D> {
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
inverted_axes.invert(&mut res);
Expand Down Expand Up @@ -884,7 +892,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// });
/// ```
pub fn to_owned_array(&self) -> Array<T, D> {
unsafe { self.as_array() }.to_owned()
unsafe { self.as_array_unchecked() }.to_owned()
}
}

Expand Down
170 changes: 170 additions & 0 deletions src/borrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::cell::UnsafeCell;
use std::collections::hash_map::{Entry, HashMap};
use std::ops::{Deref, DerefMut};

use ndarray::{ArrayView, ArrayViewMut, Dimension};
use pyo3::{FromPyObject, PyAny, PyResult};

use crate::array::PyArray;
use crate::dtype::Element;

thread_local! {
static BORROW_FLAGS: UnsafeCell<HashMap<usize, isize>> = UnsafeCell::new(HashMap::new());
}

pub struct PyArrayRef<'a, T, D> {
array: &'a PyArray<T, D>,
view: ArrayView<'a, T, D>,
}

impl<'a, T, D> Deref for PyArrayRef<'a, T, D> {
type Target = ArrayView<'a, T, D>;

fn deref(&self) -> &Self::Target {
&self.view
}
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRef<'py, T, D> {
fn extract(obj: &'py PyAny) -> PyResult<Self> {
let array: &'py PyArray<T, D> = obj.extract()?;
Ok(array.as_array())
}
}

impl<'a, T, D> PyArrayRef<'a, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: &'a PyArray<T, D>) -> Option<Self> {
let address = array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let readers = entry.into_mut();

let new_readers = readers.wrapping_add(1);

if new_readers <= 0 {
cold();
return None;
}

*readers = new_readers;
}
Entry::Vacant(entry) => {
entry.insert(1);
}
}

// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
let view = unsafe { array.as_array_unchecked() };

Some(Self { array, view })
})
}
}

impl<'a, T, D> Drop for PyArrayRef<'a, T, D> {
fn drop(&mut self) {
let address = self.array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

let readers = borrow_flags.get_mut(&address).unwrap();

*readers -= 1;

if *readers == 0 {
borrow_flags.remove(&address).unwrap();
}
});
}
}

pub struct PyArrayRefMut<'a, T, D> {
array: &'a PyArray<T, D>,
view: ArrayViewMut<'a, T, D>,
}

impl<'a, T, D> Deref for PyArrayRefMut<'a, T, D> {
type Target = ArrayViewMut<'a, T, D>;

fn deref(&self) -> &Self::Target {
&self.view
}
}

impl<'a, T, D> DerefMut for PyArrayRefMut<'a, T, D> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.view
}
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyArrayRefMut<'py, T, D> {
fn extract(obj: &'py PyAny) -> PyResult<Self> {
let array: &'py PyArray<T, D> = obj.extract()?;
Ok(array.as_array_mut())
}
}

impl<'a, T, D> PyArrayRefMut<'a, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: &'a PyArray<T, D>) -> Option<Self> {
let address = array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let writers = entry.into_mut();

if *writers != 0 {
cold();
return None;
}

*writers = -1;
}
Entry::Vacant(entry) => {
entry.insert(-1);
}
}

// SAFETY: Thread-local borrow flags ensure aliasing discipline on this thread,
// and `PyArray` is neither `Send` nor `Sync`
let view = unsafe { array.as_array_mut_unchecked() };

Some(Self { array, view })
})
}
}

impl<'a, T, D> Drop for PyArrayRefMut<'a, T, D> {
fn drop(&mut self) {
let address = self.array as *const PyArray<T, D> as usize;

BORROW_FLAGS.with(|borrow_flags| {
// SAFETY: Called on a thread local variable in a leaf function.
let borrow_flags = unsafe { &mut *borrow_flags.get() };

borrow_flags.remove(&address).unwrap();
});
}
}
#[cold]
#[inline(always)]
fn cold() {}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#![allow(clippy::needless_lifetimes)] // We often want to make the GIL lifetime explicit.

pub mod array;
mod borrow;
pub mod convert;
mod dtype;
mod error;
Expand All @@ -46,6 +47,7 @@ pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
};
pub use crate::borrow::{PyArrayRef, PyArrayRefMut};
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
Expand Down
Loading

0 comments on commit 542b587

Please sign in to comment.