From 959d54931f0dd532e2509b4bdcf0f244e2e6bd37 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sat, 14 Aug 2021 11:25:41 +0100 Subject: [PATCH] Add `UnionArray` (#283) --- README.md | 7 +- arrow-flight/src/utils.rs | 2 + .../tests/test_sql.py | 35 +++ guide/src/high_level.md | 1 + .../integration_test.rs | 2 + src/array/display.rs | 13 +- src/array/equal/mod.rs | 26 +- src/array/equal/union.rs | 5 + src/array/ffi.rs | 2 +- src/array/growable/mod.rs | 2 +- src/array/mod.rs | 15 +- src/array/union/ffi.rs | 58 ++++ src/array/union/iterator.rs | 55 ++++ src/array/union/mod.rs | 275 ++++++++++++++++++ src/compute/aggregate/memory.rs | 17 +- src/datatypes/field.rs | 4 +- src/datatypes/mod.rs | 3 +- src/ffi/array.rs | 1 + src/ffi/schema.rs | 33 ++- src/io/ipc/convert.rs | 39 ++- src/io/ipc/read/array/fixed_size_list.rs | 4 + src/io/ipc/read/array/list.rs | 4 + src/io/ipc/read/array/mod.rs | 2 + src/io/ipc/read/array/struct_.rs | 4 + src/io/ipc/read/array/union.rs | 99 +++++++ src/io/ipc/read/common.rs | 6 + src/io/ipc/read/deserialize.rs | 21 +- src/io/ipc/read/reader.rs | 11 + src/io/ipc/read/stream.rs | 12 + src/io/ipc/write/serialize.rs | 36 ++- src/io/ipc/write/writer.rs | 11 + src/io/json/read/deserialize.rs | 1 - src/io/json_integration/mod.rs | 2 + src/io/json_integration/read.rs | 82 ++++-- src/io/json_integration/schema.rs | 15 +- src/io/json_integration/write.rs | 2 + src/io/print.rs | 34 ++- src/scalar/mod.rs | 2 +- 38 files changed, 886 insertions(+), 57 deletions(-) create mode 100644 src/array/equal/union.rs create mode 100644 src/array/union/ffi.rs create mode 100644 src/array/union/iterator.rs create mode 100644 src/array/union/mod.rs create mode 100644 src/io/ipc/read/array/union.rs diff --git a/README.md b/README.md index b6192ae5ec0..61a5c5439d8 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,8 @@ venv/bin/python parquet_integration/write_parquet.py * `MutableArray` API to work in-memory in-place. * faster IPC reader (different design that avoids an extra copy of all data) * IPC supports 2.0 (compression) -* FFI support for dictionary-encoded arrays +* FFI support for dictionary-encoded arrays and union array +* All implemented arrow types pass IPC integration tests against other implementations ### Parquet @@ -81,7 +82,7 @@ venv/bin/python parquet_integration/write_parquet.py ## Features in the original not available in this crate * Parquet read and write of struct and nested lists. -* Union and Map types +* Map types ## Features in this crate not in pyarrow @@ -90,7 +91,7 @@ venv/bin/python parquet_integration/write_parquet.py ## Features in pyarrow not in this crate -Too many to enumerate; e.g. nested dictionary arrays, union, map, nested parquet. +Too many to enumerate; e.g. nested dictionary arrays, map, nested parquet. ## How to develop diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 9ea425983eb..25f01837ee3 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -26,6 +26,7 @@ use arrow2::{ datatypes::*, error::{ArrowError, Result}, io::ipc, + io::ipc::gen::Schema::MetadataVersion, io::ipc::read::read_record_batch, io::ipc::write, io::ipc::write::common::{encoded_batch, DictionaryTracker, EncodedData, IpcWriteOptions}, @@ -168,6 +169,7 @@ pub fn flight_data_to_arrow_batch( None, is_little_endian, &dictionaries_by_field, + MetadataVersion::V5, &mut reader, 0, ) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 5260536342a..d891eb27b90 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -193,3 +193,38 @@ def test_dict(self): b.validate(full=True) assert a.to_pylist() == b.to_pylist() assert a.type == b.type + + def test_sparse_union(self): + """ + Python -> Rust -> Python + """ + a = pyarrow.UnionArray.from_sparse( + pyarrow.array([0, 1, 1, 0, 1], pyarrow.int8()), + [ + pyarrow.array(["a", "", "", "", "c"], pyarrow.utf8()), + pyarrow.array([0, 1, 2, None, 0], pyarrow.int64()), + ], + ) + b = arrow_pyarrow_integration_testing.round_trip(a) + + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + + def test_dense_union(self): + """ + Python -> Rust -> Python + """ + a = pyarrow.UnionArray.from_dense( + pyarrow.array([0, 1, 1, 0, 1], pyarrow.int8()), + pyarrow.array([0, 1, 2, 3, 4], type=pyarrow.int32()), + [ + pyarrow.array(["a", "", "", "", "c"], pyarrow.utf8()), + pyarrow.array([0, 1, 2, None, 0], pyarrow.int64()), + ], + ) + b = arrow_pyarrow_integration_testing.round_trip(a) + + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type diff --git a/guide/src/high_level.md b/guide/src/high_level.md index 48d1b3819ec..b757fd2d742 100644 --- a/guide/src/high_level.md +++ b/guide/src/high_level.md @@ -145,6 +145,7 @@ There is a many-to-one relationship between `DataType` and an Array (i.e. a phys | `FixedSizeBinary(_)` | `FixedSizeBinaryArray` | | `FixedSizeList(_,_)` | `FixedSizeListArray` | | `Struct(_)` | `StructArray` | +| `Union(_,_,_)` | `UnionArray` | | `Dictionary(UInt8,_)` | `DictionaryArray` | | `Dictionary(UInt16,_)`| `DictionaryArray` | | `Dictionary(UInt32,_)`| `DictionaryArray` | diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index dc80de6ba86..de6951fee43 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -25,6 +25,7 @@ use arrow2::{ datatypes::*, io::ipc, io::ipc::gen::Message::{Message, MessageHeader}, + io::ipc::gen::Schema::MetadataVersion, record_batch::RecordBatch, }; use arrow_flight::flight_descriptor::*; @@ -295,6 +296,7 @@ async fn record_batch_from_message( None, true, &dictionaries_by_field, + MetadataVersion::V5, &mut reader, 0, ); diff --git a/src/array/display.rs b/src/array/display.rs index b90fe527953..d049c408c8e 100644 --- a/src/array/display.rs +++ b/src/array/display.rs @@ -222,7 +222,18 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box Strin string }) } - Union(_) => todo!(), + Union(_, _, _) => { + let array = array.as_any().downcast_ref::().unwrap(); + let displays = array + .fields() + .iter() + .map(|x| get_display(x.as_ref())) + .collect::>(); + Box::new(move |row: usize| { + let (field, index) = array.index(row); + displays[field](index) + }) + } } } diff --git a/src/array/equal/mod.rs b/src/array/equal/mod.rs index 45062e27c4f..be2b1c1976c 100644 --- a/src/array/equal/mod.rs +++ b/src/array/equal/mod.rs @@ -1,14 +1,9 @@ -use std::unimplemented; - use crate::{ datatypes::{DataType, IntervalUnit}, types::{days_ms, NativeType}, }; -use super::{ - primitive::PrimitiveArray, Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, - FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, Offset, StructArray, Utf8Array, -}; +use super::*; mod binary; mod boolean; @@ -19,6 +14,7 @@ mod list; mod null; mod primitive; mod struct_; +mod union; mod utf8; impl PartialEq for dyn Array { @@ -147,6 +143,18 @@ impl PartialEq<&dyn Array> for DictionaryArray { } } +impl PartialEq for UnionArray { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +impl PartialEq<&dyn Array> for UnionArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + /// Logically compares two [`Array`]s. /// Two arrays are logically equal if and only if: /// * their data types are equal @@ -323,7 +331,11 @@ pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { let rhs = rhs.as_any().downcast_ref().unwrap(); fixed_size_list::equal(lhs, rhs) } - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + union::equal(lhs, rhs) + } } } diff --git a/src/array/equal/union.rs b/src/array/equal/union.rs new file mode 100644 index 00000000000..51b9d960fea --- /dev/null +++ b/src/array/equal/union.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, UnionArray}; + +pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/src/array/ffi.rs b/src/array/ffi.rs index e707f98d6b0..094035b7284 100644 --- a/src/array/ffi.rs +++ b/src/array/ffi.rs @@ -85,7 +85,7 @@ pub fn buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { DataType::LargeList(_) => ffi_dyn!(array, ListArray::), DataType::FixedSizeList(_, _) => ffi_dyn!(array, FixedSizeListArray), DataType::Struct(_) => ffi_dyn!(array, StructArray), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => ffi_dyn!(array, UnionArray), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => ffi_dict_dyn!(array, DictionaryArray::), DataType::Int16 => ffi_dict_dyn!(array, DictionaryArray::), diff --git a/src/array/growable/mod.rs b/src/array/growable/mod.rs index db54a6e12e4..dc44a3ef3e5 100644 --- a/src/array/growable/mod.rs +++ b/src/array/growable/mod.rs @@ -225,7 +225,7 @@ pub fn make_growable<'a>( )) } DataType::FixedSizeList(_, _) => todo!(), - DataType::Union(_) => todo!(), + DataType::Union(_, _, _) => todo!(), DataType::Dictionary(key, _) => match key.as_ref() { DataType::UInt8 => dyn_dict_growable!(u8, arrays, use_validity, capacity), DataType::UInt16 => dyn_dict_growable!(u16, arrays, use_validity, capacity), diff --git a/src/array/mod.rs b/src/array/mod.rs index 3505a6204a6..37aa8f7d318 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -56,6 +56,9 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// This is `O(1)`. #[inline] fn null_count(&self) -> usize { + if self.data_type() == &DataType::Null { + return self.len(); + }; self.validity() .as_ref() .map(|x| x.null_count()) @@ -185,7 +188,7 @@ impl Display for dyn Array { DataType::LargeList(_) => fmt_dyn!(self, ListArray::, f), DataType::FixedSizeList(_, _) => fmt_dyn!(self, FixedSizeListArray, f), DataType::Struct(_) => fmt_dyn!(self, StructArray, f), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => fmt_dyn!(self, UnionArray, f), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => fmt_dyn!(self, DictionaryArray::, f), DataType::Int16 => fmt_dyn!(self, DictionaryArray::, f), @@ -239,7 +242,7 @@ pub fn new_empty_array(data_type: DataType) -> Box { DataType::LargeList(_) => Box::new(ListArray::::new_empty(data_type)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)), DataType::Struct(fields) => Box::new(StructArray::new_empty(&fields)), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => Box::new(UnionArray::new_empty(data_type)), DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { DataType::Int8 => Box::new(DictionaryArray::::new_empty(*value_type)), DataType::Int16 => Box::new(DictionaryArray::::new_empty(*value_type)), @@ -293,7 +296,7 @@ pub fn new_null_array(data_type: DataType, length: usize) -> Box { DataType::LargeList(_) => Box::new(ListArray::::new_null(data_type, length)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_null(data_type, length)), DataType::Struct(fields) => Box::new(StructArray::new_null(&fields, length)), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => Box::new(UnionArray::new_null(data_type, length)), DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { DataType::Int8 => Box::new(DictionaryArray::::new_null(*value_type, length)), DataType::Int16 => Box::new(DictionaryArray::::new_null(*value_type, length)), @@ -354,7 +357,7 @@ pub fn clone(array: &dyn Array) -> Box { DataType::LargeList(_) => clone_dyn!(array, ListArray::), DataType::FixedSizeList(_, _) => clone_dyn!(array, FixedSizeListArray), DataType::Struct(_) => clone_dyn!(array, StructArray), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => clone_dyn!(array, UnionArray), DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => clone_dyn!(array, DictionaryArray::), DataType::Int16 => clone_dyn!(array, DictionaryArray::), @@ -380,6 +383,7 @@ mod null; mod primitive; mod specification; mod struct_; +mod union; mod utf8; mod equal; @@ -399,6 +403,7 @@ pub use null::NullArray; pub use primitive::*; pub use specification::Offset; pub use struct_::StructArray; +pub use union::UnionArray; pub use utf8::{MutableUtf8Array, Utf8Array, Utf8ValuesIter}; pub(crate) use self::ffi::buffers_children_dictionary; @@ -498,6 +503,8 @@ mod tests { DataType::Utf8, DataType::Binary, DataType::List(Box::new(Field::new("a", DataType::Binary, true))), + DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true), + DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false), ]; let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0); assert!(a); diff --git a/src/array/union/ffi.rs b/src/array/union/ffi.rs new file mode 100644 index 00000000000..7028480e049 --- /dev/null +++ b/src/array/union/ffi.rs @@ -0,0 +1,58 @@ +use std::sync::Arc; + +use crate::{array::FromFfi, error::Result, ffi}; + +use super::super::{ffi::ToFfi, Array}; +use super::UnionArray; + +unsafe impl ToFfi for UnionArray { + fn buffers(&self) -> Vec>> { + if let Some(offsets) = &self.offsets { + vec![ + None, + std::ptr::NonNull::new(self.types.as_ptr() as *mut u8), + std::ptr::NonNull::new(offsets.as_ptr() as *mut u8), + ] + } else { + vec![None, std::ptr::NonNull::new(self.types.as_ptr() as *mut u8)] + } + } + + fn offset(&self) -> usize { + self.offset + } + + fn children(&self) -> Vec> { + self.fields.clone() + } +} + +unsafe impl FromFfi for UnionArray { + fn try_from_ffi(array: A) -> Result { + let field = array.field()?; + let data_type = field.data_type().clone(); + let fields = Self::get_fields(field.data_type()); + + let mut types = unsafe { array.buffer::(0) }?; + let offsets = if Self::is_sparse(&data_type) { + None + } else { + Some(unsafe { array.buffer::(1) }?) + }; + + let length = array.array().len(); + let offset = array.array().offset(); + let fields = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + Ok(ffi::try_from(child)?.into()) + }) + .collect::>>>()?; + + if offset > 0 { + types = types.slice(offset, length); + }; + + Ok(Self::from_data(data_type, types, fields, offsets)) + } +} diff --git a/src/array/union/iterator.rs b/src/array/union/iterator.rs new file mode 100644 index 00000000000..7a859561c14 --- /dev/null +++ b/src/array/union/iterator.rs @@ -0,0 +1,55 @@ +use super::{Array, UnionArray}; +use crate::{scalar::Scalar, trusted_len::TrustedLen}; + +#[derive(Debug, Clone)] +pub struct UnionIter<'a> { + array: &'a UnionArray, + current: usize, +} + +impl<'a> UnionIter<'a> { + pub fn new(array: &'a UnionArray) -> Self { + Self { array, current: 0 } + } +} + +impl<'a> Iterator for UnionIter<'a> { + type Item = Box; + + fn next(&mut self) -> Option { + if self.current == self.array.len() { + None + } else { + let old = self.current; + self.current += 1; + Some(self.array.value(old)) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.current; + (len, Some(len)) + } +} + +impl<'a> IntoIterator for &'a UnionArray { + type Item = Box; + type IntoIter = UnionIter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> UnionArray { + /// constructs a new iterator + #[inline] + pub fn iter(&'a self) -> UnionIter<'a> { + UnionIter::new(self) + } +} + +impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} + +unsafe impl<'a> TrustedLen for UnionIter<'a> {} diff --git a/src/array/union/mod.rs b/src/array/union/mod.rs new file mode 100644 index 00000000000..bccf74812e8 --- /dev/null +++ b/src/array/union/mod.rs @@ -0,0 +1,275 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + array::{display::get_value_display, display_fmt, new_empty_array, new_null_array, Array}, + bitmap::Bitmap, + buffer::Buffer, + datatypes::{DataType, Field}, + scalar::{new_scalar, Scalar}, +}; + +mod ffi; +mod iterator; + +type FieldEntry = (usize, Arc); + +/// [`UnionArray`] represents an array whose each slot can contain different values. +/// +// How to read a value at slot i: +// ``` +// let index = self.types()[i] as usize; +// let field = self.fields()[index]; +// let offset = self.offsets().map(|x| x[index]).unwrap_or(i); +// let field = field.as_any().downcast to correct type; +// let value = field.value(offset); +// ``` +#[derive(Debug, Clone)] +pub struct UnionArray { + types: Buffer, + // None represents when there is no typeid + fields_hash: Option>, + fields: Vec>, + offsets: Option>, + data_type: DataType, + offset: usize, +} + +impl UnionArray { + pub fn new_null(data_type: DataType, length: usize) -> Self { + if let DataType::Union(f, _, is_sparse) = &data_type { + let fields = f + .iter() + .map(|x| new_null_array(x.data_type().clone(), length).into()) + .collect(); + + let offsets = if *is_sparse { + None + } else { + Some((0..length as i32).collect::>()) + }; + + // all from the same field + let types = Buffer::new_zeroed(length); + + Self::from_data(data_type, types, fields, offsets) + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } + + pub fn new_empty(data_type: DataType) -> Self { + if let DataType::Union(f, _, is_sparse) = &data_type { + let fields = f + .iter() + .map(|x| new_empty_array(x.data_type().clone()).into()) + .collect(); + + let offsets = if *is_sparse { + None + } else { + Some(Buffer::new()) + }; + + Self { + data_type, + fields_hash: None, + fields, + offsets, + types: Buffer::new(), + offset: 0, + } + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } + + pub fn from_data( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Self { + let fields_hash = if let DataType::Union(f, ids, is_sparse) = &data_type { + if f.len() != fields.len() { + panic!( + "The number of `fields` must equal the number of fields in the Union DataType" + ) + }; + let same_data_types = f + .iter() + .zip(fields.iter()) + .all(|(f, array)| f.data_type() == array.data_type()); + if !same_data_types { + panic!("All fields' datatype in the union must equal the datatypes on the fields.") + } + if offsets.is_none() != *is_sparse { + panic!("Sparsness flag must equal to noness of offsets in UnionArray") + } + ids.as_ref().map(|ids| { + ids.iter() + .map(|x| *x as i8) + .enumerate() + .zip(fields.iter().cloned()) + .map(|((i, type_), field)| (type_, (i, field))) + .collect() + }) + } else { + panic!("Union struct must be created with the corresponding Union DataType") + }; + // not validated: + // * `offsets` is valid + // * max id < fields.len() + Self { + data_type, + fields_hash, + fields, + offsets, + types, + offset: 0, + } + } + + pub fn offsets(&self) -> &Option> { + &self.offsets + } + + pub fn fields(&self) -> &Vec> { + &self.fields + } + + pub fn types(&self) -> &Buffer { + &self.types + } + + #[inline] + fn field(&self, type_: i8) -> &Arc { + self.fields_hash + .as_ref() + .map(|x| &x[&type_].1) + .unwrap_or_else(|| &self.fields[type_ as usize]) + } + + #[inline] + fn field_slot(&self, index: usize) -> usize { + self.offsets() + .as_ref() + .map(|x| x[index] as usize) + .unwrap_or(index) + } + + /// Returns the index and slot of the field to select from `self.fields`. + pub fn index(&self, index: usize) -> (usize, usize) { + let type_ = self.types()[index]; + let field_index = self + .fields_hash + .as_ref() + .map(|x| x[&type_].0) + .unwrap_or_else(|| type_ as usize); + let index = self.field_slot(index); + (field_index, index) + } + + /// Returns the slot `index` as a [`Scalar`]. + pub fn value(&self, index: usize) -> Box { + let type_ = self.types()[index]; + let field = self.field(type_); + let index = self.field_slot(index); + new_scalar(field.as_ref(), index) + } + + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Panic + /// This function panics iff `offset + length >= self.len()`. + #[inline] + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + fields: self.fields.clone(), + fields_hash: self.fields_hash.clone(), + types: self.types.clone().slice(offset, length), + offsets: self.offsets.clone(), + offset: self.offset + offset, + } + } +} + +impl Array for UnionArray { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn len(&self) -> usize { + self.types.len() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn validity(&self) -> &Option { + &None + } + + fn slice(&self, offset: usize, length: usize) -> Box { + Box::new(self.slice(offset, length)) + } +} + +impl UnionArray { + pub fn get_fields(data_type: &DataType) -> &[Field] { + if let DataType::Union(fields, _, _) = data_type { + fields + } else { + panic!("Wrong datatype passed to UnionArray.") + } + } + + pub fn is_sparse(data_type: &DataType) -> bool { + if let DataType::Union(_, _, is_sparse) = data_type { + *is_sparse + } else { + panic!("Wrong datatype passed to UnionArray.") + } + } +} + +impl std::fmt::Display for UnionArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let display = get_value_display(self); + let new_lines = false; + let head = "UnionArray"; + let iter = self + .iter() + .enumerate() + .map(|(i, x)| if x.is_valid() { Some(display(i)) } else { None }); + display_fmt(iter, head, f, new_lines) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{array::*, buffer::Buffer, datatypes::*, error::Result}; + + #[test] + fn display() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, true); + let types = Buffer::from(&[0, 0, 1]); + let fields = vec![ + Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc, + Arc::new(Utf8Array::::from(&[Some("a"), Some("b"), Some("c")])) as Arc, + ]; + + let array = UnionArray::from_data(data_type, types, fields, None); + + assert_eq!(format!("{}", array), "UnionArray[1, , c]"); + + Ok(()) + } +} diff --git a/src/compute/aggregate/memory.rs b/src/compute/aggregate/memory.rs index 7b1eb239d00..d49411c3b34 100644 --- a/src/compute/aggregate/memory.rs +++ b/src/compute/aggregate/memory.rs @@ -109,7 +109,22 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { .sum::() + validity_size(array.validity()) } - Union(_) => unreachable!(), + Union(_, _, _) => { + let array = array.as_any().downcast_ref::().unwrap(); + let types = array.types().len() * std::mem::size_of::(); + let offsets = array + .offsets() + .as_ref() + .map(|x| x.len() * std::mem::size_of::()) + .unwrap_or_default(); + let fields = array + .fields() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::(); + types + offsets + fields + } Dictionary(keys, _) => match keys.as_ref() { Int8 => dyn_dict!(array, i8), Int16 => dyn_dict!(array, i16), diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 205a9a041dc..141bb1e189b 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -200,8 +200,8 @@ impl Field { )); } }, - DataType::Union(nested_fields) => match &from.data_type { - DataType::Union(from_nested_fields) => { + DataType::Union(nested_fields, _, _) => match &from.data_type { + DataType::Union(from_nested_fields, _, _) => { for from_field in from_nested_fields { let mut is_new_field = true; for self_field in nested_fields.iter_mut() { diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 6603ed80331..5b621982cf0 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -103,7 +103,8 @@ pub enum DataType { /// A nested datatype that contains a number of sub-fields. Struct(Vec), /// A nested datatype that can represent slots of differing types. - Union(Vec), + /// Third argument represents sparsness + Union(Vec, Option>, bool), /// A dictionary encoded array (`key_type`, `value_type`), where /// each array element is an index of `key_type` into an /// associated dictionary of `value_type`. diff --git a/src/ffi/array.rs b/src/ffi/array.rs index 88575b60f50..bde98378b7f 100644 --- a/src/ffi/array.rs +++ b/src/ffi/array.rs @@ -77,6 +77,7 @@ pub fn try_from(array: A) -> Result> { DataType::UInt64 => Box::new(DictionaryArray::::try_from_ffi(array)?), _ => unreachable!(), }, + DataType::Union(_, _, _) => Box::new(UnionArray::try_from_ffi(array)?), data_type => { return Err(ArrowError::NotYetImplemented(format!( "Reading DataType \"{}\" is not yet supported.", diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index b9edf4d544d..926871af11e 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -69,6 +69,10 @@ impl Ffi_ArrowSchema { .iter() .map(|field| Ok(Box::new(Ffi_ArrowSchema::try_new(field.clone())?))) .collect::>>()?, + DataType::Union(fields, _, _) => fields + .iter() + .map(|field| Ok(Box::new(Ffi_ArrowSchema::try_new(field.clone())?))) + .collect::>>()?, _ => vec![], }; // note: this cannot be done along with the above because the above is fallible and this op leaks. @@ -255,6 +259,21 @@ fn to_data_type(schema: &Ffi_ArrowSchema) -> Result { ArrowError::Ffi("Decimal scale is not a valid integer".to_string()) })?; DataType::Decimal(precision, scale) + } else if !parts.is_empty() && ((parts[0] == "+us") || (parts[0] == "+ud")) { + // union + let is_sparse = parts[0] == "+us"; + let type_ids = parts[1] + .split(',') + .map(|x| { + x.parse::().map_err(|_| { + ArrowError::Ffi("Union type id is not a valid integer".to_string()) + }) + }) + .collect::>>()?; + let fields = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + DataType::Union(fields, Some(type_ids), is_sparse) } else { return Err(ArrowError::Ffi(format!( "The datatype \"{}\" is still not supported in Rust implementation", @@ -316,7 +335,19 @@ fn to_format(data_type: &DataType) -> Result { DataType::Struct(_) => "+s", DataType::FixedSizeBinary(size) => return Ok(format!("w{}", size)), DataType::FixedSizeList(_, size) => return Ok(format!("+w:{}", size)), - DataType::Union(_) => todo!(), + DataType::Union(f, ids, is_sparse) => { + let sparsness = if *is_sparse { 's' } else { 'd' }; + let mut r = format!("+u{}:", sparsness); + let ids = if let Some(ids) = ids { + ids.iter() + .fold(String::new(), |a, b| a + &b.to_string() + ",") + } else { + (0..f.len()).fold(String::new(), |a, b| a + &b.to_string() + ",") + }; + let ids = &ids[..ids.len() - 1]; // take away last "," + r.push_str(ids); + return Ok(r); + } DataType::Dictionary(index, _) => return to_format(index.as_ref()), _ => todo!(), } diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index c5452db242d..07d94c5b81f 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -19,6 +19,7 @@ use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use crate::endianess::is_native_little_endian; +use crate::io::ipc::convert::ipc::UnionMode; mod ipc { pub use super::super::gen::File::*; @@ -276,6 +277,22 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT let fsb = field.type_as_decimal().unwrap(); DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) } + ipc::Type::Union => { + let type_ = field.type_as_union().unwrap(); + + let is_sparse = type_.mode() == UnionMode::Sparse; + + let ids = type_.typeIds().map(|x| x.iter().collect()); + + let fields = if let Some(children) = field.children() { + (0..children.len()) + .map(|i| children.get(i).into()) + .collect() + } else { + vec![] + }; + DataType::Union(fields, ids, is_sparse) + } t => unimplemented!("Type {:?} not supported", t), } } @@ -604,7 +621,27 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - t => unimplemented!("Type {:?} not supported", t), + Union(fields, ids, is_sparse) => { + let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); + + let ids = ids.as_ref().map(|ids| fbb.create_vector(ids)); + + let mut builder = ipc::UnionBuilder::new(fbb); + builder.add_mode(if *is_sparse { + UnionMode::Sparse + } else { + UnionMode::Dense + }); + + if let Some(ids) = ids { + builder.add_typeIds(ids); + } + FbFieldType { + type_type: ipc::Type::Union, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&children)), + } + } } } diff --git a/src/io/ipc/read/array/fixed_size_list.rs b/src/io/ipc/read/array/fixed_size_list.rs index a7416cd9ca3..8fb9b45cbcd 100644 --- a/src/io/ipc/read/array/fixed_size_list.rs +++ b/src/io/ipc/read/array/fixed_size_list.rs @@ -1,6 +1,8 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::FixedSizeListArray; use crate::datatypes::DataType; use crate::error::Result; @@ -18,6 +20,7 @@ pub fn read_fixed_size_list( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result { let field_node = field_nodes.pop_front().unwrap().0; @@ -40,6 +43,7 @@ pub fn read_fixed_size_list( block_offset, is_little_endian, compression, + version, )?; Ok(FixedSizeListArray::from_data(data_type, values, validity)) } diff --git a/src/io/ipc/read/array/list.rs b/src/io/ipc/read/array/list.rs index a876576fa27..61ff8b8612b 100644 --- a/src/io/ipc/read/array/list.rs +++ b/src/io/ipc/read/array/list.rs @@ -2,6 +2,8 @@ use std::collections::VecDeque; use std::convert::TryInto; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::{ListArray, Offset}; use crate::buffer::Buffer; use crate::datatypes::DataType; @@ -20,6 +22,7 @@ pub fn read_list( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result> where Vec: TryInto, @@ -56,6 +59,7 @@ where block_offset, is_little_endian, compression, + version, )?; Ok(ListArray::from_data(data_type, offsets, values, validity)) } diff --git a/src/io/ipc/read/array/mod.rs b/src/io/ipc/read/array/mod.rs index 458c62123fb..0dd2610510e 100644 --- a/src/io/ipc/read/array/mod.rs +++ b/src/io/ipc/read/array/mod.rs @@ -18,3 +18,5 @@ mod null; pub use null::*; mod dictionary; pub use dictionary::*; +mod union; +pub use union::*; diff --git a/src/io/ipc/read/array/struct_.rs b/src/io/ipc/read/array/struct_.rs index 95274459731..c259849c37a 100644 --- a/src/io/ipc/read/array/struct_.rs +++ b/src/io/ipc/read/array/struct_.rs @@ -1,6 +1,8 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; +use gen::Schema::MetadataVersion; + use crate::array::StructArray; use crate::datatypes::DataType; use crate::error::Result; @@ -18,6 +20,7 @@ pub fn read_struct( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result { let field_node = field_nodes.pop_front().unwrap().0; @@ -43,6 +46,7 @@ pub fn read_struct( block_offset, is_little_endian, compression, + version, ) }) .collect::>>()?; diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs new file mode 100644 index 00000000000..adaac0f13cd --- /dev/null +++ b/src/io/ipc/read/array/union.rs @@ -0,0 +1,99 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use gen::Schema::MetadataVersion; + +use crate::array::UnionArray; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::ipc::gen::Message::BodyCompression; + +use super::super::super::gen; +use super::super::deserialize::{read, skip, Node}; +use super::super::read_basic::*; + +pub fn read_union( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque<&gen::Schema::Buffer>, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + version: MetadataVersion, +) -> Result { + let field_node = field_nodes.pop_front().unwrap().0; + + if version != MetadataVersion::V5 { + let _ = buffers.pop_front().unwrap(); + }; + + let types = read_buffer( + buffers, + field_node.length() as usize, + reader, + block_offset, + is_little_endian, + compression, + )?; + + let offsets = if let DataType::Union(_, _, is_sparse) = data_type { + if !is_sparse { + Some(read_buffer( + buffers, + field_node.length() as usize, + reader, + block_offset, + is_little_endian, + compression, + )?) + } else { + None + } + } else { + panic!() + }; + + let fields = UnionArray::get_fields(&data_type); + + let fields = fields + .iter() + .map(|field| { + read( + field_nodes, + field.data_type().clone(), + buffers, + reader, + block_offset, + is_little_endian, + compression, + version, + ) + }) + .collect::>>()?; + + Ok(UnionArray::from_data(data_type, types, fields, offsets)) +} + +pub fn skip_union( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque<&gen::Schema::Buffer>, +) { + let _ = field_nodes.pop_front().unwrap(); + + let _ = buffers.pop_front().unwrap(); + if let DataType::Union(_, _, is_sparse) = data_type { + if !*is_sparse { + let _ = buffers.pop_front().unwrap(); + } + } else { + panic!() + }; + + let fields = UnionArray::get_fields(data_type); + + fields + .iter() + .for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index 0cc9fc468ea..e0524c558cd 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -19,6 +19,8 @@ use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; use std::sync::Arc; +use gen::Schema::MetadataVersion; + use crate::array::*; use crate::datatypes::{DataType, Field, Schema}; use crate::error::{ArrowError, Result}; @@ -96,6 +98,7 @@ pub fn read_record_batch( projection: Option<(&[usize], Arc)>, is_little_endian: bool, dictionaries: &[Option], + version: MetadataVersion, reader: &mut R, block_offset: u64, ) -> Result { @@ -130,6 +133,7 @@ pub fn read_record_batch( block_offset, is_little_endian, batch.compression(), + version, )), ProjectionResult::NotSelected(field) => { skip(&mut field_nodes, field.data_type(), &mut buffers); @@ -152,6 +156,7 @@ pub fn read_record_batch( block_offset, is_little_endian, batch.compression(), + version, ) }) .collect::>>()?; @@ -199,6 +204,7 @@ pub fn read_dictionary( None, is_little_endian, dictionaries_by_field, + MetadataVersion::V5, reader, block_offset, )?; diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index f244024bb0a..d45c7453342 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -9,6 +9,8 @@ use std::{ sync::Arc, }; +use gen::Schema::MetadataVersion; + use crate::datatypes::{DataType, IntervalUnit}; use crate::error::Result; use crate::io::ipc::gen::Message::BodyCompression; @@ -27,6 +29,7 @@ pub fn read( block_offset: u64, is_little_endian: bool, compression: Option, + version: MetadataVersion, ) -> Result> { match data_type { DataType::Null => { @@ -229,6 +232,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::LargeList(_) => read_list::( @@ -239,6 +243,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::FixedSizeList(_, _) => read_fixed_size_list( @@ -249,6 +254,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::Struct(_) => read_struct( @@ -259,6 +265,7 @@ pub fn read( block_offset, is_little_endian, compression, + version, ) .map(|x| Arc::new(x) as Arc), DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { @@ -328,7 +335,17 @@ pub fn read( .map(|x| Arc::new(x) as Arc), _ => unreachable!(), }, - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => read_union( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + version, + ) + .map(|x| Arc::new(x) as Arc), } } @@ -367,6 +384,6 @@ pub fn skip( DataType::FixedSizeList(_, _) => skip_fixed_size_list(field_nodes, data_type, buffers), DataType::Struct(_) => skip_struct(field_nodes, data_type, buffers), DataType::Dictionary(_, _) => skip_dictionary(field_nodes, buffers), - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => skip_union(field_nodes, data_type, buffers), } } diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index e70498b26f3..e4656c1caff 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -217,6 +217,7 @@ pub fn read_batch( projection, metadata.is_little_endian, &metadata.dictionaries_by_field, + metadata.version, reader, block.offset() as u64 + block.metaDataLength() as u64, ) @@ -415,6 +416,16 @@ mod tests { test_file("1.0.0-bigendian", "generated_interval") } + #[test] + fn read_generated_100_union() -> Result<()> { + test_file("1.0.0-littleendian", "generated_union") + } + + #[test] + fn read_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } + #[test] fn read_generated_200_compression_lz4() -> Result<()> { test_file("2.0.0-compression", "generated_lz4") diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 01d10f54dfe..ddf432145dc 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -18,6 +18,8 @@ use std::io::Read; use std::sync::Arc; +use gen::Schema::MetadataVersion; + use crate::array::*; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; @@ -34,6 +36,8 @@ pub struct StreamMetadata { /// The schema that is read from the stream's first message schema: Arc, + version: MetadataVersion, + /// Whether the incoming stream is little-endian is_little_endian: bool, } @@ -57,6 +61,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { let message = gen::Message::root_as_message(meta_buffer.as_slice()) .map_err(|err| ArrowError::Ipc(format!("Unable to get root as message: {:?}", err)))?; + let version = message.version(); // message header is a Schema, so read it let ipc_schema: gen::Schema::Schema = message .header_as_schema() @@ -66,6 +71,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { Ok(StreamMetadata { schema, + version, is_little_endian, }) } @@ -134,6 +140,7 @@ pub fn read_next( None, metadata.is_little_endian, dictionaries_by_field, + metadata.version, &mut reader, 0, ) @@ -324,4 +331,9 @@ mod tests { fn read_generated_200_compression_zstd() -> Result<()> { test_file("2.0.0-compression", "generated_zstd") } + + #[test] + fn read_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } } diff --git a/src/io/ipc/write/serialize.rs b/src/io/ipc/write/serialize.rs index 60f529395d0..0fa669452cf 100644 --- a/src/io/ipc/write/serialize.rs +++ b/src/io/ipc/write/serialize.rs @@ -16,10 +16,7 @@ // under the License. use crate::{ - array::{ - Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, FixedSizeBinaryArray, - FixedSizeListArray, ListArray, Offset, PrimitiveArray, StructArray, Utf8Array, - }, + array::*, bitmap::Bitmap, datatypes::{DataType, IntervalUnit}, endianess::is_native_little_endian, @@ -236,6 +233,33 @@ pub fn write_struct( }); } +pub fn write_union( + array: &dyn Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, +) { + let array = array.as_any().downcast_ref::().unwrap(); + + write_buffer(array.types(), buffers, arrow_data, offset, is_little_endian); + + if let Some(offsets) = array.offsets() { + write_buffer(offsets, buffers, arrow_data, offset, is_little_endian); + } + array.fields().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + ) + }); +} + fn write_fixed_size_list( array: &dyn Array, buffers: &mut Vec, @@ -467,7 +491,9 @@ pub fn write( true, ); } - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, _) => { + write_union(array, buffers, arrow_data, nodes, offset, is_little_endian); + } } } diff --git a/src/io/ipc/write/writer.rs b/src/io/ipc/write/writer.rs index 1f108dd94d8..e642648914d 100644 --- a/src/io/ipc/write/writer.rs +++ b/src/io/ipc/write/writer.rs @@ -330,6 +330,17 @@ mod tests { test_file("1.0.0-bigendian", "generated_decimal") } + #[test] + fn write_100_union() -> Result<()> { + test_file("1.0.0-littleendian", "generated_union")?; + test_file("1.0.0-bigendian", "generated_union") + } + + #[test] + fn write_generated_017_union() -> Result<()> { + test_file("0.17.1", "generated_union") + } + #[test] fn write_sliced_utf8() -> Result<()> { use crate::array::{Array, Utf8Array}; diff --git a/src/io/json/read/deserialize.rs b/src/io/json/read/deserialize.rs index 8a02c95f7e9..18c7499861b 100644 --- a/src/io/json/read/deserialize.rs +++ b/src/io/json/read/deserialize.rs @@ -255,7 +255,6 @@ pub fn read(rows: &[&Value], data_type: DataType) -> Arc { /* DataType::FixedSizeBinary(_) => Box::new(FixedSizeBinaryArray::new_empty(data_type)), DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)), - DataType::Union(_) => unimplemented!(), DataType::Decimal(_, _) => Box::new(PrimitiveArray::::new_empty(data_type)), */ } diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index be4880fefa8..29a015ce1b5 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -138,5 +138,7 @@ pub struct ArrowJsonColumn { pub data: Option>, #[serde(rename = "OFFSET")] pub offset: Option>, // leaving as Value as 64-bit offsets are strings + #[serde(rename = "TYPE_ID")] + pub type_id: Option>, // for union types pub children: Option>, } diff --git a/src/io/json_integration/read.rs b/src/io/json_integration/read.rs index 6063335a95b..0d0e4c6f968 100644 --- a/src/io/json_integration/read.rs +++ b/src/io/json_integration/read.rs @@ -22,7 +22,7 @@ use serde_json::Value; use crate::{ array::*, - bitmap::Bitmap, + bitmap::{Bitmap, MutableBitmap}, buffer::Buffer, datatypes::{DataType, Field, IntervalUnit, Schema}, error::{ArrowError, Result}, @@ -33,9 +33,12 @@ use crate::{ use super::{ArrowJsonBatch, ArrowJsonColumn, ArrowJsonDictionaryBatch}; fn to_validity(validity: &Option>) -> Option { - validity - .as_ref() - .map(|x| x.iter().map(|is_valid| *is_valid == 1).collect::()) + validity.as_ref().and_then(|x| { + x.iter() + .map(|is_valid| *is_valid == 1) + .collect::() + .into() + }) } fn to_offsets(offsets: Option<&Vec>) -> Buffer { @@ -174,10 +177,7 @@ fn to_list( data_type: DataType, dictionaries: &HashMap, ) -> Result> { - let validity = json_col - .validity - .as_ref() - .map(|x| x.iter().map(|is_valid| *is_valid == 1).collect::()); + let validity = to_validity(&json_col.validity); let child_field = ListArray::::get_child_field(&data_type); let children = &json_col.children.as_ref().unwrap()[0]; @@ -222,21 +222,15 @@ pub fn to_array( match data_type { DataType::Null => Ok(Arc::new(NullArray::from_data(json_col.count))), DataType::Boolean => { - let array = json_col - .validity + let validity = to_validity(&json_col.validity); + let values = json_col + .data .as_ref() .unwrap() .iter() - .zip(json_col.data.as_ref().unwrap()) - .map(|(is_valid, value)| { - if *is_valid == 1 { - Some(value.as_bool().unwrap()) - } else { - None - } - }) - .collect::(); - Ok(Arc::new(array)) + .map(|value| value.as_bool().unwrap()) + .collect::(); + Ok(Arc::new(BooleanArray::from_data(values, validity))) } DataType::Int8 => Ok(Arc::new(to_primitive::(json_col, data_type.clone()))), DataType::Int16 => Ok(Arc::new(to_primitive::(json_col, data_type.clone()))), @@ -322,9 +316,51 @@ pub fn to_array( _ => unreachable!(), }, DataType::Float16 => unreachable!(), - DataType::Union(_) => Err(ArrowError::NotYetImplemented( - "Union not supported".to_string(), - )), + DataType::Union(fields, _, _) => { + let fields = fields + .iter() + .zip(json_col.children.as_ref().unwrap()) + .map(|(field, col)| to_array(field, col, dictionaries)) + .collect::>>()?; + + let types = json_col + .type_id + .as_ref() + .map(|x| { + x.iter() + .map(|value| match value { + Value::Number(x) => { + x.as_i64().and_then(num::cast::cast::).unwrap() + } + Value::String(x) => x.parse::().ok().unwrap(), + _ => { + panic!() + } + }) + .collect() + }) + .unwrap_or_default(); + + let offsets = json_col + .offset + .as_ref() + .map(|x| { + Some( + x.iter() + .map(|value| match value { + Value::Number(x) => { + x.as_i64().and_then(num::cast::cast::).unwrap() + } + _ => panic!(), + }) + .collect(), + ) + }) + .unwrap_or_default(); + + let array = UnionArray::from_data(data_type.clone(), types, fields, offsets); + Ok(Arc::new(array)) + } } } diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index 5ed21db1085..b6c85bc22df 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -56,7 +56,7 @@ impl ToJson for DataType { json!({"name": "fixedsizebinary", "byteWidth": byte_width}) } DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_) => json!({"name": "union"}), + DataType::Union(_, _, _) => json!({"name": "union"}), DataType::List(_) => json!({ "name": "list"}), DataType::LargeList(_) => json!({ "name": "largelist"}), DataType::FixedSizeList(_, length) => { @@ -333,6 +333,19 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { } } "struct" => DataType::Struct(children), + "union" => { + let is_sparse = if let Some(Value::String(mode)) = item.get("mode") { + mode == "SPARSE" + } else { + return Err(ArrowError::Schema("union requires mode".to_string())); + }; + let ids = if let Some(Value::Array(ids)) = item.get("typeIds") { + Some(ids.iter().map(|x| x.as_i64().unwrap() as i32).collect()) + } else { + return Err(ArrowError::Schema("union requires ids".to_string())); + }; + DataType::Union(children, ids, is_sparse) + } other => { return Err(ArrowError::Schema(format!( "invalid json value type \"{}\"", diff --git a/src/io/json_integration/write.rs b/src/io/json_integration/write.rs index c99ec404326..fcb1ebffec3 100644 --- a/src/io/json_integration/write.rs +++ b/src/io/json_integration/write.rs @@ -25,6 +25,7 @@ pub fn from_record_batch(batch: &RecordBatch) -> ArrowJsonBatch { validity: Some(validity), data: Some(data), offset: None, + type_id: None, children: None, } } @@ -34,6 +35,7 @@ pub fn from_record_batch(batch: &RecordBatch) -> ArrowJsonBatch { validity: None, data: None, offset: None, + type_id: None, children: None, }, }; diff --git a/src/io/print.rs b/src/io/print.rs index 6a7b2c533de..be65b955980 100644 --- a/src/io/print.rs +++ b/src/io/print.rs @@ -67,7 +67,7 @@ fn create_table(results: &[RecordBatch]) -> Table { #[cfg(test)] mod tests { - use crate::{array::*, bitmap::Bitmap, datatypes::*, error::Result}; + use crate::{array::*, bitmap::Bitmap, buffer::Buffer, datatypes::*, error::Result}; use super::*; use std::sync::Arc; @@ -426,4 +426,36 @@ mod tests { Ok(()) } + + #[test] + fn test_write_union() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, true); + let types = Buffer::from(&[0, 0, 1]); + let fields = vec![ + Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc, + Arc::new(Utf8Array::::from(&[Some("a"), Some("b"), Some("c")])) as Arc, + ]; + + let array = UnionArray::from_data(data_type, types, fields, None); + + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; + + let table = write(&[batch]); + + let expected = vec![ + "+---+", "| a |", "+---+", "| 1 |", "| |", "| c |", "+---+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{}", table); + + Ok(()) + } } diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index 9ec6417ad66..149703d15c2 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -130,7 +130,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } FixedSizeBinary(_) => todo!(), FixedSizeList(_, _) => todo!(), - Union(_) => todo!(), + Union(_, _, _) => todo!(), Dictionary(_, _) => todo!(), } }