Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
TCeason committed Feb 22, 2023
1 parent 0b3a4ab commit 2540288
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 4 deletions.
4 changes: 4 additions & 0 deletions parquet_integration/write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def case_basic_nullable() -> Tuple[dict, pa.Schema, str]:
pa.field("decimal_9", pa.decimal128(9, 0)),
pa.field("decimal_18", pa.decimal128(18, 0)),
pa.field("decimal_26", pa.decimal128(26, 0)),
pa.field("decimal_39", pa.decimal256(39, 0)),
pa.field("timestamp_us", pa.timestamp("us")),
pa.field("timestamp_s", pa.timestamp("s")),
pa.field("emoji", pa.utf8()),
Expand All @@ -51,6 +52,7 @@ def case_basic_nullable() -> Tuple[dict, pa.Schema, str]:
"decimal_9": decimal,
"decimal_18": decimal,
"decimal_26": decimal,
"decimal_39": decimal,
"timestamp_us": int64,
"timestamp_s": int64,
"emoji": emoji,
Expand Down Expand Up @@ -83,6 +85,7 @@ def case_basic_required() -> Tuple[dict, pa.Schema, str]:
pa.field("decimal_9", pa.decimal128(9, 0), nullable=False),
pa.field("decimal_18", pa.decimal128(18, 0), nullable=False),
pa.field("decimal_26", pa.decimal128(26, 0), nullable=False),
pa.field("decimal_39", pa.decimal256(39, 0), nullable=False),
]
schema = pa.schema(fields)

Expand All @@ -97,6 +100,7 @@ def case_basic_required() -> Tuple[dict, pa.Schema, str]:
"decimal_9": decimal,
"decimal_18": decimal,
"decimal_26": decimal,
"decimal_39": decimal,
},
schema,
f"basic_required_10.parquet",
Expand Down
11 changes: 11 additions & 0 deletions src/io/parquet/read/indexes/fixed_len_binary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use parquet2::indexes::PageIndex;

use crate::types::{i256, NativeType};
use crate::{
array::{Array, FixedSizeBinaryArray, MutableFixedSizeBinaryArray, PrimitiveArray},
datatypes::{DataType, PhysicalType, PrimitiveType},
Expand Down Expand Up @@ -42,6 +43,16 @@ fn deserialize_binary_iter<'a, I: TrustedLen<Item = Option<&'a Vec<u8>>>>(
})
})))
}
PhysicalType::Primitive(PrimitiveType::Int256) => {
Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| {
v.map(|x| {
let n = x.len();
let mut bytes = [0u8; 32];
bytes[..n].copy_from_slice(x);
i256::from_be_bytes(bytes)
})
})))
}
_ => {
let mut a = MutableFixedSizeBinaryArray::try_new(
data_type,
Expand Down
12 changes: 12 additions & 0 deletions src/io/parquet/read/indexes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ fn deserialize(
))),
}
}
PhysicalType::Primitive(PrimitiveType::Int256) => {
let index = indexes.pop_front().unwrap();
match index.physical_type() {
parquet2::schema::types::PhysicalType::FixedLenByteArray(_) => {
let index = index.as_any().downcast_ref::<FixedLenByteIndex>().unwrap();
Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into())
}
other => Err(Error::nyi(format!(
"Deserialize {other:?} to arrow's int64"
))),
}
}
PhysicalType::Primitive(PrimitiveType::UInt8)
| PhysicalType::Primitive(PrimitiveType::UInt16)
| PhysicalType::Primitive(PrimitiveType::UInt32)
Expand Down
25 changes: 24 additions & 1 deletion src/io/parquet/read/statistics/fixlen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use parquet2::statistics::{FixedLenStatistics, Statistics as ParquetStatistics};

use crate::array::*;
use crate::error::Result;
use crate::types::days_ms;
use crate::io::parquet::read::convert_i256;
use crate::types::{days_ms, i256};

use super::super::{convert_days_ms, convert_i128};

Expand All @@ -28,6 +29,28 @@ pub(super) fn push_i128(
Ok(())
}

pub(super) fn push_i256(
from: Option<&dyn ParquetStatistics>,
n: usize,
min: &mut dyn MutableArray,
max: &mut dyn MutableArray,
) -> Result<()> {
let min = min
.as_mut_any()
.downcast_mut::<MutablePrimitiveArray<i256>>()
.unwrap();
let max = max
.as_mut_any()
.downcast_mut::<MutablePrimitiveArray<i256>>()
.unwrap();
let from = from.map(|s| s.as_any().downcast_ref::<FixedLenStatistics>().unwrap());

min.push(from.and_then(|s| s.min_value.as_deref().map(|x| convert_i256(x, n))));
max.push(from.and_then(|s| s.max_value.as_deref().map(|x| convert_i256(x, n))));

Ok(())
}

pub(super) fn push(
from: Option<&dyn ParquetStatistics>,
min: &mut dyn MutableArray,
Expand Down
7 changes: 7 additions & 0 deletions src/io/parquet/read/statistics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,13 @@ fn push(
ParquetPhysicalType::FixedLenByteArray(n) => fixlen::push_i128(from, *n, min, max),
_ => unreachable!(),
},
Decimal256(_, _) => match physical_type {
ParquetPhysicalType::FixedLenByteArray(n) if *n > 16 => Err(Error::NotYetImplemented(
format!("Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}"),
)),
ParquetPhysicalType::FixedLenByteArray(n) => fixlen::push_i256(from, *n, min, max),
_ => unreachable!(),
},
Binary => binary::push::<i32>(from, min, max),
LargeBinary => binary::push::<i64>(from, min, max),
Utf8 => utf8::push::<i32>(from, min, max),
Expand Down
4 changes: 1 addition & 3 deletions src/io/parquet/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,9 @@ pub fn array_to_page_simple(

fixed_len_bytes::array_to_page(array, options, type_, statistics)
}
DataType::Decimal256(precision, _) => {
DataType::Decimal256(_, _) => {
let type_ = type_;
let precision = *precision;
let size = 16;
println!("the array is {:?}", array.clone());
let array = array
.as_any()
.downcast_ref::<PrimitiveArray<i256>>()
Expand Down
26 changes: 26 additions & 0 deletions tests/it/io/parquet/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use ethnum::AsI256;
use std::io::{Cursor, Read, Seek};

use arrow2::types::i256;
use arrow2::{
array::*,
bitmap::Bitmap,
Expand Down Expand Up @@ -528,6 +530,13 @@ pub fn pyarrow_nullable(column: &str) -> Box<dyn Array> {
.collect::<Vec<_>>();
Box::new(PrimitiveArray::<i128>::from(values).to(DataType::Decimal(26, 0)))
}
"decimal_39" => {
let values = i64_values
.iter()
.map(|x| x.map(|x| i256(x.as_i256())))
.collect::<Vec<_>>();
Box::new(PrimitiveArray::<i256>::from(values).to(DataType::Decimal256(39, 0)))
}
"timestamp_us" => Box::new(
PrimitiveArray::<i64>::from(i64_values)
.to(DataType::Timestamp(TimeUnit::Microsecond, None)),
Expand Down Expand Up @@ -614,6 +623,16 @@ pub fn pyarrow_nullable_statistics(column: &str) -> Statistics {
min_value: Box::new(Int128Array::from_slice([-256]).to(DataType::Decimal(26, 0))),
max_value: Box::new(Int128Array::from_slice([9]).to(DataType::Decimal(26, 0))),
},
"decimal_39" => Statistics {
distinct_count: UInt64Array::from([None]).boxed(),
null_count: UInt64Array::from([Some(3)]).boxed(),
min_value: Box::new(
Int256Array::from_slice([i256(-(256.as_i256()))]).to(DataType::Decimal256(39, 0)),
),
max_value: Box::new(
Int256Array::from_slice([i256(9.as_i256())]).to(DataType::Decimal256(39, 0)),
),
},
"timestamp_us" => Statistics {
distinct_count: UInt64Array::from([None]).boxed(),
null_count: UInt64Array::from([Some(3)]).boxed(),
Expand Down Expand Up @@ -694,6 +713,13 @@ pub fn pyarrow_required(column: &str) -> Box<dyn Array> {
.collect::<Vec<_>>();
Box::new(PrimitiveArray::<i128>::from(values).to(DataType::Decimal(26, 0)))
}
"decimal_39" => {
let values = i64_values
.iter()
.map(|x| x.map(|x| i256(x.as_i256())))
.collect::<Vec<_>>();
Box::new(PrimitiveArray::<i256>::from(values).to(DataType::Decimal256(39, 0)))
}
_ => unreachable!(),
}
}
Expand Down
25 changes: 25 additions & 0 deletions tests/it/io/parquet/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ fn v1_decimal_26_required() -> Result<()> {
test_pyarrow_integration("decimal_26", 1, "basic", false, true, None)
}

#[test]
fn v1_decimal_39_nullable() -> Result<()> {
test_pyarrow_integration("decimal_39", 1, "basic", false, false, None)
}

#[test]
fn v1_decimal_39_required() -> Result<()> {
test_pyarrow_integration("decimal_39", 1, "basic", false, true, None)
}

#[test]
fn v2_decimal_9_nullable() -> Result<()> {
test_pyarrow_integration("decimal_9", 2, "basic", false, false, None)
Expand Down Expand Up @@ -436,6 +446,11 @@ fn v2_decimal_26_nullable() -> Result<()> {
test_pyarrow_integration("decimal_26", 2, "basic", false, false, None)
}

#[test]
fn v2_decimal_39_nullable() -> Result<()> {
test_pyarrow_integration("decimal_39", 2, "basic", false, false, None)
}

#[test]
fn v1_timestamp_us_nullable() -> Result<()> {
test_pyarrow_integration("timestamp_us", 1, "basic", false, false, None)
Expand Down Expand Up @@ -466,6 +481,16 @@ fn v2_decimal_26_required_dict() -> Result<()> {
test_pyarrow_integration("decimal_26", 2, "basic", true, true, None)
}

#[test]
fn v2_decimal_39_required() -> Result<()> {
test_pyarrow_integration("decimal_39", 2, "basic", false, true, None)
}

#[test]
fn v2_decimal_39_required_dict() -> Result<()> {
test_pyarrow_integration("decimal_39", 2, "basic", true, true, None)
}

#[test]
fn v1_struct_required_optional() -> Result<()> {
test_pyarrow_integration("struct", 1, "struct", false, false, None)
Expand Down
44 changes: 44 additions & 0 deletions tests/it/io/parquet/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,28 @@ fn decimal_26_required_v1() -> Result<()> {
)
}

#[test]
fn decimal_39_optional_v1() -> Result<()> {
round_trip(
"decimal_39",
"nullable",
Version::V1,
CompressionOptions::Uncompressed,
vec![Encoding::Plain],
)
}

#[test]
fn decimal_39_required_v1() -> Result<()> {
round_trip(
"decimal_39",
"required",
Version::V1,
CompressionOptions::Uncompressed,
vec![Encoding::Plain],
)
}

#[test]
fn decimal_9_optional_v2() -> Result<()> {
round_trip(
Expand Down Expand Up @@ -604,6 +626,28 @@ fn decimal_26_required_v2() -> Result<()> {
)
}

#[test]
fn decimal_39_optional_v2() -> Result<()> {
round_trip(
"decimal_39",
"nullable",
Version::V2,
CompressionOptions::Uncompressed,
vec![Encoding::Plain],
)
}

#[test]
fn decimal_39_required_v2() -> Result<()> {
round_trip(
"decimal_39",
"required",
Version::V2,
CompressionOptions::Uncompressed,
vec![Encoding::Plain],
)
}

#[test]
fn struct_v1() -> Result<()> {
round_trip(
Expand Down

0 comments on commit 2540288

Please sign in to comment.