From 8b9f1e22f94e944d748b1bc2e3e3b21f78f27743 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 13 Mar 2022 12:24:49 +0000 Subject: [PATCH] Added write of Struct --- src/io/avro/read/nested.rs | 1 + src/io/avro/read/schema.rs | 8 ++--- src/io/avro/write/schema.rs | 9 ++++- src/io/avro/write/serialize.rs | 63 +++++++++++++++++++++++++++++++- tests/it/io/avro/read.rs | 4 +-- tests/it/io/avro/write.rs | 66 ++++++++++++++++++++++++++++++++++ 6 files changed, 141 insertions(+), 10 deletions(-) diff --git a/src/io/avro/read/nested.rs b/src/io/avro/read/nested.rs index 9a6992959d1..f229c9510fa 100644 --- a/src/io/avro/read/nested.rs +++ b/src/io/avro/read/nested.rs @@ -218,6 +218,7 @@ impl DynMutableStructArray { #[inline] fn push_null(&mut self) { + self.values.iter_mut().for_each(|x| x.push_null()); match &mut self.validity { Some(validity) => validity.push(false), None => self.init_validity(), diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs index ff72879d78b..0aa48e4330d 100644 --- a/src/io/avro/read/schema.rs +++ b/src/io/avro/read/schema.rs @@ -112,7 +112,7 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> DataType::Union(fields, None, UnionMode::Dense) } } - AvroSchema::Record(Record { name, fields, .. }) => { + AvroSchema::Record(Record { fields, .. }) => { let fields = fields .iter() .map(|field| { @@ -120,11 +120,7 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> if let Some(doc) = &field.doc { props.insert("avro::doc".to_string(), doc.clone()); } - schema_to_field( - &field.schema, - Some(&format!("{}.{}", name, field.name)), - props, - ) + schema_to_field(&field.schema, Some(&field.name), props) }) .collect::>()?; DataType::Struct(fields) diff --git a/src/io/avro/write/schema.rs b/src/io/avro/write/schema.rs index 6db94b84a04..5804a9188f3 100644 --- a/src/io/avro/write/schema.rs +++ b/src/io/avro/write/schema.rs @@ -1,5 +1,5 @@ use avro_schema::{ - BytesLogical, Field as AvroField, Fixed, FixedLogical, IntLogical, LongLogical, + BytesLogical, Field as AvroField, Fixed, FixedLogical, IntLogical, LongLogical, Record, Schema as AvroSchema, }; @@ -39,6 +39,13 @@ fn _type_to_schema(data_type: &DataType) -> Result { DataType::LargeList(inner) | DataType::List(inner) => AvroSchema::Array(Box::new( type_to_schema(&inner.data_type, inner.is_nullable)?, )), + DataType::Struct(fields) => AvroSchema::Record(Record::new( + "", + fields + .iter() + .map(field_to_field) + .collect::>>()?, + )), DataType::Date32 => AvroSchema::Int(Some(IntLogical::Date)), DataType::Time32(TimeUnit::Millisecond) => AvroSchema::Int(Some(IntLogical::Time)), DataType::Time64(TimeUnit::Microsecond) => AvroSchema::Long(Some(LongLogical::Time)), diff --git a/src/io/avro/write/serialize.rs b/src/io/avro/write/serialize.rs index 71ae79bfe4e..e6c18fbaf8e 100644 --- a/src/io/avro/write/serialize.rs +++ b/src/io/avro/write/serialize.rs @@ -1,4 +1,4 @@ -use avro_schema::Schema as AvroSchema; +use avro_schema::{Record, Schema as AvroSchema}; use crate::bitmap::utils::zip_validity; use crate::datatypes::{IntervalUnit, PhysicalType, PrimitiveType}; @@ -116,6 +116,56 @@ fn list_optional<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> )) } +fn struct_required<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + Box::new(BufStreamingIterator::new( + 0..array.len(), + move |_, buf| { + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + }, + vec![], + )) +} + +fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + let iterator = zip_validity(0..array.len(), array.validity().as_ref().map(|x| x.iter())); + + Box::new(BufStreamingIterator::new( + iterator, + move |maybe, buf| { + util::zigzag_encode(maybe.is_some() as i64, buf).unwrap(); + if maybe.is_some() { + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + } else { + // skip the item + inner.iter_mut().for_each(|item| { + let _ = item.next().unwrap(); + }); + } + }, + vec![], + )) +} + /// Creates a [`StreamingIterator`] trait object that presents items from `array` /// encoded according to `schema`. /// # Panic @@ -375,6 +425,17 @@ pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSeria }; list_optional::(array.as_any().downcast_ref().unwrap(), schema) } + (PhysicalType::Struct, AvroSchema::Record(inner)) => { + struct_required(array.as_any().downcast_ref().unwrap(), inner) + } + (PhysicalType::Struct, AvroSchema::Union(inner)) => { + let inner = if let AvroSchema::Record(inner) = &inner[1] { + inner + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + struct_optional(array.as_any().downcast_ref().unwrap(), inner) + } (a, b) => todo!("{:?} -> {:?} not supported", a, b), } } diff --git a/tests/it/io/avro/read.rs b/tests/it/io/avro/read.rs index 5efd42518b6..1c5fc567480 100644 --- a/tests/it/io/avro/read.rs +++ b/tests/it/io/avro/read.rs @@ -69,7 +69,7 @@ pub(super) fn schema() -> (AvroSchema, Schema) { ), Field::new( "i", - DataType::Struct(vec![Field::new("bla.e", DataType::Float64, false)]), + DataType::Struct(vec![Field::new("e", DataType::Float64, false)]), false, ), Field::new( @@ -103,7 +103,7 @@ pub(super) fn data() -> Chunk> { Arc::new(Utf8Array::::from([Some("foo"), None])), array.into_arc(), Arc::new(StructArray::from_data( - DataType::Struct(vec![Field::new("bla.e", DataType::Float64, false)]), + DataType::Struct(vec![Field::new("e", DataType::Float64, false)]), vec![Arc::new(PrimitiveArray::::from_slice([1.0, 2.0]))], None, )), diff --git a/tests/it/io/avro/write.rs b/tests/it/io/avro/write.rs index 833dfc5ef59..8bf3c2f60b6 100644 --- a/tests/it/io/avro/write.rs +++ b/tests/it/io/avro/write.rs @@ -230,3 +230,69 @@ fn check_large_format() -> Result<()> { Ok(()) } + +fn struct_schema() -> Schema { + Schema::from(vec![ + Field::new( + "struct", + DataType::Struct(vec![ + Field::new("item1", DataType::Int32, false), + Field::new("item2", DataType::Int32, true), + ]), + false, + ), + Field::new( + "struct nullable", + DataType::Struct(vec![ + Field::new("item1", DataType::Int32, false), + Field::new("item2", DataType::Int32, true), + ]), + true, + ), + ]) +} + +fn struct_data() -> Chunk> { + let struct_dt = DataType::Struct(vec![ + Field::new("item1", DataType::Int32, false), + Field::new("item2", DataType::Int32, true), + ]); + + Chunk::new(vec![ + Box::new(StructArray::new( + struct_dt.clone(), + vec![ + Arc::new(PrimitiveArray::::from_slice([1, 2])), + Arc::new(PrimitiveArray::::from([None, Some(1)])), + ], + None, + )), + Box::new(StructArray::new( + struct_dt, + vec![ + Arc::new(PrimitiveArray::::from_slice([1, 2])), + Arc::new(PrimitiveArray::::from([None, Some(1)])), + ], + Some([true, false].into()), + )), + ]) +} + +#[test] +fn struct_() -> Result<()> { + let write_schema = struct_schema(); + let write_data = struct_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schema) = read_avro(&data, None)?; + + let expected_schema = struct_schema(); + assert_eq!(read_schema, expected_schema); + + let expected_data = struct_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +}