From 22cb4c58634c854c0c6122b89e3f5eb6d3798de4 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Wed, 9 Oct 2024 18:36:48 +1100 Subject: [PATCH] Further simplify Float encoding --- Cargo.toml | 2 +- src/encoding/float.rs | 120 +++++++++++++++++------------------------- src/writer/column.rs | 6 +-- 3 files changed, 51 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1f976608..9940261e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ rust-version = "1.73" [dependencies] arrow = { version = "52", features = ["prettyprint", "chrono-tz"] } -bytemuck = "1.18.0" +bytemuck = { version = "1.18.0", features = ["must_cast"] } bytes = "1.4" chrono = { version = "0.4.37", default-features = false, features = ["std"] } chrono-tz = "0.9" diff --git a/src/encoding/float.rs b/src/encoding/float.rs index f5679ff0..5b9fa7ec 100644 --- a/src/encoding/float.rs +++ b/src/encoding/float.rs @@ -17,8 +17,7 @@ use std::marker::PhantomData; -use arrow::datatypes::{ArrowPrimitiveType, ToByteSlice}; -use bytemuck::cast_slice_mut; +use bytemuck::{must_cast_slice, must_cast_slice_mut}; use bytes::{Bytes, BytesMut}; use snafu::ResultExt; @@ -29,20 +28,20 @@ use crate::{ use super::{PrimitiveValueDecoder, PrimitiveValueEncoder}; -/// Generically represent f32 and f64. +/// Collect all the required traits we need on floats. pub trait Float: - num::Float + std::fmt::Debug + num::traits::ToBytes + bytemuck::NoUninit + bytemuck::AnyBitPattern + num::Float + std::fmt::Debug + bytemuck::NoUninit + bytemuck::AnyBitPattern { } impl Float for f32 {} impl Float for f64 {} -pub struct FloatDecoder { +pub struct FloatDecoder { reader: R, - phantom: std::marker::PhantomData, + phantom: std::marker::PhantomData, } -impl FloatDecoder { +impl FloatDecoder { pub fn new(reader: R) -> Self { Self { reader, @@ -51,9 +50,9 @@ impl FloatDecoder { } } -impl PrimitiveValueDecoder for FloatDecoder { - fn decode(&mut self, out: &mut [T]) -> Result<()> { - let bytes = cast_slice_mut::(out); +impl PrimitiveValueDecoder for FloatDecoder { + fn decode(&mut self, out: &mut [F]) -> Result<()> { + let bytes = must_cast_slice_mut::(out); self.reader.read_exact(bytes).context(IoSnafu)?; Ok(()) } @@ -62,27 +61,18 @@ impl PrimitiveValueDecoder for FloatDecoder /// No special run encoding for floats/doubles, they are stored as their IEEE 754 floating /// point bit layout. This encoder simply copies incoming floats/doubles to its internal /// byte buffer. -pub struct FloatValueEncoder -where - T::Native: Float, -{ +pub struct FloatEncoder { data: BytesMut, - _phantom: PhantomData, + _phantom: PhantomData, } -impl EstimateMemory for FloatValueEncoder -where - T::Native: Float, -{ +impl EstimateMemory for FloatEncoder { fn estimate_memory_size(&self) -> usize { self.data.len() } } -impl PrimitiveValueEncoder for FloatValueEncoder -where - T::Native: Float, -{ +impl PrimitiveValueEncoder for FloatEncoder { fn new() -> Self { Self { data: BytesMut::new(), @@ -90,14 +80,13 @@ where } } - fn write_one(&mut self, value: T::Native) { - let bytes = value.to_byte_slice(); - self.data.extend_from_slice(bytes); + fn write_one(&mut self, value: F) { + self.write_slice(&[value]); } - fn write_slice(&mut self, values: &[T::Native]) { - let bytes = values.to_byte_slice(); - self.data.extend_from_slice(bytes) + fn write_slice(&mut self, values: &[F]) { + let bytes = must_cast_slice::(values); + self.data.extend_from_slice(bytes); } fn take_inner(&mut self) -> Bytes { @@ -111,70 +100,58 @@ mod tests { use std::f64::consts as f64c; use std::io::Cursor; - use super::*; + use proptest::prelude::*; - fn float_to_bytes(input: &[F]) -> Vec { - input - .iter() - .flat_map(|f| f.to_le_bytes().as_ref().to_vec()) - .collect() - } + use super::*; - fn assert_roundtrip(input: Vec) { - let bytes = float_to_bytes(&input); + fn roundtrip_helper(input: &[F]) -> Result> { + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(input); + let bytes = encoder.take_inner(); let bytes = Cursor::new(bytes); let mut iter = FloatDecoder::::new(bytes); let mut actual = vec![F::zero(); input.len()]; - iter.decode(&mut actual).unwrap(); + iter.decode(&mut actual)?; + Ok(actual) + } + + fn assert_roundtrip(input: Vec) { + let actual = roundtrip_helper(&input).unwrap(); assert_eq!(input, actual); } - #[test] - fn test_float_iter_empty() { - assert_roundtrip::(vec![]); + proptest! { + #[test] + fn roundtrip_f32(values: Vec) { + let out = roundtrip_helper(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_f64(values: Vec) { + let out = roundtrip_helper(&values)?; + prop_assert_eq!(out, values); + } } #[test] - fn test_double_iter_empty() { + fn test_float_edge_cases() { + assert_roundtrip::(vec![]); assert_roundtrip::(vec![]); - } - #[test] - fn test_float_iter_one() { assert_roundtrip(vec![f32c::PI]); - } - - #[test] - fn test_double_iter_one() { assert_roundtrip(vec![f64c::PI]); - } - #[test] - fn test_float_iter_nan() { - let bytes = float_to_bytes(&[f32::NAN]); - let bytes = Cursor::new(bytes); - - let mut iter = FloatDecoder::::new(bytes); - let mut actual = vec![0.0; 1]; - iter.decode(&mut actual).unwrap(); + let actual = roundtrip_helper(&[f32::NAN]).unwrap(); assert!(actual[0].is_nan()); - } - - #[test] - fn test_double_iter_nan() { - let bytes = float_to_bytes(&[f64::NAN]); - let bytes = Cursor::new(bytes); - - let mut iter = FloatDecoder::::new(bytes); - let mut actual = vec![0.0; 1]; - iter.decode(&mut actual).unwrap(); + let actual = roundtrip_helper(&[f64::NAN]).unwrap(); assert!(actual[0].is_nan()); } #[test] - fn test_float_iter_many() { + fn test_float_many() { assert_roundtrip(vec![ f32::NEG_INFINITY, f32::MIN, @@ -186,10 +163,7 @@ mod tests { f32::MAX, f32::INFINITY, ]); - } - #[test] - fn test_double_iter_many() { assert_roundtrip(vec![ f64::NEG_INFINITY, f64::MIN, diff --git a/src/writer/column.rs b/src/writer/column.rs index 1174c17f..f1ec72f4 100644 --- a/src/writer/column.rs +++ b/src/writer/column.rs @@ -30,7 +30,7 @@ use crate::{ encoding::{ boolean::BooleanEncoder, byte::ByteRleEncoder, - float::FloatValueEncoder, + float::FloatEncoder, integer::{rle_v2::RleV2Encoder, NInt, SignedEncoding, UnsignedEncoding}, PrimitiveValueEncoder, }, @@ -390,8 +390,8 @@ where } } -pub type FloatColumnEncoder = PrimitiveColumnEncoder>; -pub type DoubleColumnEncoder = PrimitiveColumnEncoder>; +pub type FloatColumnEncoder = PrimitiveColumnEncoder>; +pub type DoubleColumnEncoder = PrimitiveColumnEncoder>; pub type ByteColumnEncoder = PrimitiveColumnEncoder; pub type Int16ColumnEncoder = PrimitiveColumnEncoder>; pub type Int32ColumnEncoder = PrimitiveColumnEncoder>;