From b82d43a5d6e33f59f197d59ceacf4fe0a88565b1 Mon Sep 17 00:00:00 2001 From: taichong Date: Wed, 1 Mar 2023 19:35:51 +0800 Subject: [PATCH] support decimal256(76,0) --- parquet_integration/write_parquet.py | 4 +++ src/io/parquet/read/deserialize/simple.rs | 2 +- src/io/parquet/read/mod.rs | 13 ++++--- src/io/parquet/read/row_group.rs | 3 -- src/io/parquet/read/statistics/fixlen.rs | 8 ++--- src/io/parquet/read/statistics/mod.rs | 2 +- src/io/parquet/write/mod.rs | 4 +-- src/io/parquet/write/schema.rs | 3 +- src/types/native.rs | 10 +++--- tests/it/io/parquet/mod.rs | 24 +++++++++++++ tests/it/io/parquet/read.rs | 25 +++++++++++++ tests/it/io/parquet/write.rs | 44 +++++++++++++++++++++++ 12 files changed, 120 insertions(+), 22 deletions(-) diff --git a/parquet_integration/write_parquet.py b/parquet_integration/write_parquet.py index d6e6dfe1b51..acfd819d57c 100644 --- a/parquet_integration/write_parquet.py +++ b/parquet_integration/write_parquet.py @@ -36,6 +36,7 @@ def case_basic_nullable() -> Tuple[dict, pa.Schema, str]: pa.field("decimal256_18", pa.decimal256(18, 0)), pa.field("decimal256_26", pa.decimal256(26, 0)), pa.field("decimal256_39", pa.decimal256(39, 0)), + pa.field("decimal256_76", pa.decimal256(76, 0)), pa.field("timestamp_us", pa.timestamp("us")), pa.field("timestamp_s", pa.timestamp("s")), pa.field("emoji", pa.utf8()), @@ -58,6 +59,7 @@ def case_basic_nullable() -> Tuple[dict, pa.Schema, str]: "decimal256_18": decimal, "decimal256_26": decimal, "decimal256_39": decimal, + "decimal256_76": decimal, "timestamp_us": int64, "timestamp_s": int64, "emoji": emoji, @@ -94,6 +96,7 @@ def case_basic_required() -> Tuple[dict, pa.Schema, str]: pa.field("decimal256_18", pa.decimal256(18, 0), nullable=False), pa.field("decimal256_26", pa.decimal256(26, 0), nullable=False), pa.field("decimal256_39", pa.decimal256(39, 0), nullable=False), + pa.field("decimal256_76", pa.decimal256(76, 0), nullable=False), ] schema = pa.schema(fields) @@ -112,6 +115,7 @@ def case_basic_required() -> Tuple[dict, pa.Schema, str]: "decimal256_18": decimal, "decimal256_26": decimal, "decimal256_39": decimal, + "decimal256_76": decimal, }, schema, f"basic_required_10.parquet", diff --git a/src/io/parquet/read/deserialize/simple.rs b/src/io/parquet/read/deserialize/simple.rs index 0b1e3926239..b4b614980e8 100644 --- a/src/io/parquet/read/deserialize/simple.rs +++ b/src/io/parquet/read/deserialize/simple.rs @@ -285,7 +285,7 @@ pub fn page_iter_to_arrays<'a, I: Pages + 'a>( let values = array .values() .chunks_exact(n) - .map(|value: &[u8]| super::super::convert_i256(value, n)) + .map(super::super::convert_i256) .collect::>(); let validity = array.validity().cloned(); diff --git a/src/io/parquet/read/mod.rs b/src/io/parquet/read/mod.rs index b9d60c68cf0..709eb221f01 100644 --- a/src/io/parquet/read/mod.rs +++ b/src/io/parquet/read/mod.rs @@ -83,9 +83,14 @@ fn convert_i128(value: &[u8], n: usize) -> i128 { i128::from_be_bytes(bytes) >> (8 * (16 - n)) } -fn convert_i256(value: &[u8], n: usize) -> i256 { +fn convert_i256(value: &[u8]) -> i256 { let mut bytes = [0u8; 32]; - bytes[..n].copy_from_slice(value); - - i256(i256::from_be_bytes(bytes).0 >> (8 * (32 - n))) + let mut neg_bytes = [255u8; 32]; + if value[0] >= 128 { + neg_bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(neg_bytes) + } else { + bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(bytes) + } } diff --git a/src/io/parquet/read/row_group.rs b/src/io/parquet/read/row_group.rs index e475010e606..176c6e83182 100644 --- a/src/io/parquet/read/row_group.rs +++ b/src/io/parquet/read/row_group.rs @@ -225,9 +225,6 @@ pub fn to_deserializer<'a>( (columns, types) } else { - for (meta, chunk) in columns.clone() { - println!("the meta is {:?},\nthe chunk is {:?}", meta, chunk); - } let (columns, types): (Vec<_>, Vec<_>) = columns .into_iter() .map(|(column_meta, chunk)| { diff --git a/src/io/parquet/read/statistics/fixlen.rs b/src/io/parquet/read/statistics/fixlen.rs index 1362d22bb97..04d881da3b1 100644 --- a/src/io/parquet/read/statistics/fixlen.rs +++ b/src/io/parquet/read/statistics/fixlen.rs @@ -3,9 +3,10 @@ use parquet2::statistics::{FixedLenStatistics, Statistics as ParquetStatistics}; use crate::array::*; use crate::error::Result; +use crate::io::parquet::read::convert_i256; use crate::types::{days_ms, i256}; -use super::super::{convert_days_ms, convert_i128, convert_i256}; +use super::super::{convert_days_ms, convert_i128}; pub(super) fn push_i128( from: Option<&dyn ParquetStatistics>, @@ -61,7 +62,6 @@ pub(super) fn push_i256_with_i128( pub(super) fn push_i256( from: Option<&dyn ParquetStatistics>, - n: usize, min: &mut dyn MutableArray, max: &mut dyn MutableArray, ) -> Result<()> { @@ -75,8 +75,8 @@ pub(super) fn push_i256( .unwrap(); let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); - min.push(from.and_then(|s| s.min_value.as_deref().map(|x| convert_i256(x, n)))); - max.push(from.and_then(|s| s.max_value.as_deref().map(|x| convert_i256(x, n)))); + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_i256))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_i256))); Ok(()) } diff --git a/src/io/parquet/read/statistics/mod.rs b/src/io/parquet/read/statistics/mod.rs index 27f137f0f34..f3c1ed9e8de 100644 --- a/src/io/parquet/read/statistics/mod.rs +++ b/src/io/parquet/read/statistics/mod.rs @@ -530,7 +530,7 @@ fn push( ParquetPhysicalType::FixedLenByteArray(n) if *n > 32 => Err(Error::NotYetImplemented( format!("Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}"), )), - ParquetPhysicalType::FixedLenByteArray(n) => fixlen::push_i256(from, *n, min, max), + ParquetPhysicalType::FixedLenByteArray(_) => fixlen::push_i256(from, min, max), _ => unreachable!(), }, Binary => binary::push::(from, min, max), diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 34c07a535b6..fad4ff6ad90 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -518,7 +518,7 @@ pub fn array_to_page_simple( ); fixed_len_bytes::array_to_page(&array, options, type_, statistics) } else { - let size = decimal_length_from_precision(precision); + let size = 32; let array = array .as_any() .downcast_ref::>() @@ -532,7 +532,7 @@ pub fn array_to_page_simple( }; let mut values = Vec::::with_capacity(size * array.len()); array.values().iter().for_each(|x| { - let bytes = &x.to_be_bytes()[32 - size..]; + let bytes = &x.to_be_bytes(); values.extend_from_slice(bytes) }); let array = FixedSizeBinaryArray::new( diff --git a/src/io/parquet/write/schema.rs b/src/io/parquet/write/schema.rs index 68e7786e49a..f4af963a8ed 100644 --- a/src/io/parquet/write/schema.rs +++ b/src/io/parquet/write/schema.rs @@ -329,10 +329,9 @@ pub fn to_parquet_type(field: &Field) -> Result { None, )?) } else { - let len = decimal_length_from_precision(precision); Ok(ParquetType::try_from_primitive( name, - PhysicalType::FixedLenByteArray(len), + PhysicalType::FixedLenByteArray(32), repetition, None, None, diff --git a/src/types/native.rs b/src/types/native.rs index 40b79bcf1a2..f66ceb8403e 100644 --- a/src/types/native.rs +++ b/src/types/native.rs @@ -577,14 +577,14 @@ impl NativeType for i256 { let mut bytes = [0u8; 32]; let (a, b) = self.0.into_words(); - let b = b.to_be_bytes(); + let a = a.to_be_bytes(); (0..16).for_each(|i| { - bytes[i] = b[i]; + bytes[i] = a[i]; }); - let a = a.to_be_bytes(); + let b = b.to_be_bytes(); (0..16).for_each(|i| { - bytes[i + 16] = a[i]; + bytes[i + 16] = b[i]; }); bytes @@ -592,7 +592,7 @@ impl NativeType for i256 { #[inline] fn from_be_bytes(bytes: Self::Bytes) -> Self { - let (b, a) = bytes.split_at(16); + let (a, b) = bytes.split_at(16); let a: [u8; 16] = a.try_into().unwrap(); let b: [u8; 16] = b.try_into().unwrap(); let a = i128::from_be_bytes(a); diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 83e16f1951e..60b9f44cae4 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -558,6 +558,13 @@ pub fn pyarrow_nullable(column: &str) -> Box { .collect::>(); Box::new(PrimitiveArray::::from(values).to(DataType::Decimal256(39, 0))) } + "decimal256_76" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| i256(x.as_i256()))) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal256(76, 0))) + } "timestamp_us" => Box::new( PrimitiveArray::::from(i64_values) .to(DataType::Timestamp(TimeUnit::Microsecond, None)), @@ -684,6 +691,16 @@ pub fn pyarrow_nullable_statistics(column: &str) -> Statistics { Int256Array::from_slice([i256(9.as_i256())]).to(DataType::Decimal256(39, 0)), ), }, + "decimal256_76" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int256Array::from_slice([i256(-(256.as_i256()))]).to(DataType::Decimal256(76, 0)), + ), + max_value: Box::new( + Int256Array::from_slice([i256(9.as_i256())]).to(DataType::Decimal256(76, 0)), + ), + }, "timestamp_us" => Statistics { distinct_count: UInt64Array::from([None]).boxed(), null_count: UInt64Array::from([Some(3)]).boxed(), @@ -792,6 +809,13 @@ pub fn pyarrow_required(column: &str) -> Box { .collect::>(); Box::new(PrimitiveArray::::from(values).to(DataType::Decimal256(39, 0))) } + "decimal256_76" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| i256(x.as_i256()))) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal256(76, 0))) + } _ => unreachable!(), } } diff --git a/tests/it/io/parquet/read.rs b/tests/it/io/parquet/read.rs index 58e4c685bc2..786bdf6f96d 100644 --- a/tests/it/io/parquet/read.rs +++ b/tests/it/io/parquet/read.rs @@ -441,6 +441,16 @@ fn v1_decimal256_39_required() -> Result<()> { test_pyarrow_integration("decimal256_39", 1, "basic", false, true, None) } +#[test] +fn v1_decimal256_76_nullable() -> Result<()> { + test_pyarrow_integration("decimal256_76", 1, "basic", false, false, None) +} + +#[test] +fn v1_decimal256_76_required() -> Result<()> { + test_pyarrow_integration("decimal256_76", 1, "basic", false, true, None) +} + #[test] fn v2_decimal_9_nullable() -> Result<()> { test_pyarrow_integration("decimal_9", 2, "basic", false, false, None) @@ -496,6 +506,11 @@ fn v2_decimal256_39_nullable() -> Result<()> { test_pyarrow_integration("decimal256_39", 2, "basic", false, false, None) } +#[test] +fn v2_decimal256_76_nullable() -> Result<()> { + test_pyarrow_integration("decimal256_76", 2, "basic", false, false, None) +} + #[test] fn v1_timestamp_us_nullable() -> Result<()> { test_pyarrow_integration("timestamp_us", 1, "basic", false, false, None) @@ -566,6 +581,16 @@ fn v2_decimal256_39_required_dict() -> Result<()> { test_pyarrow_integration("decimal256_39", 2, "basic", true, true, None) } +#[test] +fn v2_decimal256_76_required() -> Result<()> { + test_pyarrow_integration("decimal256_76", 2, "basic", false, true, None) +} + +#[test] +fn v2_decimal256_76_required_dict() -> Result<()> { + test_pyarrow_integration("decimal256_76", 2, "basic", true, true, None) +} + #[test] fn v1_struct_required_optional() -> Result<()> { test_pyarrow_integration("struct", 1, "struct", false, false, None) diff --git a/tests/it/io/parquet/write.rs b/tests/it/io/parquet/write.rs index 0c2b8e05d6e..e874d706d6c 100644 --- a/tests/it/io/parquet/write.rs +++ b/tests/it/io/parquet/write.rs @@ -627,6 +627,28 @@ fn decimal256_39_required_v1() -> Result<()> { ) } +#[test] +fn decimal256_76_optional_v1() -> Result<()> { + round_trip( + "decimal256_76", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn decimal256_76_required_v1() -> Result<()> { + round_trip( + "decimal256_76", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + #[test] fn decimal_9_optional_v2() -> Result<()> { round_trip( @@ -781,6 +803,28 @@ fn decimal256_39_required_v2() -> Result<()> { ) } +#[test] +fn decimal256_76_optional_v2() -> Result<()> { + round_trip( + "decimal256_76", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn decimal256_76_required_v2() -> Result<()> { + round_trip( + "decimal256_76", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + #[test] fn struct_v1() -> Result<()> { round_trip(