Skip to content

Commit

Permalink
Add minimal support for BFloat16 dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Jun 21, 2023
1 parent 9f10b61 commit f2f8fbc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Unreleased
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))

- v0.19.0
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
Expand Down
20 changes: 19 additions & 1 deletion src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::os::raw::{
use std::ptr;

#[cfg(feature = "half")]
use half::f16;
use half::{bf16, f16};
use num_traits::{Bounded, Zero};
use pyo3::{
exceptions::{PyIndexError, PyValueError},
Expand All @@ -15,6 +15,8 @@ use pyo3::{
AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject,
PyResult, PyTypeInfo, Python, ToPyObject,
};
#[cfg(feature = "half")]
use pyo3::{sync::GILOnceCell, IntoPy, Py};

use crate::npyffi::{
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
Expand Down Expand Up @@ -477,6 +479,22 @@ impl_element_scalar!(f64 => NPY_DOUBLE);
#[cfg(feature = "half")]
impl_element_scalar!(f16 => NPY_HALF);

#[cfg(feature = "half")]
unsafe impl Element for bf16 {
const IS_COPY: bool = true;

fn get_dtype(py: Python) -> &PyArrayDescr {
static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();

DTYPE
.get_or_init(py, || {
PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").into_py(py)
})
.clone()
.into_ref(py)
}
}

impl_element_scalar!(Complex32 => NPY_CFLOAT,
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
impl_element_scalar!(Complex64 => NPY_CDOUBLE,
Expand Down

0 comments on commit f2f8fbc

Please sign in to comment.