From 2324a8f2f49f9a891f1f917c5b46a2a83fc2ecc0 Mon Sep 17 00:00:00 2001 From: Tommy Tran Date: Wed, 6 Oct 2021 18:26:07 +0700 Subject: [PATCH] Added read `Decimal` from parquet (#489) Co-authored-by: thanhtm1 --- parquet_integration/write_parquet.py | 16 +++ src/io/parquet/read/mod.rs | 53 +++++-- src/io/parquet/read/schema/convert.rs | 3 + src/io/parquet/read/schema/metadata.rs | 5 +- src/io/parquet/read/statistics/fixlen.rs | 105 ++++++++++++++ src/io/parquet/read/statistics/mod.rs | 6 + src/io/parquet/read/statistics/primitive.rs | 2 + src/io/parquet/write/fixed_len_bytes.rs | 43 +++++- tests/it/io/parquet/mod.rs | 94 +++++++++++++ tests/it/io/parquet/read.rs | 61 ++++++++ tests/it/io/parquet/write.rs | 145 ++++++++++++++++++++ 11 files changed, 518 insertions(+), 15 deletions(-) create mode 100644 src/io/parquet/read/statistics/fixlen.rs diff --git a/parquet_integration/write_parquet.py b/parquet_integration/write_parquet.py index d941cb4e6df..0d9e556216d 100644 --- a/parquet_integration/write_parquet.py +++ b/parquet_integration/write_parquet.py @@ -1,6 +1,7 @@ import pyarrow as pa import pyarrow.parquet import os +from decimal import Decimal PYARROW_PATH = "fixtures/pyarrow3" @@ -11,6 +12,7 @@ def case_basic_nullable(size=1): string = ["Hello", None, "aa", "", None, "abc", None, None, "def", "aaa"] boolean = [True, None, False, False, None, True, None, None, True, True] string_large = ["ABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDšŸ˜ƒšŸŒššŸ•³šŸ‘Š"] * 10 + decimal = [Decimal(e) if e is not None else None for e in int64] fields = [ pa.field("int64", pa.int64()), @@ -20,6 +22,10 @@ def case_basic_nullable(size=1): pa.field("date", pa.timestamp("ms")), pa.field("uint32", pa.uint32()), pa.field("string_large", pa.utf8()), + # decimal testing + pa.field("decimal_9", pa.decimal128(9,0)), + pa.field("decimal_18", pa.decimal128(18,0)), + pa.field("decimal_26", pa.decimal128(26,0)), ] schema = pa.schema(fields) @@ -32,6 +38,9 @@ def case_basic_nullable(size=1): "date": int64 * size, "uint32": int64 * size, "string_large": string_large * size, + "decimal_9": decimal * size, + "decimal_18": decimal * size, + "decimal_26": decimal * size, }, schema, f"basic_nullable_{size*10}.parquet", @@ -43,6 +52,7 @@ def case_basic_required(size=1): float64 = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] string = ["Hello", "bbb", "aa", "", "bbb", "abc", "bbb", "bbb", "def", "aaa"] boolean = [True, True, False, False, False, True, True, True, True, True] + decimal = [Decimal(e) for e in int64] fields = [ pa.field("int64", pa.int64(), nullable=False), @@ -57,6 +67,9 @@ def case_basic_required(size=1): nullable=False, ), pa.field("uint32", pa.uint32(), nullable=False), + pa.field("decimal_9", pa.decimal128(9,0), nullable=False), + pa.field("decimal_18", pa.decimal128(18,0), nullable=False), + pa.field("decimal_26", pa.decimal128(26,0), nullable=False), ] schema = pa.schema(fields) @@ -68,6 +81,9 @@ def case_basic_required(size=1): "bool": boolean * size, "date": int64 * size, "uint32": int64 * size, + "decimal_9": decimal * size, + "decimal_18": decimal * size, + "decimal_26": decimal * size, }, schema, f"basic_required_{size*10}.parquet", diff --git a/src/io/parquet/read/mod.rs b/src/io/parquet/read/mod.rs index 90ea870e938..f05f2669936 100644 --- a/src/io/parquet/read/mod.rs +++ b/src/io/parquet/read/mod.rs @@ -1,8 +1,5 @@ //! APIs to read from Parquet format. -use std::{ - io::{Read, Seek}, - sync::Arc, -}; +use std::{convert::TryInto, io::{Read, Seek}, sync::Arc}; use futures::{AsyncRead, AsyncSeek, Stream}; pub use parquet2::{ @@ -21,11 +18,7 @@ pub use parquet2::{ types::int96_to_i64_ns, }; -use crate::{ - array::{Array, DictionaryKey}, - datatypes::{DataType, IntervalUnit, TimeUnit}, - error::{ArrowError, Result}, -}; +use crate::{array::{Array, DictionaryKey, PrimitiveArray}, datatypes::{DataType, IntervalUnit, TimeUnit}, error::{ArrowError, Result}}; mod binary; mod boolean; @@ -211,7 +204,47 @@ pub fn page_iter_to_array< FixedSizeBinary(_) => Ok(Box::new(fixed_size_binary::iter_to_array( iter, data_type, metadata, )?)), - + Decimal(_, _) => match metadata.descriptor().type_() { + ParquetType::PrimitiveType { physical_type, ..} => match physical_type{ + PhysicalType::Int32 => primitive::iter_to_array( + iter, + metadata, + data_type, + |x: i32| x as i128, + ), + PhysicalType::Int64 => primitive::iter_to_array( + iter, + metadata, + data_type, + |x: i64| x as i128, + ), + PhysicalType::FixedLenByteArray(n) => { + if *n > 16 { + Err(ArrowError::NotYetImplemented(format!( + "Can't decode Decimal128 type from Fixed Size Byte Array of len {:?}", + n + ))) + } else { + let paddings = (0..(16-*n)).map(|_| 0u8).collect::>(); + fixed_size_binary::iter_to_array(iter, DataType::FixedSizeBinary(*n), metadata) + .map(|e|{ + let a = e.into_iter().map(|v| + v.and_then(|v1| { + [&paddings, v1].concat().try_into().map( + |pad16| i128::from_be_bytes(pad16) + ).ok() + } + ) + ).collect::>(); + Box::new(PrimitiveArray::::from(a).to(data_type)) as Box + } + ) + } + }, + _ => unreachable!() + }, + _ => unreachable!() + }, List(ref inner) => match inner.data_type() { UInt8 => primitive::iter_to_array_nested(iter, metadata, data_type, |x: i32| x as u8), UInt16 => primitive::iter_to_array_nested(iter, metadata, data_type, |x: i32| x as u16), diff --git a/src/io/parquet/read/schema/convert.rs b/src/io/parquet/read/schema/convert.rs index 1c92683b161..14c7c7c2edc 100644 --- a/src/io/parquet/read/schema/convert.rs +++ b/src/io/parquet/read/schema/convert.rs @@ -167,6 +167,9 @@ pub fn from_int64( ParquetTimeUnit::MICROS(_) => DataType::Time64(TimeUnit::Microsecond), ParquetTimeUnit::NANOS(_) => DataType::Time64(TimeUnit::Nanosecond), }, + (Some(PrimitiveConvertedType::Decimal(precision,scale)), _) => { + DataType::Decimal(*precision as usize, *scale as usize) + } (c, l) => { return Err(ArrowError::NotYetImplemented(format!( "The conversion of (Int64, {:?}, {:?}) to arrow still not implemented", diff --git a/src/io/parquet/read/schema/metadata.rs b/src/io/parquet/read/schema/metadata.rs index 8b511612afe..54508ca05ea 100644 --- a/src/io/parquet/read/schema/metadata.rs +++ b/src/io/parquet/read/schema/metadata.rs @@ -120,7 +120,10 @@ mod tests { "bool", "date", "uint32", - "string_large" + "string_large", + "decimal_9", + "decimal_18", + "decimal_26" ] ); Ok(()) diff --git a/src/io/parquet/read/statistics/fixlen.rs b/src/io/parquet/read/statistics/fixlen.rs new file mode 100644 index 00000000000..f550b3b63ad --- /dev/null +++ b/src/io/parquet/read/statistics/fixlen.rs @@ -0,0 +1,105 @@ +use std::convert::{TryFrom, TryInto}; + +use super::super::schema; +use super::primitive::PrimitiveStatistics; +use crate::datatypes::DataType; +use crate::error::{ArrowError, Result}; +use parquet2::schema::types::ParquetType; +use parquet2::{ + schema::types::PhysicalType, + statistics::{ + FixedLenStatistics as ParquetFixedLenStatistics, Statistics as ParquetStatistics, + }, +}; + +use super::Statistics; + +#[derive(Debug, Clone, PartialEq)] +pub struct FixedLenStatistics { + pub null_count: Option, + pub distinct_count: Option, + pub min_value: Option>, + pub max_value: Option>, + pub data_type: DataType, +} + +impl Statistics for FixedLenStatistics { + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +impl From<&ParquetFixedLenStatistics> for FixedLenStatistics { + fn from(stats: &ParquetFixedLenStatistics) -> Self { + let byte_lens = match stats.physical_type() { + PhysicalType::FixedLenByteArray(size) => *size, + _ => unreachable!(), + }; + Self { + null_count: stats.null_count, + distinct_count: stats.distinct_count, + min_value: stats.min_value.clone(), + max_value: stats.max_value.clone(), + data_type: DataType::FixedSizeBinary(byte_lens), + } + } +} + +impl TryFrom<(&ParquetFixedLenStatistics, DataType)> for PrimitiveStatistics { + type Error = ArrowError; + fn try_from((stats, data_type): (&ParquetFixedLenStatistics, DataType)) -> Result { + let byte_lens = match stats.physical_type() { + PhysicalType::FixedLenByteArray(size) => *size, + _ => unreachable!(), + }; + if byte_lens > 16 { + Err(ArrowError::Other(format!( + "Can't deserialize i128 from Fixed Len Byte array with lengtg {:?}", + byte_lens + ))) + } else { + let paddings = (0..(16 - byte_lens)).map(|_| 0u8).collect::>(); + let max_value = stats.max_value.as_ref().and_then(|value| { + [paddings.as_slice(), value] + .concat() + .try_into() + .map(|v| i128::from_be_bytes(v)) + .ok() + }); + + let min_value = stats.min_value.as_ref().and_then(|value| { + [paddings.as_slice(), value] + .concat() + .try_into() + .map(|v| i128::from_be_bytes(v)) + .ok() + }); + Ok(Self { + data_type, + null_count: stats.null_count, + distinct_count: stats.distinct_count, + max_value, + min_value, + }) + } + } +} + +pub(super) fn statistics_from_fix_len( + stats: &ParquetFixedLenStatistics, + type_: &ParquetType, +) -> Result> { + let data_type = schema::to_data_type(type_)?.unwrap(); + + use DataType::*; + Ok(match data_type { + Decimal(_, _) => Box::new(PrimitiveStatistics::::try_from((stats, data_type))?), + FixedSizeBinary(_) => Box::new(FixedLenStatistics::from(stats)), + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Can't read {:?} from parquet", + other + ))) + } + }) +} diff --git a/src/io/parquet/read/statistics/mod.rs b/src/io/parquet/read/statistics/mod.rs index 27b2cc2790a..3f554f82958 100644 --- a/src/io/parquet/read/statistics/mod.rs +++ b/src/io/parquet/read/statistics/mod.rs @@ -13,6 +13,8 @@ mod binary; pub use binary::*; mod boolean; pub use boolean::*; +mod fixlen; +pub use fixlen::*; /// Trait representing a deserialized parquet statistics into arrow. pub trait Statistics: std::fmt::Debug { @@ -70,6 +72,10 @@ pub fn deserialize_statistics(stats: &dyn ParquetStatistics) -> Result{ + let stats = stats.as_any().downcast_ref().unwrap(); + fixlen::statistics_from_fix_len(stats, stats.descriptor.type_()) + } _ => Err(ArrowError::NotYetImplemented( "Reading Fixed-len array statistics is not yet supported".to_string(), )), diff --git a/src/io/parquet/read/statistics/primitive.rs b/src/io/parquet/read/statistics/primitive.rs index d669425e507..f807754ea0d 100644 --- a/src/io/parquet/read/statistics/primitive.rs +++ b/src/io/parquet/read/statistics/primitive.rs @@ -54,6 +54,7 @@ pub(super) fn statistics_from_i32( UInt32 => Box::new(PrimitiveStatistics::::from((stats, data_type))), Int8 => Box::new(PrimitiveStatistics::::from((stats, data_type))), Int16 => Box::new(PrimitiveStatistics::::from((stats, data_type))), + Decimal(_, _) => Box::new(PrimitiveStatistics::::from((stats, data_type))), _ => Box::new(PrimitiveStatistics::::from((stats, data_type))), }) } @@ -69,6 +70,7 @@ pub(super) fn statistics_from_i64( UInt64 => { Box::new(PrimitiveStatistics::::from((stats, data_type))) as Box } + Decimal(_, _) => Box::new(PrimitiveStatistics::::from((stats, data_type))), _ => Box::new(PrimitiveStatistics::::from((stats, data_type))), }) } diff --git a/src/io/parquet/write/fixed_len_bytes.rs b/src/io/parquet/write/fixed_len_bytes.rs index 9d2cc2c188e..01069f6d0b0 100644 --- a/src/io/parquet/write/fixed_len_bytes.rs +++ b/src/io/parquet/write/fixed_len_bytes.rs @@ -1,9 +1,13 @@ use parquet2::{ - compression::create_codec, encoding::Encoding, metadata::ColumnDescriptor, - page::CompressedDataPage, write::WriteOptions, + compression::create_codec, + encoding::Encoding, + metadata::ColumnDescriptor, + page::CompressedDataPage, + statistics::{serialize_statistics, deserialize_statistics, ParquetStatistics}, + write::WriteOptions, }; -use super::utils; +use super::{binary::ord_binary, utils}; use crate::{ array::{Array, FixedSizeBinaryArray}, error::Result, @@ -54,6 +58,12 @@ pub fn array_to_page( buffer }; + let statistics = if options.write_statistics { + build_statistics(array, descriptor.clone()) + } else { + None + }; + utils::build_plain_page( buffer, array.len(), @@ -61,9 +71,34 @@ pub fn array_to_page( uncompressed_page_size, 0, definition_levels_byte_length, - None, + statistics, descriptor, options, Encoding::Plain, ) } + +pub(super) fn build_statistics( + array: &FixedSizeBinaryArray, + descriptor: ColumnDescriptor, +) -> Option { + let pq_statistics = &ParquetStatistics { + max: None, + min: None, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + }; + deserialize_statistics(pq_statistics,descriptor).map( + |e| serialize_statistics(&*e) + ).ok() +} diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index a84219316dd..97a2445a0ba 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -240,6 +240,30 @@ pub fn pyarrow_nullable(column: usize) -> Box { let values = Arc::new(PrimitiveArray::::from_slice([10, 200])); Box::new(DictionaryArray::::from_data(keys, values)) } + // decimal 9 + 7 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(9, 0))) + } + // decimal 18 + 8 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(18, 0))) + } + // decimal 26 + 9 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(26, 0))) + } _ => unreachable!(), } } @@ -289,6 +313,28 @@ pub fn pyarrow_nullable_statistics(column: usize) -> Option> max_value: Some(9), }), 6 => return None, + // Decimal statistics + 7 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(3), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(9, 0), + }), + 8 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(3), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(18, 0), + }), + 9 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(3), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(26, 0), + }), _ => unreachable!(), }) } @@ -316,6 +362,30 @@ pub fn pyarrow_required(column: usize) -> Box { 2 => Box::new(Utf8Array::::from_slice(&[ "Hello", "bbb", "aa", "", "bbb", "abc", "bbb", "bbb", "def", "aaa", ])), + // decimal 9 + 6 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(9, 0))) + } + // decimal 18 + 7 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(18, 0))) + } + // decimal 26 + 8 => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(26, 0))) + } _ => unreachable!(), } } @@ -341,6 +411,30 @@ pub fn pyarrow_required_statistics(column: usize) -> Option> min_value: Some("".to_string()), max_value: Some("def".to_string()), }), + // decimal_9 + 6 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(0), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(9, 0), + }), + // decimal_18 + 7 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(0), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(18, 0), + }), + // decimal_26 + 8 => Box::new(PrimitiveStatistics:: { + distinct_count: None, + null_count: Some(0), + min_value: Some(0i128), + max_value: Some(9i128), + data_type: DataType::Decimal(26, 0), + }), _ => unreachable!(), }) } diff --git a/tests/it/io/parquet/read.rs b/tests/it/io/parquet/read.rs index 93544b3d73a..fc6164b5f68 100644 --- a/tests/it/io/parquet/read.rs +++ b/tests/it/io/parquet/read.rs @@ -232,6 +232,67 @@ fn v1_nested_large_binary() -> Result<()> { test_pyarrow_integration(6, 1, "nested", false, false) } +#[test] +fn v1_decimal_9_nullable() -> Result<()> { + test_pyarrow_integration(7, 1, "basic", false, false) +} + +#[test] +fn v1_decimal_9_required() -> Result<()> { + test_pyarrow_integration(6, 1, "basic", false, true) +} + +#[test] +fn v1_decimal_18_nullable() -> Result<()> { + test_pyarrow_integration(8, 1, "basic", false, false) +} + +#[test] +fn v1_decimal_18_required() -> Result<()> { + test_pyarrow_integration(7, 1, "basic", false, true) +} + +#[test] +fn v1_decimal_26_nullable() -> Result<()> { + test_pyarrow_integration(9, 1, "basic", false, false) +} + + +#[test] +fn v1_decimal_26_required() -> Result<()> { + test_pyarrow_integration(8, 1, "basic", false, true) +} + +#[test] +fn v2_decimal_9_nullable() -> Result<()> { + test_pyarrow_integration(7, 2, "basic", false, false) +} + +#[test] +fn v2_decimal_9_required() -> Result<()> { + test_pyarrow_integration(6, 2, "basic", false, true) +} + +#[test] +fn v2_decimal_18_nullable() -> Result<()> { + test_pyarrow_integration(8, 2, "basic", false, false) +} + +#[test] +fn v2_decimal_18_required() -> Result<()> { + test_pyarrow_integration(7, 2, "basic", false, true) +} + +#[test] +fn v2_decimal_26_nullable() -> Result<()> { + test_pyarrow_integration(9, 2, "basic", false, false) +} + + +#[test] +fn v2_decimal_26_required() -> Result<()> { + test_pyarrow_integration(8, 2, "basic", false, true) +} /*#[test] fn v2_nested_nested() { let _ = test_pyarrow_integration(7, 1, "nested",false, false); diff --git a/tests/it/io/parquet/write.rs b/tests/it/io/parquet/write.rs index cc2a13253f7..5b1ceb40c34 100644 --- a/tests/it/io/parquet/write.rs +++ b/tests/it/io/parquet/write.rs @@ -366,3 +366,148 @@ fn i32_optional_v2_dict() -> Result<()> { Encoding::RleDictionary, ) } + +// Decimal Testing +#[test] +fn decimal_9_optional_v1() -> Result<()> { + round_trip( + 7, + true, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_9_required_v1() -> Result<()> { + round_trip( + 6, + false, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_18_optional_v1() -> Result<()> { + round_trip( + 8, + true, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_18_required_v1() -> Result<()> { + round_trip( + 7, + false, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_26_optional_v1() -> Result<()> { + round_trip( + 9, + true, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_26_required_v1() -> Result<()> { + round_trip( + 8, + false, + false, + Version::V1, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_9_optional_v2() -> Result<()> { + round_trip( + 7, + true, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_9_required_v2() -> Result<()> { + round_trip( + 6, + false, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_18_optional_v2() -> Result<()> { + round_trip( + 8, + true, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_18_required_v2() -> Result<()> { + round_trip( + 7, + false, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_26_optional_v2() -> Result<()> { + round_trip( + 9, + true, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn decimal_26_required_v2() -> Result<()> { + round_trip( + 8, + false, + false, + Version::V2, + Compression::Uncompressed, + Encoding::Plain, + ) +}