Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Fix calculation of nested rep levels #1355

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 82 additions & 29 deletions src/io/parquet/write/nested/rep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,31 @@ trait DebugIter: Iterator<Item = usize> + std::fmt::Debug {}

impl<A: Iterator<Item = usize> + std::fmt::Debug> DebugIter for A {}

fn iter<'a>(nested: &'a [Nested]) -> Vec<Box<dyn DebugIter + 'a>> {
enum RepIter<'a> {
Required(Box<dyn DebugIter + 'a>),
Repeated(Box<dyn DebugIter + 'a>),
}

fn iter<'a>(nested: &'a [Nested]) -> Vec<RepIter<'a>> {
nested
.iter()
.enumerate()
.filter_map(|(i, nested)| match nested {
Nested::Primitive(_, _, _) => None,
Nested::List(nested) => Some(Box::new(to_length(nested.offsets)) as Box<dyn DebugIter>),
Nested::LargeList(nested) => {
Some(Box::new(to_length(nested.offsets)) as Box<dyn DebugIter>)
}
Nested::Struct(_, _, length) => {
// only return 1, 1, 1, (x len) if struct is outer structure.
// otherwise treat as leaf
Nested::List(nested) => Some(RepIter::Repeated(
Box::new(to_length(nested.offsets)) as Box<dyn DebugIter>
)),
Nested::LargeList(nested) => Some(RepIter::Repeated(
Box::new(to_length(nested.offsets)) as Box<dyn DebugIter>,
)),
Nested::Struct(_, is_optional, length) => {
let iter = Box::new(std::iter::repeat(1usize).take(*length)) as Box<dyn DebugIter>;
if i == 0 {
Some(Box::new(std::iter::repeat(1usize).take(*length)) as Box<dyn DebugIter>)
if *is_optional {
Some(RepIter::Repeated(iter))
} else {
Some(RepIter::Required(iter))
}
} else {
None
}
Expand All @@ -29,13 +39,28 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec<Box<dyn DebugIter + 'a>> {
}

pub fn num_values(nested: &[Nested]) -> usize {
_num_values(nested, true)
}

fn _num_values(nested: &[Nested], count_required: bool) -> usize {
let iterators = iter(nested);
let depth = iterators.len();

iterators
.into_iter()
.enumerate()
.map(|(index, lengths)| {
let lengths = match lengths {
RepIter::Repeated(i) => i,
RepIter::Required(i) => {
if count_required {
i
} else {
Box::new(std::iter::empty()) as Box<dyn DebugIter>
}
}
};

if index == depth - 1 {
lengths
.map(|length| if length == 0 { 1 } else { length })
Expand Down Expand Up @@ -73,9 +98,16 @@ pub struct RepLevelsIter<'a> {

impl<'a> RepLevelsIter<'a> {
pub fn new(nested: &'a [Nested]) -> Self {
let remaining_values = num_values(nested);
let remaining_values = _num_values(nested, false);

let iter: Vec<_> = iter(nested)
.into_iter()
.flat_map(|i| match i {
RepIter::Repeated(ls) => Some(ls),
_ => None,
})
.collect();

let iter = iter(nested);
let remaining = std::iter::repeat(0).take(iter.len()).collect();

Self {
Expand All @@ -92,25 +124,30 @@ impl<'a> Iterator for RepLevelsIter<'a> {
type Item = u32;

fn next(&mut self) -> Option<Self::Item> {
if *self.remaining.last().unwrap() > 0 {
*self.remaining.last_mut().unwrap() -= 1;

let total = self.total;
self.total = 0;
let r = Some((self.current_level - total) as u32);

for level in 0..self.current_level - 1 {
let level = self.remaining.len() - level - 1;
if self.remaining[level] == 0 {
self.current_level -= 1;
self.remaining[level.saturating_sub(1)] -= 1;
match self.remaining.last() {
Some(v) => {
if *v > 0 {
*self.remaining.last_mut().unwrap() -= 1;

let total = self.total;
self.total = 0;
let r = Some((self.current_level - total) as u32);

for level in 0..self.current_level - 1 {
let level = self.remaining.len() - level - 1;
if self.remaining[level] == 0 {
self.current_level -= 1;
self.remaining[level.saturating_sub(1)] -= 1;
}
}
if self.remaining[0] == 0 {
self.current_level -= 1;
}
self.remaining_values -= 1;
return r;
}
}
if self.remaining[0] == 0 {
self.current_level -= 1;
}
self.remaining_values -= 1;
return r;
None => return None,
}

self.total = 0;
Expand Down Expand Up @@ -158,7 +195,23 @@ mod tests {
Nested::Struct(None, false, 10),
Nested::Primitive(None, true, 10),
];
let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let expected = vec![];

test(nested, expected)
}

#[test]
fn struct_required_nested_list() {
let nested = vec![
Nested::Struct(None, false, 5),
Nested::List(ListNested {
is_optional: false,
offsets: &[0i32, 2, 2, 4, 6, 8],
validity: None,
}),
Nested::Primitive(None, false, 10),
];
let expected = vec![0, 1, 0, 0, 1, 0, 1, 0, 1];

test(nested, expected)
}
Expand Down
81 changes: 81 additions & 0 deletions tests/it/io/parquet/write.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::io::Cursor;

use arrow2::array::*;
use arrow2::error::Result;
use arrow2::io::parquet::write::*;
use arrow2::offset::Offsets;

use super::*;

Expand Down Expand Up @@ -73,6 +75,45 @@ fn round_trip_opt_stats(
Ok(())
}

fn round_trip_native(
schema: Schema,
chunk: Chunk<Box<dyn Array>>,
version: Version,
compression: CompressionOptions,
encodings: Vec<Vec<Encoding>>,
) -> Result<()> {
let options = WriteOptions {
write_statistics: true,
compression,
version,
data_pagesize_limit: None,
};

let row_groups = RowGroupIterator::try_new(
vec![Ok(chunk.clone())].into_iter(),
&schema,
options,
encodings,
)?;

let writer = Cursor::new(vec![]);
let mut writer = FileWriter::try_new(writer, schema.clone(), options)?;

for group in row_groups {
writer.write(group?)?;
}
writer.end(None)?;

let data = writer.into_inner().into_inner();

for (field, array) in schema.fields.iter().zip(chunk.arrays().iter()) {
let (result, _) = read_column(&mut Cursor::new(data.clone()), field.name.as_str())?;
assert_eq!(array, result.as_ref());
}

Ok(())
}

#[test]
fn int64_optional_v1() -> Result<()> {
round_trip(
Expand Down Expand Up @@ -585,3 +626,43 @@ fn struct_v2() -> Result<()> {
vec![Encoding::Plain, Encoding::Plain],
)
}

#[test]
fn round_trip_nested_required_struct() -> Result<()> {
let inner = PrimitiveArray::<f64>::from_values(vec![
2.0, 2.0, 4.0, 4.0, 6.0, 6.0, 8.0, 8.0, 10.0, 10.0,
]);

let offsets = vec![0, 2, 2, 4, 6, 8];
let child = ListArray::new(
DataType::List(Box::new(Field::new(
"inner",
inner.data_type().clone(),
false,
))),
Offsets::try_from(offsets).unwrap().into(),
inner.boxed(),
None,
);

let nested = StructArray::new(
DataType::Struct(vec![Field::new("child", child.data_type().clone(), false)]),
vec![child.clone().boxed()],
None,
);

let schema = Schema {
fields: vec![Field::new("nested", nested.data_type().clone(), false)],
metadata: Metadata::default(),
};

let chunk = Chunk::new(vec![nested.to_boxed()]);

round_trip_native(
schema,
chunk,
Version::V1,
CompressionOptions::Uncompressed,
vec![vec![Encoding::Plain]],
)
}