From 9e329f704e5e4e216cafae66c38571c20b9a5cc0 Mon Sep 17 00:00:00 2001 From: Frank Murphy Date: Tue, 20 Dec 2022 17:27:14 -0500 Subject: [PATCH] Fix calculation of nested rep levels --- src/io/parquet/write/nested/rep.rs | 111 +++++++++++++++++++++-------- tests/it/io/parquet/write.rs | 81 +++++++++++++++++++++ 2 files changed, 163 insertions(+), 29 deletions(-) diff --git a/src/io/parquet/write/nested/rep.rs b/src/io/parquet/write/nested/rep.rs index c7c5d15eb8e..8497815c49f 100644 --- a/src/io/parquet/write/nested/rep.rs +++ b/src/io/parquet/write/nested/rep.rs @@ -5,21 +5,31 @@ trait DebugIter: Iterator + std::fmt::Debug {} impl + std::fmt::Debug> DebugIter for A {} -fn iter<'a>(nested: &'a [Nested]) -> Vec> { +enum RepIter<'a> { + Required(Box), + Repeated(Box), +} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { nested .iter() .enumerate() .filter_map(|(i, nested)| match nested { Nested::Primitive(_, _, _) => None, - Nested::List(nested) => Some(Box::new(to_length(nested.offsets)) as Box), - Nested::LargeList(nested) => { - Some(Box::new(to_length(nested.offsets)) as Box) - } - Nested::Struct(_, _, length) => { - // only return 1, 1, 1, (x len) if struct is outer structure. - // otherwise treat as leaf + Nested::List(nested) => Some(RepIter::Repeated( + Box::new(to_length(nested.offsets)) as Box + )), + Nested::LargeList(nested) => Some(RepIter::Repeated( + Box::new(to_length(nested.offsets)) as Box, + )), + Nested::Struct(_, is_optional, length) => { + let iter = Box::new(std::iter::repeat(1usize).take(*length)) as Box; if i == 0 { - Some(Box::new(std::iter::repeat(1usize).take(*length)) as Box) + if *is_optional { + Some(RepIter::Repeated(iter)) + } else { + Some(RepIter::Required(iter)) + } } else { None } @@ -29,6 +39,10 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec> { } pub fn num_values(nested: &[Nested]) -> usize { + _num_values(nested, true) +} + +fn _num_values(nested: &[Nested], count_required: bool) -> usize { let iterators = iter(nested); let depth = iterators.len(); @@ -36,6 +50,17 @@ pub fn num_values(nested: &[Nested]) -> usize { .into_iter() .enumerate() .map(|(index, lengths)| { + let lengths = match lengths { + RepIter::Repeated(i) => i, + RepIter::Required(i) => { + if count_required { + i + } else { + Box::new(std::iter::empty()) as Box + } + } + }; + if index == depth - 1 { lengths .map(|length| if length == 0 { 1 } else { length }) @@ -73,9 +98,16 @@ pub struct RepLevelsIter<'a> { impl<'a> RepLevelsIter<'a> { pub fn new(nested: &'a [Nested]) -> Self { - let remaining_values = num_values(nested); + let remaining_values = _num_values(nested, false); + + let iter: Vec<_> = iter(nested) + .into_iter() + .flat_map(|i| match i { + RepIter::Repeated(ls) => Some(ls), + _ => None, + }) + .collect(); - let iter = iter(nested); let remaining = std::iter::repeat(0).take(iter.len()).collect(); Self { @@ -92,25 +124,30 @@ impl<'a> Iterator for RepLevelsIter<'a> { type Item = u32; fn next(&mut self) -> Option { - if *self.remaining.last().unwrap() > 0 { - *self.remaining.last_mut().unwrap() -= 1; - - let total = self.total; - self.total = 0; - let r = Some((self.current_level - total) as u32); - - for level in 0..self.current_level - 1 { - let level = self.remaining.len() - level - 1; - if self.remaining[level] == 0 { - self.current_level -= 1; - self.remaining[level.saturating_sub(1)] -= 1; + match self.remaining.last() { + Some(v) => { + if *v > 0 { + *self.remaining.last_mut().unwrap() -= 1; + + let total = self.total; + self.total = 0; + let r = Some((self.current_level - total) as u32); + + for level in 0..self.current_level - 1 { + let level = self.remaining.len() - level - 1; + if self.remaining[level] == 0 { + self.current_level -= 1; + self.remaining[level.saturating_sub(1)] -= 1; + } + } + if self.remaining[0] == 0 { + self.current_level -= 1; + } + self.remaining_values -= 1; + return r; } } - if self.remaining[0] == 0 { - self.current_level -= 1; - } - self.remaining_values -= 1; - return r; + None => return None, } self.total = 0; @@ -158,7 +195,23 @@ mod tests { Nested::Struct(None, false, 10), Nested::Primitive(None, true, 10), ]; - let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let expected = vec![]; + + test(nested, expected) + } + + #[test] + fn struct_required_nested_list() { + let nested = vec![ + Nested::Struct(None, false, 5), + Nested::List(ListNested { + is_optional: false, + offsets: &[0i32, 2, 2, 4, 6, 8], + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![0, 1, 0, 0, 1, 0, 1, 0, 1]; test(nested, expected) } diff --git a/tests/it/io/parquet/write.rs b/tests/it/io/parquet/write.rs index ef5a18888d2..7effb2224a9 100644 --- a/tests/it/io/parquet/write.rs +++ b/tests/it/io/parquet/write.rs @@ -1,7 +1,9 @@ use std::io::Cursor; +use arrow2::array::*; use arrow2::error::Result; use arrow2::io::parquet::write::*; +use arrow2::offset::Offsets; use super::*; @@ -73,6 +75,45 @@ fn round_trip_opt_stats( Ok(()) } +fn round_trip_native( + schema: Schema, + chunk: Chunk>, + version: Version, + compression: CompressionOptions, + encodings: Vec>, +) -> Result<()> { + let options = WriteOptions { + write_statistics: true, + compression, + version, + data_pagesize_limit: None, + }; + + let row_groups = RowGroupIterator::try_new( + vec![Ok(chunk.clone())].into_iter(), + &schema, + options, + encodings, + )?; + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(writer, schema.clone(), options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + for (field, array) in schema.fields.iter().zip(chunk.arrays().iter()) { + let (result, _) = read_column(&mut Cursor::new(data.clone()), field.name.as_str())?; + assert_eq!(array, result.as_ref()); + } + + Ok(()) +} + #[test] fn int64_optional_v1() -> Result<()> { round_trip( @@ -585,3 +626,43 @@ fn struct_v2() -> Result<()> { vec![Encoding::Plain, Encoding::Plain], ) } + +#[test] +fn round_trip_nested_required_struct() -> Result<()> { + let inner = PrimitiveArray::::from_values(vec![ + 2.0, 2.0, 4.0, 4.0, 6.0, 6.0, 8.0, 8.0, 10.0, 10.0, + ]); + + let offsets = vec![0, 2, 2, 4, 6, 8]; + let child = ListArray::new( + DataType::List(Box::new(Field::new( + "inner", + inner.data_type().clone(), + false, + ))), + Offsets::try_from(offsets).unwrap().into(), + inner.boxed(), + None, + ); + + let nested = StructArray::new( + DataType::Struct(vec![Field::new("child", child.data_type().clone(), false)]), + vec![child.clone().boxed()], + None, + ); + + let schema = Schema { + fields: vec![Field::new("nested", nested.data_type().clone(), false)], + metadata: Metadata::default(), + }; + + let chunk = Chunk::new(vec![nested.to_boxed()]); + + round_trip_native( + schema, + chunk, + Version::V1, + CompressionOptions::Uncompressed, + vec![vec![Encoding::Plain]], + ) +}