From acee177ab716d196cc2830299be66ca3b5b5f354 Mon Sep 17 00:00:00 2001 From: Frank Murphy Date: Tue, 20 Dec 2022 17:27:14 -0500 Subject: [PATCH] Fix calculation of nested rep levels --- src/io/parquet/write/nested/rep.rs | 110 ++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 27 deletions(-) diff --git a/src/io/parquet/write/nested/rep.rs b/src/io/parquet/write/nested/rep.rs index 68c7609acae..f9f50fc1b53 100644 --- a/src/io/parquet/write/nested/rep.rs +++ b/src/io/parquet/write/nested/rep.rs @@ -5,23 +5,40 @@ trait DebugIter: Iterator + std::fmt::Debug {} impl + std::fmt::Debug> DebugIter for A {} -fn iter<'a>(nested: &'a [Nested]) -> Vec> { +enum RepIter<'a> { + Required(Box), + Repeated(Box), +} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { nested .iter() .filter_map(|nested| match nested { Nested::Primitive(_, _, _) => None, - Nested::List(nested) => Some(Box::new(to_length(nested.offsets)) as Box), - Nested::LargeList(nested) => { - Some(Box::new(to_length(nested.offsets)) as Box) - } - Nested::Struct(_, _, length) => { - Some(Box::new(std::iter::repeat(1usize).take(*length)) as Box) + Nested::List(nested) => Some(RepIter::Repeated( + Box::new(to_length(nested.offsets)) as Box + )), + Nested::LargeList(nested) => Some(RepIter::Repeated( + Box::new(to_length(nested.offsets)) as Box, + )), + Nested::Struct(_, is_optional, length) => { + let iter = Box::new(std::iter::repeat(1usize).take(*length)) as Box; + // if true { + if *is_optional { + Some(RepIter::Repeated(iter)) + } else { + Some(RepIter::Required(iter)) + } } }) .collect() } pub fn num_values(nested: &[Nested]) -> usize { + _num_values(nested, true) +} + +fn _num_values(nested: &[Nested], count_required: bool) -> usize { let iterators = iter(nested); let depth = iterators.len(); @@ -29,6 +46,17 @@ pub fn num_values(nested: &[Nested]) -> usize { .into_iter() .enumerate() .map(|(index, lengths)| { + let lengths = match lengths { + RepIter::Repeated(i) => i, + RepIter::Required(i) => { + if count_required { + i + } else { + Box::new(std::iter::empty()) as Box + } + } + }; + if index == depth - 1 { lengths .map(|length| if length == 0 { 1 } else { length }) @@ -66,9 +94,16 @@ pub struct RepLevelsIter<'a> { impl<'a> RepLevelsIter<'a> { pub fn new(nested: &'a [Nested]) -> Self { - let remaining_values = num_values(nested); + let remaining_values = _num_values(nested, false); + + let iter: Vec<_> = iter(nested) + .into_iter() + .flat_map(|i| match i { + RepIter::Repeated(ls) => Some(ls), + _ => None, + }) + .collect(); - let iter = iter(nested); let remaining = std::iter::repeat(0).take(iter.len()).collect(); Self { @@ -85,25 +120,30 @@ impl<'a> Iterator for RepLevelsIter<'a> { type Item = u32; fn next(&mut self) -> Option { - if *self.remaining.last().unwrap() > 0 { - *self.remaining.last_mut().unwrap() -= 1; - - let total = self.total; - self.total = 0; - let r = Some((self.current_level - total) as u32); - - for level in 0..self.current_level - 1 { - let level = self.remaining.len() - level - 1; - if self.remaining[level] == 0 { - self.current_level -= 1; - self.remaining[level.saturating_sub(1)] -= 1; + match self.remaining.last() { + Some(v) => { + if *v > 0 { + *self.remaining.last_mut().unwrap() -= 1; + + let total = self.total; + self.total = 0; + let r = Some((self.current_level - total) as u32); + + for level in 0..self.current_level - 1 { + let level = self.remaining.len() - level - 1; + if self.remaining[level] == 0 { + self.current_level -= 1; + self.remaining[level.saturating_sub(1)] -= 1; + } + } + if self.remaining[0] == 0 { + self.current_level -= 1; + } + self.remaining_values -= 1; + return r; } } - if self.remaining[0] == 0 { - self.current_level -= 1; - } - self.remaining_values -= 1; - return r; + None => return None, } self.total = 0; @@ -151,7 +191,23 @@ mod tests { Nested::Struct(None, false, 10), Nested::Primitive(None, true, 10), ]; - let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let expected = vec![]; + + test(nested, expected) + } + + #[test] + fn struct_required_nested_list() { + let nested = vec![ + Nested::Struct(None, false, 5), + Nested::List(ListNested { + is_optional: false, + offsets: &[0i32, 2, 2, 4, 6, 8], + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![0, 1, 0, 0, 1, 0, 1, 0, 1]; test(nested, expected) }