Skip to content

Commit

Permalink
Introduce bytemuck to simplify float decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey committed Oct 8, 2024
1 parent 8213fd8 commit 2d2647b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 32 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rust-version = "1.73"

[dependencies]
arrow = { version = "52", features = ["prettyprint", "chrono-tz"] }
bytemuck = "1.18.0"
bytes = "1.4"
chrono = { version = "0.4.37", default-features = false, features = ["std"] }
chrono-tz = "0.9"
Expand Down
40 changes: 8 additions & 32 deletions src/encoding/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::marker::PhantomData;

use arrow::datatypes::{ArrowPrimitiveType, ToByteSlice};
use bytemuck::cast_slice_mut;
use bytes::{Bytes, BytesMut};
use snafu::ResultExt;

Expand All @@ -29,34 +30,12 @@ use crate::{
use super::{PrimitiveValueDecoder, PrimitiveValueEncoder};

/// Generically represent f32 and f64.
// TODO: figure out how to use num::traits::FromBytes instead of rolling our own?
pub trait Float: num::Float + std::fmt::Debug + num::traits::ToBytes {
const BYTE_SIZE: usize;

fn from_le_bytes(bytes: &[u8]) -> Self;
}

impl Float for f32 {
const BYTE_SIZE: usize = 4;

#[inline]
fn from_le_bytes(bytes: &[u8]) -> Self {
debug_assert!(Self::BYTE_SIZE == bytes.len());
Self::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
}

impl Float for f64 {
const BYTE_SIZE: usize = 8;

#[inline]
fn from_le_bytes(bytes: &[u8]) -> Self {
debug_assert!(Self::BYTE_SIZE == bytes.len());
Self::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
])
}
pub trait Float:
num::Float + std::fmt::Debug + num::traits::ToBytes + bytemuck::NoUninit + bytemuck::AnyBitPattern
{
}
impl Float for f32 {}
impl Float for f64 {}

pub struct FloatDecoder<T: Float, R: std::io::Read> {
reader: R,
Expand All @@ -74,11 +53,8 @@ impl<T: Float, R: std::io::Read> FloatDecoder<T, R> {

impl<T: Float, R: std::io::Read> PrimitiveValueDecoder<T> for FloatDecoder<T, R> {
fn decode(&mut self, out: &mut [T]) -> Result<()> {
let mut buf = vec![0; out.len() * T::BYTE_SIZE];
self.reader.read_exact(&mut buf).context(IoSnafu)?;
for (out_float, bytes) in out.iter_mut().zip(buf.chunks(T::BYTE_SIZE)) {
*out_float = T::from_le_bytes(bytes);
}
let bytes = cast_slice_mut::<T, u8>(out);
self.reader.read_exact(bytes).context(IoSnafu)?;
Ok(())
}
}
Expand Down

0 comments on commit 2d2647b

Please sign in to comment.