Skip to content

Commit

Permalink
Make half::f16 ndarrays work as well
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed May 9, 2023
1 parent 33eb13c commit 04ff2c6
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,6 @@ pub enum TensorCastError {

#[error("ndarray Array is not contiguous and in standard order")]
NotContiguousStdOrder,

#[error(
"tensors do not currently support f16 data (https://github.com/rerun-io/rerun/issues/854)"
)]
F16NotSupported,
}

macro_rules! tensor_type {
Expand Down Expand Up @@ -600,15 +595,85 @@ tensor_type!(i16, I16);
tensor_type!(i32, I32);
tensor_type!(i64, I64);

tensor_type!(arrow2::types::f16, F16);
tensor_type!(f32, F32);
tensor_type!(f64, F64);

// TODO(#854) Switch back to `tensor_type!` once we have F16 tensors
// Manual expansion of tensor_type! macro for `half::f16` types.
// ==========================================
// TODO(jleibs): would be nice to support this with the macro definition as well
// but the bytemuck casts add a bit of complexity here.
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
type Error = TensorCastError;

fn try_from(_: &'a Tensor) -> Result<Self, Self::Error> {
Err(TensorCastError::F16NotSupported)
fn try_from(value: &'a Tensor) -> Result<Self, Self::Error> {
let shape: Vec<_> = value.shape.iter().map(|d| d.size as usize).collect();
if let TensorData::F16(data) = &value.data {
ndarray::ArrayViewD::from_shape(shape, bytemuck::cast_slice(data.as_slice()))
.map_err(|err| TensorCastError::BadTensorShape { source: err })
} else {
Err(TensorCastError::TypeMismatch)
}
}
}
impl<'a, D: ::ndarray::Dimension> TryFrom<::ndarray::ArrayView<'a, half::f16, D>> for Tensor {
type Error = TensorCastError;
fn try_from(view: ::ndarray::ArrayView<'a, half::f16, D>) -> Result<Self, Self::Error> {
let shape = view
.shape()
.iter()
.map(|dim| TensorDimension {
size: *dim as u64,
name: None,
})
.collect();
match view.to_slice() {
Some(slice) => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(Vec::from(bytemuck::cast_slice(slice)).into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
}),
None => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(
view.iter()
.map(|f| arrow2::types::f16::from_bits(f.to_bits()))
.collect::<Vec<_>>()
.into(),
),
meaning: TensorDataMeaning::Unknown,
meter: None,
}),
}
}
}
impl<D: ::ndarray::Dimension> TryFrom<::ndarray::Array<half::f16, D>> for Tensor {
type Error = TensorCastError;
fn try_from(value: ndarray::Array<half::f16, D>) -> Result<Self, Self::Error> {
let shape = value
.shape()
.iter()
.map(|dim| TensorDimension {
size: *dim as u64,
name: None,
})
.collect();
value
.is_standard_layout()
.then(|| Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(
bytemuck::cast_slice(value.into_raw_vec().as_slice())
.to_vec()
.into(),
),
meaning: TensorDataMeaning::Unknown,
meter: None,
})
.ok_or(TensorCastError::NotContiguousStdOrder)
}
}

Expand Down

0 comments on commit 04ff2c6

Please sign in to comment.