diff --git a/src/io/orc/read/mod.rs b/src/io/orc/read/mod.rs index 629300dbed2..f4723f807ee 100644 --- a/src/io/orc/read/mod.rs +++ b/src/io/orc/read/mod.rs @@ -1,10 +1,11 @@ //! APIs to read from [ORC format](https://orc.apache.org). use std::io::{Read, Seek, SeekFrom}; -use crate::array::{Array, BooleanArray, Float32Array, Int32Array, Int64Array}; +use crate::array::{Array, BooleanArray, Int64Array, PrimitiveArray}; use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::{DataType, Field, Schema}; use crate::error::Error; +use crate::types::NativeType; use orc_format::fallible_streaming_iterator::FallibleStreamingIterator; use orc_format::proto::stream::Kind; @@ -102,11 +103,11 @@ fn deserialize_validity( } /// Deserializes column `column` from `stripe`, assumed to represent a f32 -fn deserialize_f32( +fn deserialize_float( data_type: DataType, stripe: &Stripe, column: usize, -) -> Result { +) -> Result, Error> { let mut scratch = vec![]; let num_rows = stripe.number_of_rows(); @@ -118,23 +119,29 @@ fn deserialize_f32( if let Some(validity) = &validity { let mut validity_iter = validity.iter(); while let Some(chunk) = chunks.next()? { - let mut valid_iter = decode::deserialize_f32(chunk); + let mut valid_iter = chunk + .chunks_exact(std::mem::size_of::()) + .map(|chunk| T::from_le_bytes(chunk.try_into().unwrap_or_default())); let iter = validity_iter.by_ref().map(|is_valid| { if is_valid { valid_iter.next().unwrap() } else { - 0.0f32 + T::default() } }); values.extend(iter); } } else { while let Some(chunk) = chunks.next()? { - values.extend(decode::deserialize_f32(chunk)); + values.extend( + chunk + .chunks_exact(std::mem::size_of::()) + .map(|chunk| T::from_le_bytes(chunk.try_into().unwrap_or_default())), + ); } } - Float32Array::try_new(data_type, values.into(), validity) + PrimitiveArray::try_new(data_type, values.into(), validity) } /// Deserializes column `column` from `stripe`, assumed to represent a boolean array @@ -266,11 +273,14 @@ fn deserialize_i64( } /// Deserializes column `column` from `stripe`, assumed to represent a boolean array -fn deserialize_i32( +fn deserialize_int( data_type: DataType, stripe: &Stripe, column: usize, -) -> Result { +) -> Result, Error> +where + T: NativeType + TryFrom, +{ let num_rows = stripe.number_of_rows(); let mut scratch = vec![]; @@ -278,7 +288,7 @@ fn deserialize_i32( let mut chunks = stripe.get_bytes(column, Kind::Data, std::mem::take(&mut scratch))?; - let mut values = Vec::with_capacity(num_rows); + let mut values = Vec::::with_capacity(num_rows); if let Some(validity) = &validity { let validity_iter = validity.iter(); @@ -292,30 +302,51 @@ fn deserialize_i32( iter = IntIter::new(chunks.next()?.unwrap()); iter.next().transpose()?.unwrap() }; - values.push(item as i32); + let item = item + .try_into() + .map_err(|_| Error::ExternalFormat("value uncastable".to_string()))?; + values.push(item); } else { - values.push(0); + values.push(T::default()); } } } else { while let Some(chunk) = chunks.next()? { decode::SignedRleV2Iter::new(chunk).try_for_each(|run| { - run.map(|run| match run { + run.map_err(Error::from).and_then(|run| match run { decode::SignedRleV2Run::Direct(values_iter) => { - values.extend(values_iter.map(|x| x as i32)) + for item in values_iter { + let item = item.try_into().map_err(|_| { + Error::ExternalFormat("value uncastable".to_string()) + })?; + values.push(item); + } + Ok(()) } decode::SignedRleV2Run::Delta(values_iter) => { - values.extend(values_iter.map(|x| x as i32)) + for item in values_iter { + let item = item.try_into().map_err(|_| { + Error::ExternalFormat("value uncastable".to_string()) + })?; + values.push(item); + } + Ok(()) } decode::SignedRleV2Run::ShortRepeat(values_iter) => { - values.extend(values_iter.map(|x| x as i32)) + for item in values_iter { + let item = item.try_into().map_err(|_| { + Error::ExternalFormat("value uncastable".to_string()) + })?; + values.push(item); + } + Ok(()) } }) })?; } } - Int32Array::try_new(data_type, values.into(), validity) + PrimitiveArray::try_new(data_type, values.into(), validity) } /// Deserializes column `column` from `stripe`, assumed @@ -327,9 +358,12 @@ pub fn deserialize( ) -> Result, Error> { match data_type { DataType::Boolean => deserialize_bool(data_type, stripe, column).map(|x| x.boxed()), - DataType::Int32 => deserialize_i32(data_type, stripe, column).map(|x| x.boxed()), + DataType::Int8 => deserialize_int::(data_type, stripe, column).map(|x| x.boxed()), + DataType::Int16 => deserialize_int::(data_type, stripe, column).map(|x| x.boxed()), + DataType::Int32 => deserialize_int::(data_type, stripe, column).map(|x| x.boxed()), DataType::Int64 => deserialize_i64(data_type, stripe, column).map(|x| x.boxed()), - DataType::Float32 => deserialize_f32(data_type, stripe, column).map(|x| x.boxed()), + DataType::Float32 => deserialize_float::(data_type, stripe, column).map(|x| x.boxed()), + DataType::Float64 => deserialize_float::(data_type, stripe, column).map(|x| x.boxed()), dt => return Err(Error::nyi(format!("Reading {dt:?} from ORC"))), } } diff --git a/src/types/native.rs b/src/types/native.rs index 7e7711589b6..fc100017a76 100644 --- a/src/types/native.rs +++ b/src/types/native.rs @@ -28,7 +28,8 @@ pub trait NativeType: + std::ops::Index + std::ops::IndexMut + for<'a> TryFrom<&'a [u8]> - + std::fmt::Debug; + + std::fmt::Debug + + Default; /// To bytes in little endian fn to_le_bytes(&self) -> Self::Bytes; @@ -36,6 +37,9 @@ pub trait NativeType: /// To bytes in big endian fn to_be_bytes(&self) -> Self::Bytes; + /// From bytes in little endian + fn from_le_bytes(bytes: Self::Bytes) -> Self; + /// From bytes in big endian fn from_be_bytes(bytes: Self::Bytes) -> Self; } @@ -56,6 +60,11 @@ macro_rules! native_type { Self::to_be_bytes(*self) } + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from_le_bytes(bytes) + } + #[inline] fn from_be_bytes(bytes: Self::Bytes) -> Self { Self::from_be_bytes(bytes) @@ -137,6 +146,21 @@ impl NativeType for days_ms { result } + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_le_bytes(days), i32::from_le_bytes(ms)) + } + #[inline] fn from_be_bytes(bytes: Self::Bytes) -> Self { let mut days = [0; 4]; @@ -228,6 +252,29 @@ impl NativeType for months_days_ns { result } + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_le_bytes(months), + i32::from_le_bytes(days), + i64::from_le_bytes(ns), + ) + } + #[inline] fn from_be_bytes(bytes: Self::Bytes) -> Self { let mut months = [0; 4]; @@ -446,6 +493,11 @@ impl NativeType for f16 { fn from_be_bytes(bytes: Self::Bytes) -> Self { Self(u16::from_be_bytes(bytes)) } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_le_bytes(bytes)) + } } #[cfg(test)] diff --git a/tests/it/io/orc/read.rs b/tests/it/io/orc/read.rs index 7347d01e65d..f6dd8fdeba1 100644 --- a/tests/it/io/orc/read.rs +++ b/tests/it/io/orc/read.rs @@ -9,7 +9,7 @@ fn infer() -> Result<(), Error> { let (_, footer, _) = format::read::read_metadata(&mut reader)?; let schema = read::infer_schema(&footer)?; - assert_eq!(schema.fields.len(), 6); + assert_eq!(schema.fields.len(), 8); Ok(()) } @@ -31,6 +31,24 @@ fn float32() -> Result<(), Error> { Ok(()) } +#[test] +fn float64() -> Result<(), Error> { + let mut reader = std::fs::File::open("fixtures/pyorc/test.orc").unwrap(); + let (ps, footer, _) = format::read::read_metadata(&mut reader)?; + let stripe = read::read_stripe(&mut reader, footer.stripes[0].clone(), ps.compression())?; + + assert_eq!( + read::deserialize(DataType::Float64, &stripe, 7)?, + Float64Array::from([Some(1.0), Some(2.0), None, Some(4.0), Some(5.0)]).boxed() + ); + + assert_eq!( + read::deserialize(DataType::Float64, &stripe, 8)?, + Float64Array::from([Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]).boxed() + ); + Ok(()) +} + #[test] fn boolean() -> Result<(), Error> { let mut reader = std::fs::File::open("fixtures/pyorc/test.orc").unwrap(); diff --git a/tests/it/io/orc/write.py b/tests/it/io/orc/write.py index 633d28b8dd9..3416f1d21f4 100644 --- a/tests/it/io/orc/write.py +++ b/tests/it/io/orc/write.py @@ -10,6 +10,8 @@ "bool_required": [True, False, True, True, False], "int_nulable": [5, -5, None, 5, 5], "int_required": [5, -5, 1, 5, 5], + "double_nulable": [1.0, 2.0, None, 4.0, 5.0], + "double_required": [1.0, 2.0, 3.0, 4.0, 5.0], } def infer_schema(data): @@ -26,6 +28,8 @@ def infer_schema(data): dt = "string" else: raise NotImplementedError + if key.startswith("double"): + dt = "double" schema += key + ":" + dt + "," schema = schema[:-1] + ">"