diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 9ea425983eb..25f01837ee3 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -26,6 +26,7 @@ use arrow2::{ datatypes::*, error::{ArrowError, Result}, io::ipc, + io::ipc::gen::Schema::MetadataVersion, io::ipc::read::read_record_batch, io::ipc::write, io::ipc::write::common::{encoded_batch, DictionaryTracker, EncodedData, IpcWriteOptions}, @@ -168,6 +169,7 @@ pub fn flight_data_to_arrow_batch( None, is_little_endian, &dictionaries_by_field, + MetadataVersion::V5, &mut reader, 0, ) diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 2c378cb938a..bfbebf663fe 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use pretty_assertions::assert_eq; use std::fs::File; use clap::{App, Arg}; diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index dc80de6ba86..de6951fee43 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -25,6 +25,7 @@ use arrow2::{ datatypes::*, io::ipc, io::ipc::gen::Message::{Message, MessageHeader}, + io::ipc::gen::Schema::MetadataVersion, record_batch::RecordBatch, }; use arrow_flight::flight_descriptor::*; @@ -295,6 +296,7 @@ async fn record_batch_from_message( None, true, &dictionaries_by_field, + MetadataVersion::V5, &mut reader, 0, ); diff --git a/src/array/equal/mod.rs b/src/array/equal/mod.rs index 45062e27c4f..3d3433bb860 100644 --- a/src/array/equal/mod.rs +++ b/src/array/equal/mod.rs @@ -1,5 +1,3 @@ -use std::unimplemented; - use crate::{ datatypes::{DataType, IntervalUnit}, types::{days_ms, NativeType}, @@ -19,6 +17,7 @@ mod list; mod null; mod primitive; mod struct_; +mod union; mod utf8; impl PartialEq for dyn Array { @@ -323,7 +322,11 @@ pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { let rhs = rhs.as_any().downcast_ref().unwrap(); fixed_size_list::equal(lhs, rhs) } - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + union::equal(lhs, rhs) + } } } diff --git a/src/array/equal/union.rs b/src/array/equal/union.rs new file mode 100644 index 00000000000..51b9d960fea --- /dev/null +++ b/src/array/equal/union.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, UnionArray}; + +pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/src/array/ffi.rs b/src/array/ffi.rs index e707f98d6b0..5ce19cb86ff 100644 --- a/src/array/ffi.rs +++ b/src/array/ffi.rs @@ -85,7 +85,7 @@ pub fn buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { DataType::LargeList(_) => ffi_dyn!(array, ListArray::), DataType::FixedSizeList(_, _) => ffi_dyn!(array, FixedSizeListArray), DataType::Struct(_) => ffi_dyn!(array, StructArray), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => ffi_dict_dyn!(array, DictionaryArray::), DataType::Int16 => ffi_dict_dyn!(array, DictionaryArray::), diff --git a/src/array/growable/mod.rs b/src/array/growable/mod.rs index db54a6e12e4..dc44a3ef3e5 100644 --- a/src/array/growable/mod.rs +++ b/src/array/growable/mod.rs @@ -225,7 +225,7 @@ pub fn make_growable<'a>( )) } DataType::FixedSizeList(_, _) => todo!(), - DataType::Union(_) => todo!(), + DataType::Union(_, _, _) => todo!(), DataType::Dictionary(key, _) => match key.as_ref() { DataType::UInt8 => dyn_dict_growable!(u8, arrays, use_validity, capacity), DataType::UInt16 => dyn_dict_growable!(u16, arrays, use_validity, capacity), diff --git a/src/array/mod.rs b/src/array/mod.rs index 3505a6204a6..db1f644b7f3 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -56,6 +56,9 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// This is `O(1)`. #[inline] fn null_count(&self) -> usize { + if self.data_type() == &DataType::Null { + return self.len(); + }; self.validity() .as_ref() .map(|x| x.null_count()) @@ -185,7 +188,7 @@ impl Display for dyn Array { DataType::LargeList(_) => fmt_dyn!(self, ListArray::, f), DataType::FixedSizeList(_, _) => fmt_dyn!(self, FixedSizeListArray, f), DataType::Struct(_) => fmt_dyn!(self, StructArray, f), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => fmt_dyn!(self, DictionaryArray::, f), DataType::Int16 => fmt_dyn!(self, DictionaryArray::, f), @@ -239,7 +242,7 @@ pub fn new_empty_array(data_type: DataType) -> Box { DataType::LargeList(_) => Box::new(ListArray::::new_empty(data_type)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)), DataType::Struct(fields) => Box::new(StructArray::new_empty(&fields)), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { DataType::Int8 => Box::new(DictionaryArray::::new_empty(*value_type)), DataType::Int16 => Box::new(DictionaryArray::::new_empty(*value_type)), @@ -293,7 +296,7 @@ pub fn new_null_array(data_type: DataType, length: usize) -> Box { DataType::LargeList(_) => Box::new(ListArray::::new_null(data_type, length)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_null(data_type, length)), DataType::Struct(fields) => Box::new(StructArray::new_null(&fields, length)), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { DataType::Int8 => Box::new(DictionaryArray::::new_null(*value_type, length)), DataType::Int16 => Box::new(DictionaryArray::::new_null(*value_type, length)), @@ -354,7 +357,7 @@ pub fn clone(array: &dyn Array) -> Box { DataType::LargeList(_) => clone_dyn!(array, ListArray::), DataType::FixedSizeList(_, _) => clone_dyn!(array, FixedSizeListArray), DataType::Struct(_) => clone_dyn!(array, StructArray), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => clone_dyn!(array, DictionaryArray::), DataType::Int16 => clone_dyn!(array, DictionaryArray::), @@ -380,6 +383,7 @@ mod null; mod primitive; mod specification; mod struct_; +mod union; mod utf8; mod equal; @@ -399,6 +403,7 @@ pub use null::NullArray; pub use primitive::*; pub use specification::Offset; pub use struct_::StructArray; +pub use union::UnionArray; pub use utf8::{MutableUtf8Array, Utf8Array, Utf8ValuesIter}; pub(crate) use self::ffi::buffers_children_dictionary; diff --git a/src/array/union/iterator.rs b/src/array/union/iterator.rs new file mode 100644 index 00000000000..7a859561c14 --- /dev/null +++ b/src/array/union/iterator.rs @@ -0,0 +1,55 @@ +use super::{Array, UnionArray}; +use crate::{scalar::Scalar, trusted_len::TrustedLen}; + +#[derive(Debug, Clone)] +pub struct UnionIter<'a> { + array: &'a UnionArray, + current: usize, +} + +impl<'a> UnionIter<'a> { + pub fn new(array: &'a UnionArray) -> Self { + Self { array, current: 0 } + } +} + +impl<'a> Iterator for UnionIter<'a> { + type Item = Box; + + fn next(&mut self) -> Option { + if self.current == self.array.len() { + None + } else { + let old = self.current; + self.current += 1; + Some(self.array.value(old)) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.current; + (len, Some(len)) + } +} + +impl<'a> IntoIterator for &'a UnionArray { + type Item = Box; + type IntoIter = UnionIter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> UnionArray { + /// constructs a new iterator + #[inline] + pub fn iter(&'a self) -> UnionIter<'a> { + UnionIter::new(self) + } +} + +impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} + +unsafe impl<'a> TrustedLen for UnionIter<'a> {} diff --git a/src/array/union/mod.rs b/src/array/union/mod.rs new file mode 100644 index 00000000000..c6f0660a31e --- /dev/null +++ b/src/array/union/mod.rs @@ -0,0 +1,148 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + bitmap::Bitmap, + buffer::Buffer, + datatypes::{DataType, Field}, + scalar::{new_scalar, Scalar}, +}; + +use super::Array; + +mod iterator; + +/// A union +// How to read a value at slot i: +// ``` +// let index = self.types()[i] as usize; +// let field = self.fields()[index]; +// let offset = self.offsets().map(|x| x[index]).unwrap_or(i); +// let field = field.as_any().downcast to correct type; +// let value = field.value(offset); +// ``` +#[derive(Debug, Clone)] +pub struct UnionArray { + types: Buffer, + fields_hash: HashMap>, + fields: Vec>, + offsets: Option>, + data_type: DataType, + offset: usize, +} + +impl UnionArray { + pub fn from_data( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Self { + let fields_hash = if let DataType::Union(f, ids, is_sparse) = &data_type { + let ids: Vec = ids + .as_ref() + .map(|x| x.iter().map(|x| *x as i8).collect()) + .unwrap_or_else(|| (0..f.len() as i8).collect()); + if f.len() != fields.len() { + panic!( + "The number of `fields` must equal the number of fields in the Union DataType" + ) + }; + let same_data_types = f + .iter() + .zip(fields.iter()) + .all(|(f, array)| f.data_type() == array.data_type()); + if !same_data_types { + panic!("All fields' datatype in the union must equal the datatypes on the fields.") + } + if offsets.is_none() != *is_sparse { + panic!("Sparsness flag must equal to noness of offsets in UnionArray") + } + ids.into_iter().zip(fields.iter().cloned()).collect() + } else { + panic!("Union struct must be created with the corresponding Union DataType") + }; + // not validated: + // * `offsets` is valid + // * max id < fields.len() + Self { + data_type, + fields_hash, + fields, + offsets, + types, + offset: 0, + } + } + + pub fn offsets(&self) -> &Option> { + &self.offsets + } + + pub fn fields(&self) -> &Vec> { + &self.fields + } + + pub fn types(&self) -> &Buffer { + &self.types + } + + pub fn value(&self, index: usize) -> Box { + let field_index = self.types()[index]; + let field = self.fields_hash[&field_index].as_ref(); + let offset = self + .offsets() + .as_ref() + .map(|x| x[index] as usize) + .unwrap_or(index); + new_scalar(field, offset) + } + + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Panic + /// This function panics iff `offset + length >= self.len()`. + #[inline] + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + fields: self.fields.clone(), + fields_hash: self.fields_hash.clone(), + types: self.types.clone().slice(offset, length), + offsets: self.offsets.clone(), + offset: self.offset + offset, + } + } +} + +impl Array for UnionArray { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn len(&self) -> usize { + self.types.len() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn validity(&self) -> &Option { + &None + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +impl UnionArray { + pub fn get_fields(data_type: &DataType) -> &[Field] { + if let DataType::Union(fields, _, _) = data_type { + fields + } else { + panic!("Wrong datatype passed to Struct.") + } + } +} diff --git a/src/compute/aggregate/memory.rs b/src/compute/aggregate/memory.rs index 7b1eb239d00..cbace0d8f7d 100644 --- a/src/compute/aggregate/memory.rs +++ b/src/compute/aggregate/memory.rs @@ -109,7 +109,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { .sum::() + validity_size(array.validity()) } - Union(_) => unreachable!(), + Union(_, _, _) => unreachable!(), Dictionary(keys, _) => match keys.as_ref() { Int8 => dyn_dict!(array, i8), Int16 => dyn_dict!(array, i16), diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 205a9a041dc..141bb1e189b 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -200,8 +200,8 @@ impl Field { )); } }, - DataType::Union(nested_fields) => match &from.data_type { - DataType::Union(from_nested_fields) => { + DataType::Union(nested_fields, _, _) => match &from.data_type { + DataType::Union(from_nested_fields, _, _) => { for from_field in from_nested_fields { let mut is_new_field = true; for self_field in nested_fields.iter_mut() { diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 6603ed80331..5b621982cf0 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -103,7 +103,8 @@ pub enum DataType { /// A nested datatype that contains a number of sub-fields. Struct(Vec), /// A nested datatype that can represent slots of differing types. - Union(Vec), + /// Third argument represents sparsness + Union(Vec, Option>, bool), /// A dictionary encoded array (`key_type`, `value_type`), where /// each array element is an index of `key_type` into an /// associated dictionary of `value_type`. diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index b9edf4d544d..c9d26db697a 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -316,7 +316,7 @@ fn to_format(data_type: &DataType) -> Result { DataType::Struct(_) => "+s", DataType::FixedSizeBinary(size) => return Ok(format!("w{}", size)), DataType::FixedSizeList(_, size) => return Ok(format!("+w:{}", size)), - DataType::Union(_) => todo!(), + DataType::Union(_, _, _) => todo!(), DataType::Dictionary(index, _) => return to_format(index.as_ref()), _ => todo!(), } diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index c5452db242d..07d94c5b81f 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -19,6 +19,7 @@ use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use crate::endianess::is_native_little_endian; +use crate::io::ipc::convert::ipc::UnionMode; mod ipc { pub use super::super::gen::File::*; @@ -276,6 +277,22 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT let fsb = field.type_as_decimal().unwrap(); DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) } + ipc::Type::Union => { + let type_ = field.type_as_union().unwrap(); + + let is_sparse = type_.mode() == UnionMode::Sparse; + + let ids = type_.typeIds().map(|x| x.iter().collect()); + + let fields = if let Some(children) = field.children() { + (0..children.len()) + .map(|i| children.get(i).into()) + .collect() + } else { + vec![] + }; + DataType::Union(fields, ids, is_sparse) + } t => unimplemented!("Type {:?} not supported", t), } } @@ -604,7 +621,27 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - t => unimplemented!("Type {:?} not supported", t), + Union(fields, ids, is_sparse) => { + let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); + + let ids = ids.as_ref().map(|ids| fbb.create_vector(ids)); + + let mut builder = ipc::UnionBuilder::new(fbb); + builder.add_mode(if *is_sparse { + UnionMode::Sparse + } else { + UnionMode::Dense + }); + + if let Some(ids) = ids { + builder.add_typeIds(ids); + } + FbFieldType { + type_type: ipc::Type::Union, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&children)), + } + } } } diff --git a/src/io/ipc/read/array/fixed_size_list.rs b/src/io/ipc/read/array/fixed_size_list.rs index a7416cd9ca3..8fb9b45cbcd 100644 --- a/src/io/ipc/read/array/fixed_size_list.rs +++ b/src/io/ipc/read/array/fixed_size_list.rs @@ -1,6 +1,8 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::FixedSizeListArray; use crate::datatypes::DataType; use crate::error::Result; @@ -18,6 +20,7 @@ pub fn read_fixed_size_list( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result { let field_node = field_nodes.pop_front().unwrap().0; @@ -40,6 +43,7 @@ pub fn read_fixed_size_list( block_offset, is_little_endian, compression, + version, )?; Ok(FixedSizeListArray::from_data(data_type, values, validity)) } diff --git a/src/io/ipc/read/array/list.rs b/src/io/ipc/read/array/list.rs index a876576fa27..61ff8b8612b 100644 --- a/src/io/ipc/read/array/list.rs +++ b/src/io/ipc/read/array/list.rs @@ -2,6 +2,8 @@ use std::collections::VecDeque; use std::convert::TryInto; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::{ListArray, Offset}; use crate::buffer::Buffer; use crate::datatypes::DataType; @@ -20,6 +22,7 @@ pub fn read_list( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result> where Vec: TryInto, @@ -56,6 +59,7 @@ where block_offset, is_little_endian, compression, + version, )?; Ok(ListArray::from_data(data_type, offsets, values, validity)) } diff --git a/src/io/ipc/read/array/mod.rs b/src/io/ipc/read/array/mod.rs index 458c62123fb..0dd2610510e 100644 --- a/src/io/ipc/read/array/mod.rs +++ b/src/io/ipc/read/array/mod.rs @@ -18,3 +18,5 @@ mod null; pub use null::*; mod dictionary; pub use dictionary::*; +mod union; +pub use union::*; diff --git a/src/io/ipc/read/array/struct_.rs b/src/io/ipc/read/array/struct_.rs index 95274459731..c259849c37a 100644 --- a/src/io/ipc/read/array/struct_.rs +++ b/src/io/ipc/read/array/struct_.rs @@ -1,6 +1,8 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::StructArray; use crate::datatypes::DataType; use crate::error::Result; @@ -18,6 +20,7 @@ pub fn read_struct( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result { let field_node = field_nodes.pop_front().unwrap().0; @@ -43,6 +46,7 @@ pub fn read_struct( block_offset, is_little_endian, compression, + version, ) }) .collect::>>()?; diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs new file mode 100644 index 00000000000..adaac0f13cd --- /dev/null +++ b/src/io/ipc/read/array/union.rs @@ -0,0 +1,99 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use gen::Schema::MetadataVersion; + +use crate::array::UnionArray; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::ipc::gen::Message::BodyCompression; + +use super::super::super::gen; +use super::super::deserialize::{read, skip, Node}; +use super::super::read_basic::*; + +pub fn read_union( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque<&gen::Schema::Buffer>, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + version: MetadataVersion, +) -> Result { + let field_node = field_nodes.pop_front().unwrap().0; + + if version != MetadataVersion::V5 { + let _ = buffers.pop_front().unwrap(); + }; + + let types = read_buffer( + buffers, + field_node.length() as usize, + reader, + block_offset, + is_little_endian, + compression, + )?; + + let offsets = if let DataType::Union(_, _, is_sparse) = data_type { + if !is_sparse { + Some(read_buffer( + buffers, + field_node.length() as usize, + reader, + block_offset, + is_little_endian, + compression, + )?) + } else { + None + } + } else { + panic!() + }; + + let fields = UnionArray::get_fields(&data_type); + + let fields = fields + .iter() + .map(|field| { + read( + field_nodes, + field.data_type().clone(), + buffers, + reader, + block_offset, + is_little_endian, + compression, + version, + ) + }) + .collect::>>()?; + + Ok(UnionArray::from_data(data_type, types, fields, offsets)) +} + +pub fn skip_union( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque<&gen::Schema::Buffer>, +) { + let _ = field_nodes.pop_front().unwrap(); + + let _ = buffers.pop_front().unwrap(); + if let DataType::Union(_, _, is_sparse) = data_type { + if !*is_sparse { + let _ = buffers.pop_front().unwrap(); + } + } else { + panic!() + }; + + let fields = UnionArray::get_fields(data_type); + + fields + .iter() + .for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index 0cc9fc468ea..e0524c558cd 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -19,6 +19,8 @@ use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; use std::sync::Arc; +use gen::Schema::MetadataVersion; + use crate::array::*; use crate::datatypes::{DataType, Field, Schema}; use crate::error::{ArrowError, Result}; @@ -96,6 +98,7 @@ pub fn read_record_batch( projection: Option<(&[usize], Arc)>, is_little_endian: bool, dictionaries: &[Option], + version: MetadataVersion, reader: &mut R, block_offset: u64, ) -> Result { @@ -130,6 +133,7 @@ pub fn read_record_batch( block_offset, is_little_endian, batch.compression(), + version, )), ProjectionResult::NotSelected(field) => { skip(&mut field_nodes, field.data_type(), &mut buffers); @@ -152,6 +156,7 @@ pub fn read_record_batch( block_offset, is_little_endian, batch.compression(), + version, ) }) .collect::>>()?; @@ -199,6 +204,7 @@ pub fn read_dictionary( None, is_little_endian, dictionaries_by_field, + MetadataVersion::V5, reader, block_offset, )?; diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index f244024bb0a..d45c7453342 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -9,6 +9,8 @@ use std::{ sync::Arc, }; +use gen::Schema::MetadataVersion; + use crate::datatypes::{DataType, IntervalUnit}; use crate::error::Result; use crate::io::ipc::gen::Message::BodyCompression; @@ -27,6 +29,7 @@ pub fn read( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result> { match data_type { DataType::Null => { @@ -229,6 +232,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::LargeList(_) => read_list::( @@ -239,6 +243,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::FixedSizeList(_, _) => read_fixed_size_list( @@ -249,6 +254,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::Struct(_) => read_struct( @@ -259,6 +265,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { @@ -328,7 +335,17 @@ pub fn read( .map(|x| Arc::new(x) as Arc), _ => unreachable!(), }, - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => read_union( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + version, + ) + .map(|x| Arc::new(x) as Arc), } } @@ -367,6 +384,6 @@ pub fn skip( DataType::FixedSizeList(_, _) => skip_fixed_size_list(field_nodes, data_type, buffers), DataType::Struct(_) => skip_struct(field_nodes, data_type, buffers), DataType::Dictionary(_, _) => skip_dictionary(field_nodes, buffers), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => skip_union(field_nodes, data_type, buffers), } } diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index e70498b26f3..e4656c1caff 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -217,6 +217,7 @@ pub fn read_batch( projection, metadata.is_little_endian, &metadata.dictionaries_by_field, + metadata.version, reader, block.offset() as u64 + block.metaDataLength() as u64, ) @@ -415,6 +416,16 @@ mod tests { test_file("1.0.0-bigendian", "generated_interval") } + #[test] + fn read_generated_100_union() -> Result<()> { + test_file("1.0.0-littleendian", "generated_union") + } + + #[test] + fn read_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } + #[test] fn read_generated_200_compression_lz4() -> Result<()> { test_file("2.0.0-compression", "generated_lz4") diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 01d10f54dfe..ddf432145dc 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -18,6 +18,8 @@ use std::io::Read; use std::sync::Arc; +use gen::Schema::MetadataVersion; + use crate::array::*; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; @@ -34,6 +36,8 @@ pub struct StreamMetadata { /// The schema that is read from the stream's first message schema: Arc, + version: MetadataVersion, + /// Whether the incoming stream is little-endian is_little_endian: bool, } @@ -57,6 +61,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { let message = gen::Message::root_as_message(meta_buffer.as_slice()) .map_err(|err| ArrowError::Ipc(format!("Unable to get root as message: {:?}", err)))?; + let version = message.version(); // message header is a Schema, so read it let ipc_schema: gen::Schema::Schema = message .header_as_schema() @@ -66,6 +71,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { Ok(StreamMetadata { schema, + version, is_little_endian, }) } @@ -134,6 +140,7 @@ pub fn read_next( None, metadata.is_little_endian, dictionaries_by_field, + metadata.version, &mut reader, 0, ) @@ -324,4 +331,9 @@ mod tests { fn read_generated_200_compression_zstd() -> Result<()> { test_file("2.0.0-compression", "generated_zstd") } + + #[test] + fn read_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } } diff --git a/src/io/ipc/write/serialize.rs b/src/io/ipc/write/serialize.rs index 60f529395d0..0fa669452cf 100644 --- a/src/io/ipc/write/serialize.rs +++ b/src/io/ipc/write/serialize.rs @@ -16,10 +16,7 @@ // under the License. use crate::{ - array::{ - Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, FixedSizeBinaryArray, - FixedSizeListArray, ListArray, Offset, PrimitiveArray, StructArray, Utf8Array, - }, + array::*, bitmap::Bitmap, datatypes::{DataType, IntervalUnit}, endianess::is_native_little_endian, @@ -236,6 +233,33 @@ pub fn write_struct( }); } +pub fn write_union( + array: &dyn Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, +) { + let array = array.as_any().downcast_ref::().unwrap(); + + write_buffer(array.types(), buffers, arrow_data, offset, is_little_endian); + + if let Some(offsets) = array.offsets() { + write_buffer(offsets, buffers, arrow_data, offset, is_little_endian); + } + array.fields().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + ) + }); +} + fn write_fixed_size_list( array: &dyn Array, buffers: &mut Vec, @@ -467,7 +491,9 @@ pub fn write( true, ); } - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => { + write_union(array, buffers, arrow_data, nodes, offset, is_little_endian); + } } } diff --git a/src/io/ipc/write/writer.rs b/src/io/ipc/write/writer.rs index 1f108dd94d8..e642648914d 100644 --- a/src/io/ipc/write/writer.rs +++ b/src/io/ipc/write/writer.rs @@ -330,6 +330,17 @@ mod tests { test_file("1.0.0-bigendian", "generated_decimal") } + #[test] + fn write_100_union() -> Result<()> { + test_file("1.0.0-littleendian", "generated_union")?; + test_file("1.0.0-bigendian", "generated_union") + } + + #[test] + fn write_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } + #[test] fn write_sliced_utf8() -> Result<()> { use crate::array::{Array, Utf8Array}; diff --git a/src/io/json/read/deserialize.rs b/src/io/json/read/deserialize.rs index 8a02c95f7e9..b32d15b5841 100644 --- a/src/io/json/read/deserialize.rs +++ b/src/io/json/read/deserialize.rs @@ -255,7 +255,7 @@ pub fn read(rows: &[&Value], data_type: DataType) -> Arc { /* DataType::FixedSizeBinary(_) => Box::new(FixedSizeBinaryArray::new_empty(data_type)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Decimal(_, _) => Box::new(PrimitiveArray::::new_empty(data_type)), */ } diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index be4880fefa8..29a015ce1b5 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -138,5 +138,7 @@ pub struct ArrowJsonColumn { pub data: Option>, #[serde(rename = "OFFSET")] pub offset: Option>, // leaving as Value as 64-bit offsets are strings + #[serde(rename = "TYPE_ID")] + pub type_id: Option>, // for union types pub children: Option>, } diff --git a/src/io/json_integration/read.rs b/src/io/json_integration/read.rs index 6063335a95b..0d0e4c6f968 100644 --- a/src/io/json_integration/read.rs +++ b/src/io/json_integration/read.rs @@ -22,7 +22,7 @@ use serde_json::Value; use crate::{ array::*, - bitmap::Bitmap, + bitmap::{Bitmap, MutableBitmap}, buffer::Buffer, datatypes::{DataType, Field, IntervalUnit, Schema}, error::{ArrowError, Result}, @@ -33,9 +33,12 @@ use crate::{ use super::{ArrowJsonBatch, ArrowJsonColumn, ArrowJsonDictionaryBatch}; fn to_validity(validity: &Option>) -> Option { - validity - .as_ref() - .map(|x| x.iter().map(|is_valid| *is_valid == 1).collect::()) + validity.as_ref().and_then(|x| { + x.iter() + .map(|is_valid| *is_valid == 1) + .collect::() + .into() + }) } fn to_offsets(offsets: Option<&Vec>) -> Buffer { @@ -174,10 +177,7 @@ fn to_list( data_type: DataType, dictionaries: &HashMap, ) -> Result> { - let validity = json_col - .validity - .as_ref() - .map(|x| x.iter().map(|is_valid| *is_valid == 1).collect::()); + let validity = to_validity(&json_col.validity); let child_field = ListArray::::get_child_field(&data_type); let children = &json_col.children.as_ref().unwrap()[0]; @@ -222,21 +222,15 @@ pub fn to_array( match data_type { DataType::Null => Ok(Arc::new(NullArray::from_data(json_col.count))), DataType::Boolean => { - let array = json_col - .validity + let validity = to_validity(&json_col.validity); + let values = json_col + .data .as_ref() .unwrap() .iter() - .zip(json_col.data.as_ref().unwrap()) - .map(|(is_valid, value)| { - if *is_valid == 1 { - Some(value.as_bool().unwrap()) - } else { - None - } - }) - .collect::(); - Ok(Arc::new(array)) + .map(|value| value.as_bool().unwrap()) + .collect::(); + Ok(Arc::new(BooleanArray::from_data(values, validity))) } DataType::Int8 => Ok(Arc::new(to_primitive::(json_col, data_type.clone()))), DataType::Int16 => Ok(Arc::new(to_primitive::(json_col, data_type.clone()))), @@ -322,9 +316,51 @@ pub fn to_array( _ => unreachable!(), }, DataType::Float16 => unreachable!(), - DataType::Union(_) => Err(ArrowError::NotYetImplemented( - "Union not supported".to_string(), - )), + DataType::Union(fields, _, _) => { + let fields = fields + .iter() + .zip(json_col.children.as_ref().unwrap()) + .map(|(field, col)| to_array(field, col, dictionaries)) + .collect::>>()?; + + let types = json_col + .type_id + .as_ref() + .map(|x| { + x.iter() + .map(|value| match value { + Value::Number(x) => { + x.as_i64().and_then(num::cast::cast::).unwrap() + } + Value::String(x) => x.parse::().ok().unwrap(), + _ => { + panic!() + } + }) + .collect() + }) + .unwrap_or_default(); + + let offsets = json_col + .offset + .as_ref() + .map(|x| { + Some( + x.iter() + .map(|value| match value { + Value::Number(x) => { + x.as_i64().and_then(num::cast::cast::).unwrap() + } + _ => panic!(), + }) + .collect(), + ) + }) + .unwrap_or_default(); + + let array = UnionArray::from_data(data_type.clone(), types, fields, offsets); + Ok(Arc::new(array)) + } } } diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index 5ed21db1085..b6c85bc22df 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -56,7 +56,7 @@ impl ToJson for DataType { json!({"name": "fixedsizebinary", "byteWidth": byte_width}) } DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_) => json!({"name": "union"}), + DataType::Union(_, _, _) => json!({"name": "union"}), DataType::List(_) => json!({ "name": "list"}), DataType::LargeList(_) => json!({ "name": "largelist"}), DataType::FixedSizeList(_, length) => { @@ -333,6 +333,19 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { } } "struct" => DataType::Struct(children), + "union" => { + let is_sparse = if let Some(Value::String(mode)) = item.get("mode") { + mode == "SPARSE" + } else { + return Err(ArrowError::Schema("union requires mode".to_string())); + }; + let ids = if let Some(Value::Array(ids)) = item.get("typeIds") { + Some(ids.iter().map(|x| x.as_i64().unwrap() as i32).collect()) + } else { + return Err(ArrowError::Schema("union requires ids".to_string())); + }; + DataType::Union(children, ids, is_sparse) + } other => { return Err(ArrowError::Schema(format!( "invalid json value type \"{}\"", diff --git a/src/io/json_integration/write.rs b/src/io/json_integration/write.rs index c99ec404326..fcb1ebffec3 100644 --- a/src/io/json_integration/write.rs +++ b/src/io/json_integration/write.rs @@ -25,6 +25,7 @@ pub fn from_record_batch(batch: &RecordBatch) -> ArrowJsonBatch { validity: Some(validity), data: Some(data), offset: None, + type_id: None, children: None, } } @@ -34,6 +35,7 @@ pub fn from_record_batch(batch: &RecordBatch) -> ArrowJsonBatch { validity: None, data: None, offset: None, + type_id: None, children: None, }, }; diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index 9ec6417ad66..149703d15c2 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -130,7 +130,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } FixedSizeBinary(_) => todo!(), FixedSizeList(_, _) => todo!(), - Union(_) => todo!(), + Union(_, _, _) => todo!(), Dictionary(_, _) => todo!(), } }