From 3b641b15b2b0d1a03bbc7f0ca1bdd082b794c79c Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 29 Apr 2022 05:48:38 +0000 Subject: [PATCH] Fixed writing required list --- src/io/parquet/write/levels.rs | 69 ++++++++++++++++++++++++---------- tests/it/io/parquet/mod.rs | 2 +- tests/it/io/parquet/read.rs | 2 +- tests/it/io/parquet/write.rs | 24 ++++++++++++ 4 files changed, 75 insertions(+), 22 deletions(-) diff --git a/src/io/parquet/write/levels.rs b/src/io/parquet/write/levels.rs index 725ebc2c96f..68d3c0b343c 100644 --- a/src/io/parquet/write/levels.rs +++ b/src/io/parquet/write/levels.rs @@ -72,9 +72,14 @@ impl Iterator for RepLevelsIter<'_, O> { } } +enum OffsetsIter<'a, O> { + Optional(std::iter::Zip, BitmapIter<'a>>), + Required(std::slice::Windows<'a, O>), +} + /// Iterator adapter of parquet / dremel definition levels pub struct DefLevelsIter<'a, O: Offset> { - iter: std::iter::Zip, Box + 'a>>, + iter: OffsetsIter<'a, O>, primitive_validity: Option>, remaining: usize, is_valid: bool, @@ -92,15 +97,12 @@ impl<'a, O: Offset> DefLevelsIter<'a, O> { let primitive_validity = primitive_validity.map(|x| x.iter()); - let validity = validity - .map(|x| Box::new(x.iter()) as Box>) - .unwrap_or_else(|| { - Box::new(std::iter::repeat(true).take(offsets.len() - 1)) - as Box> - }); + let iter = validity + .map(|x| OffsetsIter::Optional(offsets.windows(2).zip(x.iter()))) + .unwrap_or_else(|| OffsetsIter::Required(offsets.windows(2))); Self { - iter: offsets.windows(2).zip(validity), + iter, primitive_validity, remaining: 0, length: 0, @@ -115,18 +117,31 @@ impl Iterator for DefLevelsIter<'_, O> { fn next(&mut self) -> Option { if self.remaining == self.length { - if let Some((w, is_valid)) = self.iter.next() { - let start = w[0].to_usize(); - let end = w[1].to_usize(); - self.length = end - start; - self.remaining = 0; - self.is_valid = is_valid; - if self.length == 0 { - self.total_size -= 1; - return Some(is_valid as u32); + match &mut self.iter { + OffsetsIter::Optional(iter) => { + let (w, is_valid) = iter.next()?; + let start = w[0].to_usize(); + let end = w[1].to_usize(); + self.length = end - start; + self.remaining = 0; + self.is_valid = is_valid; + if self.length == 0 { + self.total_size -= 1; + return Some(self.is_valid as u32); + } + } + OffsetsIter::Required(iter) => { + let w = iter.next()?; + let start = w[0].to_usize(); + let end = w[1].to_usize(); + self.length = end - start; + self.remaining = 0; + self.is_valid = true; + if self.length == 0 { + self.total_size -= 1; + return Some(0); + } } - } else { - return None; } } self.remaining += 1; @@ -218,7 +233,7 @@ pub fn write_def_levels( validity: Option<&Bitmap>, version: Version, ) -> Result<()> { - let num_bits = 2; + let num_bits = 1 + validity.is_some() as u8; match version { Version::V1 => { @@ -252,6 +267,7 @@ mod tests { #[test] fn test_def_levels() { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] let offsets = [0, 2, 2, 5, 8, 8, 11, 11, 12].as_ref(); let validity = Some(Bitmap::from([ true, false, true, true, true, true, false, true, @@ -269,4 +285,17 @@ mod tests { .collect::>(); assert_eq!(result, expected) } + + #[test] + fn test_def_levels1() { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let offsets = [0, 2, 2, 5, 8, 8, 11, 11, 12].as_ref(); + let validity = None; + let primitive_validity = None; + let expected = vec![1u32, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1]; + + let result = DefLevelsIter::new(offsets, validity.as_ref(), primitive_validity.as_ref()) + .collect::>(); + assert_eq!(result, expected) + } } diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 3733a018d83..7804804e90b 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -183,7 +183,7 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { match column { "list_int64_required_required" => { - // [[0, 1], [], [2, None, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] let data_type = DataType::List(Box::new(Field::new("item", DataType::Int64, false))); Box::new(ListArray::::from_data( data_type, offsets, values, None, diff --git a/tests/it/io/parquet/read.rs b/tests/it/io/parquet/read.rs index 8e9d5b06cad..f8283a099e1 100644 --- a/tests/it/io/parquet/read.rs +++ b/tests/it/io/parquet/read.rs @@ -249,7 +249,7 @@ fn v1_nested_i16_dict() -> Result<()> { } #[test] -fn v2_nested_i16_required_dict() -> Result<()> { +fn v1_nested_i16_required_dict() -> Result<()> { test_pyarrow_integration( "list_int64_required_required", 1, diff --git a/tests/it/io/parquet/write.rs b/tests/it/io/parquet/write.rs index ef332185af4..424f04bd1d2 100644 --- a/tests/it/io/parquet/write.rs +++ b/tests/it/io/parquet/write.rs @@ -265,6 +265,30 @@ fn list_int64_optional_v1() -> Result<()> { ) } +#[test] +fn list_int64_required_required_v1() -> Result<()> { + round_trip( + "list_int64_required_required", + false, + true, + Version::V1, + CompressionOptions::Uncompressed, + Encoding::Plain, + ) +} + +#[test] +fn list_int64_required_required_v2() -> Result<()> { + round_trip( + "list_int64_required_required", + false, + true, + Version::V2, + CompressionOptions::Uncompressed, + Encoding::Plain, + ) +} + #[test] fn list_bool_optional_v2() -> Result<()> { round_trip(