Skip to content

Commit

Permalink
Add PyArrayDescr::new() constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
aldanor committed Jan 24, 2022
1 parent 7cc945c commit ccfb471
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::mem::size_of;
use std::os::raw::{
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
};
use std::ptr;

use num_traits::{Bounded, Zero};
use pyo3::{
Expand Down Expand Up @@ -57,6 +58,23 @@ pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
}

impl PyArrayDescr {
/// Creates a new dtype object from an arbitrary object.
///
/// Equivalent to invoking the constructor of `np.dtype`.
pub fn new<'py, T: ToPyObject + ?Sized>(py: Python<'py>, obj: &T) -> PyResult<&'py Self> {
Self::new_impl(py, obj.to_object(py))
}

#[inline]
fn new_impl<'py>(py: Python<'py>, obj: PyObject) -> PyResult<&'py Self> {
let mut descr: *mut PyArray_Descr = ptr::null_mut();
unsafe {
// None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
PY_ARRAY_API.PyArray_DescrConverter2(obj.as_ptr(), &mut descr as *mut _);
py.from_owned_ptr_or_err(descr as _)
}
}

/// Returns `self` as `*mut PyArray_Descr`.
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
self.as_ptr() as _
Expand Down Expand Up @@ -423,6 +441,19 @@ mod tests {
use super::{dtype, Complex32, Complex64, Element, PyArrayDescr};
use crate::npyffi::{NPY_ALIGNED_STRUCT, NPY_ITEM_HASOBJECT, NPY_NEEDS_PYAPI, NPY_TYPES};

#[test]
fn test_dtype_new() {
pyo3::Python::with_gil(|py| {
assert_eq!(PyArrayDescr::new(py, "float64").unwrap(), dtype::<f64>(py));
let d = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
assert_eq!(d.names(), Some(vec!["a", "b"]));
assert!(d.has_object());
assert_eq!(d.get_field("a").unwrap().0, dtype::<PyObject>(py));
assert_eq!(d.get_field("b").unwrap().0, dtype::<bool>(py));
assert!(PyArrayDescr::new(py, &123_usize).is_err());
});
}

#[test]
fn test_dtype_names() {
fn type_name<T: Element>(py: pyo3::Python) -> &str {
Expand Down

0 comments on commit ccfb471

Please sign in to comment.