diff --git a/src/io/parquet/write/nested/def.rs b/src/io/parquet/write/nested/def.rs index f3cb804feff..4426d3a3cfe 100644 --- a/src/io/parquet/write/nested/def.rs +++ b/src/io/parquet/write/nested/def.rs @@ -69,7 +69,6 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec> { pub struct DefLevelsIter<'a> { // iterators of validities and lengths. E.g. [[[None,b,c], None], None] -> [[(true, 2), (false, 0)], [(true, 3), (false, 0)], [(false, 1), (true, 1), (true, 1)]] iter: Vec>, - primitive_validity: Box, // vector containing the remaining number of values of each iterator. // e.g. the iters [[2, 2], [3, 4, 1, 2]] after the first iteration will return [2, 3], // and remaining will be [2, 3]. @@ -88,21 +87,15 @@ pub struct DefLevelsIter<'a> { } impl<'a> DefLevelsIter<'a> { - pub fn new(nested: &'a [Nested], offset: usize) -> Self { + pub fn new(nested: &'a [Nested], _offset: usize) -> Self { let remaining_values = num_values(nested); - let mut primitive_validity = iter(&nested[nested.len() - 1..]).pop().unwrap(); - if offset > 0 { - primitive_validity.nth(offset - 1); - } - - let iter = iter(&nested[..nested.len() - 1]); - let remaining = std::iter::repeat(0).take(iter.len()).collect(); - let validity = std::iter::repeat(0).take(iter.len()).collect(); + let iter = iter(nested); + let remaining = vec![0; iter.len()]; + let validity = vec![0; iter.len()]; Self { iter, - primitive_validity, remaining, validity, total: 0, @@ -152,19 +145,13 @@ impl<'a> Iterator for DefLevelsIter<'a> { *x = x.saturating_sub(1) } - let primitive = if self.current_level == self.remaining.len() { - self.primitive_validity.next()?.0 - } else { - 0 - }; - let r = Some(self.total + empty_contrib + primitive); + let r = Some(self.total + empty_contrib); - for level in 0..self.current_level.saturating_sub(1) { - let level = self.remaining.len() - level - 1; - if self.remaining[level] == 0 { + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { self.current_level -= 1; - self.remaining[level - 1] -= 1; - self.total -= self.validity[level]; + self.remaining[index - 1] -= 1; + self.total -= self.validity[index]; } } if self.remaining[0] == 0 { @@ -519,4 +506,106 @@ mod tests { test(nested, expected) } + + #[test] + fn nested_list_struct_list_nullable1() { + /* + [ + [{"a": ["b"]}, None], + ] + */ + + let a = [true].into(); + let b = [true, false].into(); + let c = [true, false].into(); + let d = [true].into(); + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: &[0, 2], + validity: Some(&a), + }), + Nested::Struct(Some(&b), true, 2), + Nested::List(ListNested { + is_optional: true, + offsets: &[0, 1, 1], + validity: Some(&c), + }), + Nested::Primitive(Some(&d), true, 1), + ]; + /* + 0 6 + 1 6 + 0 0 + 0 6 + 1 2 + */ + let expected = vec![6, 2]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable() { + /* + [List(ListNested { is_optional: true, offsets: [0, 2, 2, 5, 8, 8, 11, 11, 12], validity: Some([0b10111101]) }), Struct(Some([0b11110111, 0b____1111]), true, 12), List(ListNested { is_optional: true, offsets: [0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8], validity: Some([0b00010111, 0b____1111]) }), Primitive(Some([0b11011111]), true, 8)] + + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let a = [true, false, true, true, true, true, false, true].into(); + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ] + .into(); + let c = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ] + .into(); + let d = [true, true, true, true, true, false, true, true].into(); + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: &[0, 2, 2, 5, 8, 8, 11, 11, 12], + validity: Some(&a), + }), + // 0b11110111, 0b____1111 + Nested::Struct(Some(&b), true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: &[0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8], + validity: Some(&c), + }), + Nested::Primitive(Some(&d), true, 8), + ]; + /* + 0 6 + 1 6 + 0 0 + 0 6 + 1 2 + 1 6 + 0 3 + 1 3 + 1 3 + 0 1 + 0 6 + 1 5 + 1 6 + 2 6 + 0 0 + 0 4 + */ + let expected = vec![6, 6, 0, 6, 2, 6, 3, 3, 3, 1, 6, 5, 6, 6, 0, 4]; + + test(nested, expected) + } } diff --git a/src/io/parquet/write/nested/rep.rs b/src/io/parquet/write/nested/rep.rs index 6d88bbd1637..35841648dd6 100644 --- a/src/io/parquet/write/nested/rep.rs +++ b/src/io/parquet/write/nested/rep.rs @@ -110,11 +110,10 @@ impl<'a> Iterator for RepLevelsIter<'a> { let r = Some((self.current_level - self.total) as u32); // update - for level in 0..self.current_level.saturating_sub(1) { - let level = self.remaining.len() - level - 1; - if self.remaining[level] == 0 { + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { self.current_level -= 1; - self.remaining[level - 1] -= 1; + self.remaining[index - 1] -= 1; } } if self.remaining[0] == 0 { diff --git a/tests/it/io/parquet/write.rs b/tests/it/io/parquet/write.rs index 04383a10a2c..d110298bc68 100644 --- a/tests/it/io/parquet/write.rs +++ b/tests/it/io/parquet/write.rs @@ -414,6 +414,18 @@ fn v1_nested_struct_list_nullable() -> Result<()> { ) } +#[test] +fn v1_nested_list_struct_list_nullable() -> Result<()> { + round_trip_opt_stats( + "list_struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + true, + ) +} + #[test] fn utf8_optional_v2_delta() -> Result<()> { round_trip(