Skip to content

Commit

Permalink
Further simplify Float encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey committed Oct 9, 2024
1 parent 2d2647b commit 22cb4c5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rust-version = "1.73"

[dependencies]
arrow = { version = "52", features = ["prettyprint", "chrono-tz"] }
bytemuck = "1.18.0"
bytemuck = { version = "1.18.0", features = ["must_cast"] }
bytes = "1.4"
chrono = { version = "0.4.37", default-features = false, features = ["std"] }
chrono-tz = "0.9"
Expand Down
120 changes: 47 additions & 73 deletions src/encoding/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

use std::marker::PhantomData;

use arrow::datatypes::{ArrowPrimitiveType, ToByteSlice};
use bytemuck::cast_slice_mut;
use bytemuck::{must_cast_slice, must_cast_slice_mut};
use bytes::{Bytes, BytesMut};
use snafu::ResultExt;

Expand All @@ -29,20 +28,20 @@ use crate::{

use super::{PrimitiveValueDecoder, PrimitiveValueEncoder};

/// Generically represent f32 and f64.
/// Collect all the required traits we need on floats.
pub trait Float:
num::Float + std::fmt::Debug + num::traits::ToBytes + bytemuck::NoUninit + bytemuck::AnyBitPattern
num::Float + std::fmt::Debug + bytemuck::NoUninit + bytemuck::AnyBitPattern
{
}
impl Float for f32 {}
impl Float for f64 {}

pub struct FloatDecoder<T: Float, R: std::io::Read> {
pub struct FloatDecoder<F: Float, R: std::io::Read> {
reader: R,
phantom: std::marker::PhantomData<T>,
phantom: std::marker::PhantomData<F>,
}

impl<T: Float, R: std::io::Read> FloatDecoder<T, R> {
impl<F: Float, R: std::io::Read> FloatDecoder<F, R> {
pub fn new(reader: R) -> Self {
Self {
reader,
Expand All @@ -51,9 +50,9 @@ impl<T: Float, R: std::io::Read> FloatDecoder<T, R> {
}
}

impl<T: Float, R: std::io::Read> PrimitiveValueDecoder<T> for FloatDecoder<T, R> {
fn decode(&mut self, out: &mut [T]) -> Result<()> {
let bytes = cast_slice_mut::<T, u8>(out);
impl<F: Float, R: std::io::Read> PrimitiveValueDecoder<F> for FloatDecoder<F, R> {
fn decode(&mut self, out: &mut [F]) -> Result<()> {
let bytes = must_cast_slice_mut::<F, u8>(out);
self.reader.read_exact(bytes).context(IoSnafu)?;
Ok(())
}
Expand All @@ -62,42 +61,32 @@ impl<T: Float, R: std::io::Read> PrimitiveValueDecoder<T> for FloatDecoder<T, R>
/// No special run encoding for floats/doubles, they are stored as their IEEE 754 floating
/// point bit layout. This encoder simply copies incoming floats/doubles to its internal
/// byte buffer.
pub struct FloatValueEncoder<T: ArrowPrimitiveType>
where
T::Native: Float,
{
pub struct FloatEncoder<F: Float> {
data: BytesMut,
_phantom: PhantomData<T>,
_phantom: PhantomData<F>,
}

impl<T: ArrowPrimitiveType> EstimateMemory for FloatValueEncoder<T>
where
T::Native: Float,
{
impl<F: Float> EstimateMemory for FloatEncoder<F> {
fn estimate_memory_size(&self) -> usize {
self.data.len()
}
}

impl<T: ArrowPrimitiveType> PrimitiveValueEncoder<T::Native> for FloatValueEncoder<T>
where
T::Native: Float,
{
impl<F: Float> PrimitiveValueEncoder<F> for FloatEncoder<F> {
fn new() -> Self {
Self {
data: BytesMut::new(),
_phantom: Default::default(),
}
}

fn write_one(&mut self, value: T::Native) {
let bytes = value.to_byte_slice();
self.data.extend_from_slice(bytes);
fn write_one(&mut self, value: F) {
self.write_slice(&[value]);
}

fn write_slice(&mut self, values: &[T::Native]) {
let bytes = values.to_byte_slice();
self.data.extend_from_slice(bytes)
fn write_slice(&mut self, values: &[F]) {
let bytes = must_cast_slice::<F, u8>(values);
self.data.extend_from_slice(bytes);
}

fn take_inner(&mut self) -> Bytes {
Expand All @@ -111,70 +100,58 @@ mod tests {
use std::f64::consts as f64c;
use std::io::Cursor;

use super::*;
use proptest::prelude::*;

fn float_to_bytes<F: Float>(input: &[F]) -> Vec<u8> {
input
.iter()
.flat_map(|f| f.to_le_bytes().as_ref().to_vec())
.collect()
}
use super::*;

fn assert_roundtrip<F: Float>(input: Vec<F>) {
let bytes = float_to_bytes(&input);
fn roundtrip_helper<F: Float>(input: &[F]) -> Result<Vec<F>> {
let mut encoder = FloatEncoder::<F>::new();
encoder.write_slice(input);
let bytes = encoder.take_inner();
let bytes = Cursor::new(bytes);

let mut iter = FloatDecoder::<F, _>::new(bytes);
let mut actual = vec![F::zero(); input.len()];
iter.decode(&mut actual).unwrap();
iter.decode(&mut actual)?;

Ok(actual)
}

fn assert_roundtrip<F: Float>(input: Vec<F>) {
let actual = roundtrip_helper(&input).unwrap();
assert_eq!(input, actual);
}

#[test]
fn test_float_iter_empty() {
assert_roundtrip::<f32>(vec![]);
proptest! {
#[test]
fn roundtrip_f32(values: Vec<f32>) {
let out = roundtrip_helper(&values)?;
prop_assert_eq!(out, values);
}

#[test]
fn roundtrip_f64(values: Vec<f64>) {
let out = roundtrip_helper(&values)?;
prop_assert_eq!(out, values);
}
}

#[test]
fn test_double_iter_empty() {
fn test_float_edge_cases() {
assert_roundtrip::<f32>(vec![]);
assert_roundtrip::<f64>(vec![]);
}

#[test]
fn test_float_iter_one() {
assert_roundtrip(vec![f32c::PI]);
}

#[test]
fn test_double_iter_one() {
assert_roundtrip(vec![f64c::PI]);
}

#[test]
fn test_float_iter_nan() {
let bytes = float_to_bytes(&[f32::NAN]);
let bytes = Cursor::new(bytes);

let mut iter = FloatDecoder::<f32, _>::new(bytes);
let mut actual = vec![0.0; 1];
iter.decode(&mut actual).unwrap();
let actual = roundtrip_helper(&[f32::NAN]).unwrap();
assert!(actual[0].is_nan());
}

#[test]
fn test_double_iter_nan() {
let bytes = float_to_bytes(&[f64::NAN]);
let bytes = Cursor::new(bytes);

let mut iter = FloatDecoder::<f64, _>::new(bytes);
let mut actual = vec![0.0; 1];
iter.decode(&mut actual).unwrap();
let actual = roundtrip_helper(&[f64::NAN]).unwrap();
assert!(actual[0].is_nan());
}

#[test]
fn test_float_iter_many() {
fn test_float_many() {
assert_roundtrip(vec![
f32::NEG_INFINITY,
f32::MIN,
Expand All @@ -186,10 +163,7 @@ mod tests {
f32::MAX,
f32::INFINITY,
]);
}

#[test]
fn test_double_iter_many() {
assert_roundtrip(vec![
f64::NEG_INFINITY,
f64::MIN,
Expand Down
6 changes: 3 additions & 3 deletions src/writer/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::{
encoding::{
boolean::BooleanEncoder,
byte::ByteRleEncoder,
float::FloatValueEncoder,
float::FloatEncoder,
integer::{rle_v2::RleV2Encoder, NInt, SignedEncoding, UnsignedEncoding},
PrimitiveValueEncoder,
},
Expand Down Expand Up @@ -390,8 +390,8 @@ where
}
}

pub type FloatColumnEncoder = PrimitiveColumnEncoder<Float32Type, FloatValueEncoder<Float32Type>>;
pub type DoubleColumnEncoder = PrimitiveColumnEncoder<Float64Type, FloatValueEncoder<Float64Type>>;
pub type FloatColumnEncoder = PrimitiveColumnEncoder<Float32Type, FloatEncoder<f32>>;
pub type DoubleColumnEncoder = PrimitiveColumnEncoder<Float64Type, FloatEncoder<f64>>;
pub type ByteColumnEncoder = PrimitiveColumnEncoder<Int8Type, ByteRleEncoder>;
pub type Int16ColumnEncoder = PrimitiveColumnEncoder<Int16Type, RleV2Encoder<i16, SignedEncoding>>;
pub type Int32ColumnEncoder = PrimitiveColumnEncoder<Int32Type, RleV2Encoder<i32, SignedEncoding>>;
Expand Down

0 comments on commit 22cb4c5

Please sign in to comment.