Skip to content

Commit

Permalink
Add support for f16 tensors (#1449)
Browse files Browse the repository at this point in the history
* Add support for f16 tensors
* Make half::f16 ndarrays work as well
  • Loading branch information
jleibs authored and jprochazk committed May 11, 2023
1 parent 885961c commit a557354
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 23 deletions.
1 change: 1 addition & 0 deletions crates/re_data_ui/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ pub fn tensor_summary_ui_grid_contents(
| re_log_types::component_types::TensorData::I16(_)
| re_log_types::component_types::TensorData::I32(_)
| re_log_types::component_types::TensorData::I64(_)
| re_log_types::component_types::TensorData::F16(_)
| re_log_types::component_types::TensorData::F32(_)
| re_log_types::component_types::TensorData::F64(_) => {}
re_log_types::component_types::TensorData::JPEG(jpeg_bytes) => {
Expand Down
104 changes: 94 additions & 10 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ impl ArrowDeserialize for TensorId {
/// false
/// ),
/// Field::new(
/// "F16",
/// DataType::List(Box::new(Field::new("item", DataType::Float16, false))),
/// false
/// ),
/// Field::new(
/// "F32",
/// DataType::List(Box::new(Field::new("item", DataType::Float32, false))),
/// false
Expand Down Expand Up @@ -153,8 +158,7 @@ pub enum TensorData {
I32(Buffer<i32>),
I64(Buffer<i64>),
// ---
// TODO(#854): Native F16 support for arrow tensors
//F16(Vec<arrow2::types::f16>),
F16(Buffer<arrow2::types::f16>),
F32(Buffer<f32>),
F64(Buffer<f64>),
JPEG(Buffer<u8>),
Expand All @@ -171,6 +175,7 @@ impl TensorData {
Self::I16(_) => TensorDataType::I16,
Self::I32(_) => TensorDataType::I32,
Self::I64(_) => TensorDataType::I64,
Self::F16(_) => TensorDataType::F16,
Self::F32(_) => TensorDataType::F32,
Self::F64(_) => TensorDataType::F64,
}
Expand All @@ -186,6 +191,7 @@ impl TensorData {
Self::I16(buf) => buf.len(),
Self::I32(buf) => buf.len(),
Self::I64(buf) => buf.len(),
Self::F16(buf) => buf.len(),
Self::F32(buf) => buf.len(),
Self::F64(buf) => buf.len(),
}
Expand All @@ -205,6 +211,7 @@ impl TensorData {
| Self::I16(_)
| Self::I32(_)
| Self::I64(_)
| Self::F16(_)
| Self::F32(_)
| Self::F64(_) => false,

Expand All @@ -224,6 +231,7 @@ impl std::fmt::Debug for TensorData {
Self::I16(_) => write!(f, "I16({} bytes)", self.size_in_bytes()),
Self::I32(_) => write!(f, "I32({} bytes)", self.size_in_bytes()),
Self::I64(_) => write!(f, "I64({} bytes)", self.size_in_bytes()),
Self::F16(_) => write!(f, "F16({} bytes)", self.size_in_bytes()),
Self::F32(_) => write!(f, "F32({} bytes)", self.size_in_bytes()),
Self::F64(_) => write!(f, "F64({} bytes)", self.size_in_bytes()),
Self::JPEG(_) => write!(f, "JPEG({} bytes)", self.size_in_bytes()),
Expand Down Expand Up @@ -463,6 +471,7 @@ impl Tensor {
TensorData::I16(buf) => Some(TensorElement::I16(buf[offset])),
TensorData::I32(buf) => Some(TensorElement::I32(buf[offset])),
TensorData::I64(buf) => Some(TensorElement::I64(buf[offset])),
TensorData::F16(buf) => Some(TensorElement::F16(buf[offset])),
TensorData::F32(buf) => Some(TensorElement::F32(buf[offset])),
TensorData::F64(buf) => Some(TensorElement::F64(buf[offset])),
TensorData::JPEG(_) => None, // Too expensive to unpack here.
Expand Down Expand Up @@ -498,11 +507,6 @@ pub enum TensorCastError {

#[error("ndarray Array is not contiguous and in standard order")]
NotContiguousStdOrder,

#[error(
"tensors do not currently support f16 data (https://github.com/rerun-io/rerun/issues/854)"
)]
F16NotSupported,
}

macro_rules! tensor_type {
Expand Down Expand Up @@ -591,15 +595,93 @@ tensor_type!(i16, I16);
tensor_type!(i32, I32);
tensor_type!(i64, I64);

tensor_type!(arrow2::types::f16, F16);
tensor_type!(f32, F32);
tensor_type!(f64, F64);

// TODO(#854) Switch back to `tensor_type!` once we have F16 tensors
// Manual expansion of tensor_type! macro for `half::f16` types. We need to do this
// because arrow uses its own half type. The two use the same underlying representation
// but are still distinct types. `half::f16`, however, is more full-featured and
// generally a better choice to use when converting to ndarray.
// ==========================================
// TODO(jleibs): would be nice to support this with the macro definition as well
// but the bytemuck casts add a bit of complexity here.
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
type Error = TensorCastError;

fn try_from(_: &'a Tensor) -> Result<Self, Self::Error> {
Err(TensorCastError::F16NotSupported)
fn try_from(value: &'a Tensor) -> Result<Self, Self::Error> {
let shape: Vec<_> = value.shape.iter().map(|d| d.size as usize).collect();
if let TensorData::F16(data) = &value.data {
ndarray::ArrayViewD::from_shape(shape, bytemuck::cast_slice(data.as_slice()))
.map_err(|err| TensorCastError::BadTensorShape { source: err })
} else {
Err(TensorCastError::TypeMismatch)
}
}
}

impl<'a, D: ::ndarray::Dimension> TryFrom<::ndarray::ArrayView<'a, half::f16, D>> for Tensor {
type Error = TensorCastError;

fn try_from(view: ::ndarray::ArrayView<'a, half::f16, D>) -> Result<Self, Self::Error> {
let shape = view
.shape()
.iter()
.map(|dim| TensorDimension {
size: *dim as u64,
name: None,
})
.collect();
match view.to_slice() {
Some(slice) => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(Vec::from(bytemuck::cast_slice(slice)).into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
}),
None => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(
view.iter()
.map(|f| arrow2::types::f16::from_bits(f.to_bits()))
.collect::<Vec<_>>()
.into(),
),
meaning: TensorDataMeaning::Unknown,
meter: None,
}),
}
}
}

impl<D: ::ndarray::Dimension> TryFrom<::ndarray::Array<half::f16, D>> for Tensor {
type Error = TensorCastError;

fn try_from(value: ndarray::Array<half::f16, D>) -> Result<Self, Self::Error> {
let shape = value
.shape()
.iter()
.map(|dim| TensorDimension {
size: *dim as u64,
name: None,
})
.collect();
value
.is_standard_layout()
.then(|| Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::F16(
bytemuck::cast_slice(value.into_raw_vec().as_slice())
.to_vec()
.into(),
),
meaning: TensorDataMeaning::Unknown,
meter: None,
})
.ok_or(TensorCastError::NotContiguousStdOrder)
}
}

Expand Down Expand Up @@ -883,6 +965,7 @@ impl TryFrom<Tensor> for DecodedTensor {
| TensorData::I16(_)
| TensorData::I32(_)
| TensorData::I64(_)
| TensorData::F16(_)
| TensorData::F32(_)
| TensorData::F64(_) => Ok(Self(tensor)),

Expand Down Expand Up @@ -972,6 +1055,7 @@ impl DecodedTensor {
| TensorData::I16(_)
| TensorData::I32(_)
| TensorData::I64(_)
| TensorData::F16(_)
| TensorData::F32(_)
| TensorData::F64(_) => Ok(Self(maybe_encoded_tensor)),

Expand Down
7 changes: 3 additions & 4 deletions crates/re_log_types/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ impl TensorDataTypeTrait for f64 {

/// The data that can be stored in a [`crate::component_types::Tensor`].
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum TensorElement {
/// Unsigned 8 bit integer.
///
Expand Down Expand Up @@ -259,7 +258,7 @@ pub enum TensorElement {
///
/// Uses the standard IEEE 754-2008 binary16 format.
/// Set <https://en.wikipedia.org/wiki/Half-precision_floating-point_format>.
F16(f16),
F16(arrow2::types::f16),

/// 32-bit floating point number.
F32(f32),
Expand All @@ -282,7 +281,7 @@ impl TensorElement {
Self::I32(value) => *value as _,
Self::I64(value) => *value as _,

Self::F16(value) => value.to_f64(),
Self::F16(value) => value.to_f32() as _,
Self::F32(value) => *value as _,
Self::F64(value) => *value,
}
Expand All @@ -307,7 +306,7 @@ impl TensorElement {
Self::I32(value) => u16::try_from(*value).ok(),
Self::I64(value) => u16::try_from(*value).ok(),

Self::F16(value) => u16_from_f64(value.to_f64()),
Self::F16(value) => u16_from_f64(value.to_f32() as f64),
Self::F32(value) => u16_from_f64(*value as f64),
Self::F64(value) => u16_from_f64(*value),
}
Expand Down
3 changes: 3 additions & 0 deletions crates/re_viewer/src/ui/view_bar_chart/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ pub(crate) fn view_bar_chart(
instance_key,
data.iter().copied().map(|v| v as f64),
),
component_types::TensorData::F16(data) => {
create_bar_chart(ent_path, instance_key, data.iter().map(|f| f.to_f32()))
}
component_types::TensorData::F32(data) => {
create_bar_chart(ent_path, instance_key, data.iter().copied())
}
Expand Down
14 changes: 10 additions & 4 deletions crates/re_viewer_context/src/gpu_bridge/tensor_to_gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ fn general_texture_creation_desc_from_tensor<'a>(
TensorData::I32(buf) => (cast_slice_to_cow(buf), TextureFormat::R32Sint),
TensorData::I64(buf) => (narrow_i64_to_f32s(buf), TextureFormat::R32Float), // narrowing to f32!

// TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::R16Float), TODO(#854)
TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::R16Float),
TensorData::F32(buf) => (cast_slice_to_cow(buf), TextureFormat::R32Float),
TensorData::F64(buf) => (narrow_f64_to_f32s(buf), TextureFormat::R32Float), // narrowing to f32!

Expand All @@ -301,7 +301,7 @@ fn general_texture_creation_desc_from_tensor<'a>(
TensorData::I32(buf) => (cast_slice_to_cow(buf), TextureFormat::Rg32Sint),
TensorData::I64(buf) => (narrow_i64_to_f32s(buf), TextureFormat::Rg32Float), // narrowing to f32!

// TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::Rg16Float), TODO(#854)
TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::Rg16Float),
TensorData::F32(buf) => (cast_slice_to_cow(buf), TextureFormat::Rg32Float),
TensorData::F64(buf) => (narrow_f64_to_f32s(buf), TextureFormat::Rg32Float), // narrowing to f32!

Expand Down Expand Up @@ -335,7 +335,13 @@ fn general_texture_creation_desc_from_tensor<'a>(
TextureFormat::Rgba32Float,
),

// TensorData::F16(buf) => (pad_and_cast(buf, 1.0), TextureFormat::Rgba16Float), TODO(#854)
TensorData::F16(buf) => (
pad_and_cast(
buf,
re_log_types::external::arrow2::types::f16::from_f32(1.0),
),
TextureFormat::Rgba16Float,
),
TensorData::F32(buf) => (pad_and_cast(buf, 1.0), TextureFormat::Rgba32Float),
TensorData::F64(buf) => (
pad_and_narrow_and_cast(buf, 1.0, |x: f64| x as f32),
Expand All @@ -362,7 +368,7 @@ fn general_texture_creation_desc_from_tensor<'a>(
TensorData::I32(buf) => (cast_slice_to_cow(buf), TextureFormat::Rgba32Sint),
TensorData::I64(buf) => (narrow_i64_to_f32s(buf), TextureFormat::Rgba32Float), // narrowing to f32!

// TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::Rgba16Float), TODO(#854)
TensorData::F16(buf) => (cast_slice_to_cow(buf), TextureFormat::Rgba16Float),
TensorData::F32(buf) => (cast_slice_to_cow(buf), TextureFormat::Rgba32Float),
TensorData::F64(buf) => (narrow_f64_to_f32s(buf), TextureFormat::Rgba32Float), // narrowing to f32!

Expand Down
5 changes: 0 additions & 5 deletions rerun_py/rerun_sdk/rerun/log/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ def _log_tensor(
np.float64,
]

# We don't support float16 -- upscale to f32
# TODO(#854): Native F16 support for arrow tensors
if tensor.dtype == np.float16:
tensor = np.asarray(tensor, dtype="float32")

if tensor.dtype not in SUPPORTED_DTYPES:
_send_warning(f"Unsupported dtype: {tensor.dtype}. Expected a numeric type. Skipping this tensor.", 2)
return
Expand Down

0 comments on commit a557354

Please sign in to comment.