Skip to content

Commit

Permalink
Merge pull request #222 from adamreichold/sync-api-globals
Browse files Browse the repository at this point in the history
Make API globals thread safe using atomics
  • Loading branch information
davidhewitt authored Nov 25, 2021
2 parents 131bb5f + a0ecc3f commit 6d6084f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
34 changes: 21 additions & 13 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use libc::FILE;
use pyo3::ffi::{self, PyObject, PyTypeObject};
use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use crate::npyffi::*;

Expand All @@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
///
/// You can acceess raw c APIs via this variable and its Deref implementation.
/// You can acceess raw C APIs via this variable.
///
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
///
Expand All @@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();

/// See [PY_ARRAY_API] for more.
pub struct PyArrayAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyArrayAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyArrayAPI {}

impl PyArrayAPI {
impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int];
Expand Down
32 changes: 20 additions & 12 deletions src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html)

use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use pyo3::ffi::PyObject;
use pyo3::Python;
Expand All @@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API";
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();

pub struct PyUFuncAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyUFuncAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyUFuncAPI {}

impl PyUFuncAPI {
impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject];
impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];
Expand Down

0 comments on commit 6d6084f

Please sign in to comment.