Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added support to read and write f16 #1051

Merged
merged 1 commit into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ documentation of each of its APIs.
## Features

* Most feature-complete implementation of Apache Arrow after the reference implementation (C++)
* Float 16 unsupported (not a Rust native type)
* Decimal 256 unsupported (not a Rust native type)
* C data interface supported for all Arrow types (read and write)
* C stream interface supported for all Arrow types (read and write)
Expand Down
3 changes: 2 additions & 1 deletion src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ macro_rules! with_match_primitive_type {(
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, months_days_ns, f16};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Expand All @@ -216,6 +216,7 @@ macro_rules! with_match_primitive_type {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => __with_ty__! { f16 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down
6 changes: 5 additions & 1 deletion src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
datatypes::*,
error::Error,
trusted_len::TrustedLen,
types::{days_ms, months_days_ns, NativeType},
types::{days_ms, f16, months_days_ns, NativeType},
};

use super::Array;
Expand Down Expand Up @@ -468,6 +468,8 @@ pub type Int128Array = PrimitiveArray<i128>;
pub type DaysMsArray = PrimitiveArray<days_ms>;
/// A type definition [`PrimitiveArray`] for [`months_days_ns`]
pub type MonthsDaysNsArray = PrimitiveArray<months_days_ns>;
/// A type definition [`PrimitiveArray`] for `f16`
pub type Float16Array = PrimitiveArray<f16>;
/// A type definition [`PrimitiveArray`] for `f32`
pub type Float32Array = PrimitiveArray<f32>;
/// A type definition [`PrimitiveArray`] for `f64`
Expand Down Expand Up @@ -495,6 +497,8 @@ pub type Int128Vec = MutablePrimitiveArray<i128>;
pub type DaysMsVec = MutablePrimitiveArray<days_ms>;
/// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`]
pub type MonthsDaysNsVec = MutablePrimitiveArray<months_days_ns>;
/// A type definition [`MutablePrimitiveArray`] for `f16`
pub type Float16Vec = MutablePrimitiveArray<f16>;
/// A type definition [`MutablePrimitiveArray`] for `f32`
pub type Float32Vec = MutablePrimitiveArray<f32>;
/// A type definition [`MutablePrimitiveArray`] for `f64`
Expand Down
2 changes: 1 addition & 1 deletion src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ macro_rules! with_match_negatable {(
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 | UInt16 | UInt32 | UInt64=> todo!(),
UInt8 | UInt16 | UInt32 | UInt64 | Float16 => todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down
7 changes: 7 additions & 0 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int64, Float64) => true,
(Int64, Decimal(_, _)) => true,

(Float16, Float32) => true,

(Float32, UInt8) => true,
(Float32, UInt16) => true,
(Float32, UInt32) => true,
Expand Down Expand Up @@ -736,6 +738,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int64, Float64) => primitive_to_primitive_dyn::<i64, f64>(array, to_type, as_options),
(Int64, Decimal(p, s)) => integer_to_decimal_dyn::<i64>(array, *p, *s),

(Float16, Float32) => {
let from = array.as_any().downcast_ref().unwrap();
Ok(f16_to_f32(from).boxed())
}

(Float32, UInt8) => primitive_to_primitive_dyn::<f32, u8>(array, to_type, options),
(Float32, UInt16) => primitive_to_primitive_dyn::<f32, u16>(array, to_type, options),
(Float32, UInt32) => primitive_to_primitive_dyn::<f32, u32>(array, to_type, options),
Expand Down
7 changes: 6 additions & 1 deletion src/compute/cast/primitive_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_traits::{AsPrimitive, Float, ToPrimitive};

use crate::datatypes::IntervalUnit;
use crate::error::Result;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};
use crate::{
array::*,
bitmap::Bitmap,
Expand Down Expand Up @@ -581,3 +581,8 @@ pub fn months_to_months_days_ns(from: &PrimitiveArray<i32>) -> PrimitiveArray<mo
DataType::Interval(IntervalUnit::MonthDayNano),
)
}

/// Casts f16 into f32
pub fn f16_to_f32(from: &PrimitiveArray<f16>) -> PrimitiveArray<f32> {
unary(from, |x| x.as_f32(), DataType::Float32)
}
7 changes: 5 additions & 2 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ macro_rules! match_eq_ord {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand All @@ -91,7 +92,7 @@ macro_rules! match_eq {(
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, months_days_ns, f16};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Expand All @@ -104,6 +105,7 @@ macro_rules! match_eq {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => __with_ty__! { f16 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down Expand Up @@ -487,7 +489,8 @@ fn can_partial_eq(data_type: &DataType) -> bool {
can_partial_eq_and_ord(data_type)
|| matches!(
data_type.to_logical_type(),
DataType::Interval(IntervalUnit::DayTime)
DataType::Float16
| DataType::Interval(IntervalUnit::DayTime)
| DataType::Interval(IntervalUnit::MonthDayNano)
)
}
Expand Down
4 changes: 3 additions & 1 deletion src/compute/comparison/simd/native.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::convert::TryInto;

use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};

simd8_native_all!(u8);
simd8_native_all!(u16);
Expand All @@ -12,6 +12,8 @@ simd8_native_all!(i16);
simd8_native_all!(i32);
simd8_native_all!(i128);
simd8_native_all!(i64);
simd8_native!(f16);
simd8_native_partial_eq!(f16);
simd8_native_all!(f32);
simd8_native_all!(f64);
simd8_native!(days_ms);
Expand Down
4 changes: 3 additions & 1 deletion src/compute/comparison/simd/packed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::convert::TryInto;
use std::simd::ToBitMask;

use crate::types::simd::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};

use super::*;

Expand Down Expand Up @@ -71,6 +71,8 @@ simd8!(i16, i16x8);
simd8!(i32, i32x8);
simd8!(i64, i64x8);
simd8_native_all!(i128);
simd8_native!(f16);
simd8_native_partial_eq!(f16);
simd8!(f32, f32x8);
simd8!(f64, f64x8);
simd8_native!(days_ms);
Expand Down
3 changes: 2 additions & 1 deletion src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl DataType {
UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
Float16 => unreachable!(),
Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
Expand Down Expand Up @@ -299,6 +299,7 @@ impl From<PrimitiveType> for DataType {
PrimitiveType::UInt32 => DataType::UInt32,
PrimitiveType::UInt64 => DataType::UInt64,
PrimitiveType::Int128 => DataType::Decimal(32, 32),
PrimitiveType::Float16 => DataType::Float16,
PrimitiveType::Float32 => DataType::Float32,
PrimitiveType::Float64 => DataType::Float64,
PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime),
Expand Down
1 change: 1 addition & 0 deletions src/io/json_integration/read/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ pub fn to_array(
Primitive(PrimitiveType::UInt16) => Ok(Box::new(to_primitive::<u16>(json_col, data_type))),
Primitive(PrimitiveType::UInt32) => Ok(Box::new(to_primitive::<u32>(json_col, data_type))),
Primitive(PrimitiveType::UInt64) => Ok(Box::new(to_primitive::<u64>(json_col, data_type))),
Primitive(PrimitiveType::Float16) => todo!(),
Primitive(PrimitiveType::Float32) => Ok(Box::new(to_primitive::<f32>(json_col, data_type))),
Primitive(PrimitiveType::Float64) => Ok(Box::new(to_primitive::<f64>(json_col, data_type))),
Binary => Ok(to_binary::<i32>(json_col, data_type)),
Expand Down
3 changes: 3 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub enum PrimitiveType {
UInt32,
/// An unsigned 64-bit integer.
UInt64,
/// A 16-bit floating point number.
Float16,
/// A 32-bit floating point number.
Float32,
/// A 64-bit floating point number.
Expand All @@ -77,6 +79,7 @@ mod private {
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for i128 {}
impl Sealed for super::f16 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for super::days_ms {}
Expand Down
110 changes: 110 additions & 0 deletions src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,113 @@ impl Neg for months_days_ns {
Self::new(-self.months(), -self.days(), -self.ns())
}
}

/// Type representation of the Float16 physical type
#[derive(Copy, Clone, Default, Zeroable, Pod)]
#[allow(non_camel_case_types)]
#[repr(C)]
pub struct f16(pub u16);

impl PartialEq for f16 {
#[inline]
fn eq(&self, other: &f16) -> bool {
if self.is_nan() || other.is_nan() {
false
} else {
(self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0)
}
}
}

// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs
impl f16 {
#[inline]
#[must_use]
pub(crate) const fn is_nan(self) -> bool {
self.0 & 0x7FFFu16 > 0x7C00u16
}

/// Casts this `f16` to `f32`
#[inline]
pub fn as_f32(self) -> f32 {
let i = self.0;
// Check for signed zero
if i & 0x7FFFu16 == 0 {
return f32::from_bits((i as u32) << 16);
}

let half_sign = (i & 0x8000u16) as u32;
let half_exp = (i & 0x7C00u16) as u32;
let half_man = (i & 0x03FFu16) as u32;

// Check for an infinity or NaN when all exponent bits set
if half_exp == 0x7C00u32 {
// Check for signed infinity if mantissa is zero
if half_man == 0 {
let number = (half_sign << 16) | 0x7F80_0000u32;
return f32::from_bits(number);
} else {
// NaN, keep current mantissa but also set most significiant mantissa bit
let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13);
return f32::from_bits(number);
}
}

// Calculate single-precision components with adjusted exponent
let sign = half_sign << 16;
// Unbias exponent
let unbiased_exp = ((half_exp as i32) >> 10) - 15;

// Check for subnormals, which will be normalized by adjusting exponent
if half_exp == 0 {
// Calculate how much to adjust the exponent by
let e = (half_man as u16).leading_zeros() - 6;

// Rebias and adjust exponent
let exp = (127 - 15 - e) << 23;
let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
return f32::from_bits(sign | exp | man);
}

// Rebias exponent for a normalized normal
let exp = ((unbiased_exp + 127) as u32) << 23;
let man = (half_man & 0x03FFu32) << 13;
f32::from_bits(sign | exp | man)
}
}

impl std::fmt::Debug for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.as_f32())
}
}

impl std::fmt::Display for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_f32())
}
}

impl NativeType for f16 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Float16;
type Bytes = [u8; 2];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
self.0.to_le_bytes()
}

#[inline]
fn to_ne_bytes(&self) -> Self::Bytes {
self.0.to_ne_bytes()
}

#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
self.0.to_be_bytes()
}

#[inline]
fn from_be_bytes(bytes: Self::Bytes) -> Self {
Self(u16::from_be_bytes(bytes))
}
}
4 changes: 3 additions & 1 deletion src/types/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Contains traits and implementations of multi-data used in SIMD.
//! The actual representation is driven by the feature flag `"simd"`, which, if set,
//! uses [`std::simd`].
use super::{days_ms, months_days_ns};
use super::{days_ms, f16, months_days_ns};
use super::{BitChunk, BitChunkIter, NativeType};

/// Describes the ability to convert itself from a [`BitChunk`].
Expand Down Expand Up @@ -129,6 +129,7 @@ pub(super) use native_simd;
// Types do not have specific intrinsics and thus SIMD can't be specialized.
// Therefore, we can declare their MD representation as `[$t; 8]` irrespectively
// of how they are represented in the different channels.
native_simd!(f16x32, f16, 32, u32);
native_simd!(days_msx8, days_ms, 8, u8);
native_simd!(months_days_nsx8, months_days_ns, 8, u8);
native_simd!(i128x8, i128, 8, u8);
Expand Down Expand Up @@ -157,6 +158,7 @@ native!(i8, i8x64);
native!(i16, i16x32);
native!(i32, i32x16);
native!(i64, i64x8);
native!(f16, f16x32);
native!(f32, f32x16);
native!(f64, f64x8);
native!(i128, i128x8);
Expand Down
1 change: 1 addition & 0 deletions src/types/simd/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ native_simd!(i8x64, i8, 64, u64);
native_simd!(i16x32, i16, 32, u32);
native_simd!(i32x16, i32, 16, u16);
native_simd!(i64x8, i64, 8, u8);
native_simd!(f16x32, f16, 32, u32);
native_simd!(f32x16, f32, 16, u16);
native_simd!(f64x8, f64, 8, u8);
2 changes: 2 additions & 0 deletions tests/it/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ fn consistency() {
Int16,
Int32,
Int64,
Float16,
Float32,
Float64,
Timestamp(TimeUnit::Second, None),
Expand Down Expand Up @@ -778,6 +779,7 @@ fn null_array_from_and_to_others() {
typed_test!(UInt32Array, UInt32);
typed_test!(UInt64Array, UInt64);

typed_test!(Float16Array, Float16);
typed_test!(Float32Array, Float32);
typed_test!(Float64Array, Float64);
}
Expand Down
1 change: 1 addition & 0 deletions tests/it/compute/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fn consistency() {
Int16,
Int32,
Int64,
Float16,
Float32,
Float64,
Interval(IntervalUnit::YearMonth),
Expand Down