From cdcf1506dd2d2a0ebaa266cda3639ee57a081c92 Mon Sep 17 00:00:00 2001 From: Simon Vandel Sillesen Date: Sat, 30 Oct 2021 22:53:02 +0200 Subject: [PATCH] Introduced `UnionMode` enum (#557) --- src/array/union/mod.rs | 22 +++++++++---------- src/datatypes/mod.rs | 35 +++++++++++++++++++++++++++++-- src/ffi/schema.rs | 10 ++++----- src/io/avro/read/schema.rs | 2 +- src/io/ipc/convert.rs | 10 ++++----- src/io/ipc/read/array/union.rs | 11 +++++----- src/io/json_integration/schema.rs | 11 ++++++---- tests/it/array/mod.rs | 26 ++++++++++++++++++----- tests/it/array/union.rs | 4 ++-- tests/it/io/print.rs | 2 +- 10 files changed, 90 insertions(+), 43 deletions(-) diff --git a/src/array/union/mod.rs b/src/array/union/mod.rs index b5a6a1d2605..77574b1706f 100644 --- a/src/array/union/mod.rs +++ b/src/array/union/mod.rs @@ -4,7 +4,7 @@ use crate::{ array::{display::get_value_display, display_fmt, new_empty_array, new_null_array, Array}, bitmap::Bitmap, buffer::Buffer, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, UnionMode}, scalar::{new_scalar, Scalar}, }; @@ -37,13 +37,13 @@ pub struct UnionArray { impl UnionArray { /// Creates a new null [`UnionArray`]. pub fn new_null(data_type: DataType, length: usize) -> Self { - if let DataType::Union(f, _, is_sparse) = &data_type { + if let DataType::Union(f, _, mode) = &data_type { let fields = f .iter() .map(|x| new_null_array(x.data_type().clone(), length).into()) .collect(); - let offsets = if *is_sparse { + let offsets = if mode.is_sparse() { None } else { Some((0..length as i32).collect::>()) @@ -60,13 +60,13 @@ impl UnionArray { /// Creates a new empty [`UnionArray`]. pub fn new_empty(data_type: DataType) -> Self { - if let DataType::Union(f, _, is_sparse) = &data_type { + if let DataType::Union(f, _, mode) = &data_type { let fields = f .iter() .map(|x| new_empty_array(x.data_type().clone()).into()) .collect(); - let offsets = if *is_sparse { + let offsets = if mode.is_sparse() { None } else { Some(Buffer::new()) @@ -92,7 +92,7 @@ impl UnionArray { fields: Vec>, offsets: Option>, ) -> Self { - let (f, ids, is_sparse) = Self::get_all(&data_type); + let (f, ids, mode) = Self::get_all(&data_type); if f.len() != fields.len() { panic!("The number of `fields` must equal the number of fields in the Union DataType") @@ -104,7 +104,7 @@ impl UnionArray { if !same_data_types { panic!("All fields' datatype in the union must equal the datatypes on the fields.") } - if offsets.is_none() != is_sparse { + if offsets.is_none() != mode.is_sparse() { panic!("Sparsness flag must equal to noness of offsets in UnionArray") } let fields_hash = ids.as_ref().map(|ids| { @@ -244,11 +244,9 @@ impl Array for UnionArray { } impl UnionArray { - fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, bool) { + fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, UnionMode) { match data_type.to_logical_type() { - DataType::Union(fields, ids, is_sparse) => { - (fields, ids.as_ref().map(|x| x.as_ref()), *is_sparse) - } + DataType::Union(fields, ids, mode) => (fields, ids.as_ref().map(|x| x.as_ref()), *mode), _ => panic!("Wrong datatype passed to UnionArray."), } } @@ -264,7 +262,7 @@ impl UnionArray { /// # Panic /// Panics iff `data_type`'s logical type is not [`DataType::Union`]. pub fn is_sparse(data_type: &DataType) -> bool { - Self::get_all(data_type).2 + Self::get_all(data_type).2.is_sparse() } } diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 9386f9c7c25..81ab5f71b96 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -90,8 +90,8 @@ pub enum DataType { /// A nested datatype that contains a number of sub-fields. Struct(Vec), /// A nested datatype that can represent slots of differing types. - /// Third argument represents sparsness - Union(Vec, Option>, bool), + /// Third argument represents mode + Union(Vec, Option>, UnionMode), /// A nested type that is represented as /// /// List> @@ -144,6 +144,37 @@ impl std::fmt::Display for DataType { } } +/// Mode of [`DataType::Union`] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum UnionMode { + /// Dense union + Dense, + /// Sparse union + Sparse, +} + +impl UnionMode { + /// Constructs a [`UnionMode::Sparse`] if the input bool is true, + /// or otherwise constructs a [`UnionMode::Dense`] + pub fn sparse(is_sparse: bool) -> Self { + if is_sparse { + Self::Sparse + } else { + Self::Dense + } + } + + /// Returns whether the mode is sparse + pub fn is_sparse(&self) -> bool { + matches!(self, Self::Sparse) + } + + /// Returns whether the mode is dense + pub fn is_dense(&self) -> bool { + matches!(self, Self::Dense) + } +} + /// The time units defined in Arrow. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TimeUnit { diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index 578d5a49a20..6f53d44586d 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr}; use crate::{ - datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit}, + datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit, UnionMode}, error::{ArrowError, Result}, }; @@ -314,7 +314,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result { DataType::Decimal(precision, scale) } else if !parts.is_empty() && ((parts[0] == "+us") || (parts[0] == "+ud")) { // union - let is_sparse = parts[0] == "+us"; + let mode = UnionMode::sparse(parts[0] == "+us"); let type_ids = parts[1] .split(',') .map(|x| { @@ -326,7 +326,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result { let fields = (0..schema.n_children as usize) .map(|x| to_field(schema.child(x))) .collect::>>()?; - DataType::Union(fields, Some(type_ids), is_sparse) + DataType::Union(fields, Some(type_ids), mode) } else { return Err(ArrowError::Ffi(format!( "The datatype \"{}\" is still not supported in Rust implementation", @@ -397,8 +397,8 @@ fn to_format(data_type: &DataType) -> String { DataType::Struct(_) => "+s".to_string(), DataType::FixedSizeBinary(size) => format!("w{}", size), DataType::FixedSizeList(_, size) => format!("+w:{}", size), - DataType::Union(f, ids, is_sparse) => { - let sparsness = if *is_sparse { 's' } else { 'd' }; + DataType::Union(f, ids, mode) => { + let sparsness = if mode.is_sparse() { 's' } else { 'd' }; let mut r = format!("+u{}:", sparsness); let ids = if let Some(ids) = ids { ids.iter() diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs index ba7e2b494f1..7e0e8551c1a 100644 --- a/src/io/avro/read/schema.rs +++ b/src/io/avro/read/schema.rs @@ -142,7 +142,7 @@ fn schema_to_field( .iter() .map(|s| schema_to_field(s, None, has_nullable, None)) .collect::>>()?; - DataType::Union(fields, None, false) + DataType::Union(fields, None, UnionMode::Dense) } } AvroSchema::Record { name, fields, .. } => { diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index 69e81f1d46a..ac78bdf7a67 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -28,7 +28,7 @@ mod ipc { } use crate::datatypes::{ - get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, + get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, }; use crate::io::ipc::endianess::is_native_little_endian; @@ -292,7 +292,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo ipc::Type::Union => { let type_ = field.type_as_union().unwrap(); - let is_sparse = type_.mode() == ipc::UnionMode::Sparse; + let mode = UnionMode::sparse(type_.mode() == ipc::UnionMode::Sparse); let ids = type_.typeIds().map(|x| x.iter().collect()); @@ -303,7 +303,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo } else { vec![] }; - DataType::Union(fields, ids, is_sparse) + DataType::Union(fields, ids, mode) } ipc::Type::Map => { let map = field.type_as_map().unwrap(); @@ -704,13 +704,13 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - Union(fields, ids, is_sparse) => { + Union(fields, ids, mode) => { let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); let ids = ids.as_ref().map(|ids| fbb.create_vector(ids)); let mut builder = ipc::UnionBuilder::new(fbb); - builder.add_mode(if *is_sparse { + builder.add_mode(if mode.is_sparse() { ipc::UnionMode::Sparse } else { ipc::UnionMode::Dense diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs index 6e70b4045d8..66cab751c42 100644 --- a/src/io/ipc/read/array/union.rs +++ b/src/io/ipc/read/array/union.rs @@ -5,6 +5,7 @@ use arrow_format::ipc; use crate::array::UnionArray; use crate::datatypes::DataType; +use crate::datatypes::UnionMode::Dense; use crate::error::Result; use super::super::deserialize::{read, skip, Node}; @@ -36,8 +37,8 @@ pub fn read_union( compression, )?; - let offsets = if let DataType::Union(_, _, is_sparse) = data_type { - if !is_sparse { + let offsets = if let DataType::Union(_, _, mode) = data_type { + if !mode.is_sparse() { Some(read_buffer( buffers, field_node.length() as usize, @@ -82,10 +83,8 @@ pub fn skip_union( let _ = field_nodes.pop_front().unwrap(); let _ = buffers.pop_front().unwrap(); - if let DataType::Union(_, _, is_sparse) = data_type { - if !*is_sparse { - let _ = buffers.pop_front().unwrap(); - } + if let DataType::Union(_, _, Dense) = data_type { + let _ = buffers.pop_front().unwrap(); } else { panic!() }; diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index b57b04ef9dc..b461408c3c1 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -23,7 +23,10 @@ use std::{ use serde_derive::Deserialize; use serde_json::{json, Value}; -use crate::error::{ArrowError, Result}; +use crate::{ + datatypes::UnionMode, + error::{ArrowError, Result}, +}; use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit}; @@ -395,8 +398,8 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { } "struct" => DataType::Struct(children), "union" => { - let is_sparse = if let Some(Value::String(mode)) = item.get("mode") { - mode == "SPARSE" + let mode = if let Some(Value::String(mode)) = item.get("mode") { + UnionMode::sparse(mode == "SPARSE") } else { return Err(ArrowError::Schema("union requires mode".to_string())); }; @@ -405,7 +408,7 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { } else { return Err(ArrowError::Schema("union requires ids".to_string())); }; - DataType::Union(children, ids, is_sparse) + DataType::Union(children, ids, mode) } "map" => { let sorted_keys = if let Some(Value::Bool(sorted_keys)) = item.get("keysSorted") { diff --git a/tests/it/array/mod.rs b/tests/it/array/mod.rs index 6feb6b905c0..d9430594942 100644 --- a/tests/it/array/mod.rs +++ b/tests/it/array/mod.rs @@ -13,7 +13,7 @@ mod utf8; use arrow2::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray}; use arrow2::bitmap::Bitmap; -use arrow2::datatypes::{DataType, Field}; +use arrow2::datatypes::{DataType, Field, UnionMode}; #[test] fn nulls() { @@ -31,8 +31,16 @@ fn nulls() { // unions' null count is always 0 let datatypes = vec![ - DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false), - DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true), + DataType::Union( + vec![Field::new("a", DataType::Binary, true)], + None, + UnionMode::Dense, + ), + DataType::Union( + vec![Field::new("a", DataType::Binary, true)], + None, + UnionMode::Sparse, + ), ]; let a = datatypes .into_iter() @@ -48,8 +56,16 @@ fn empty() { DataType::Utf8, DataType::Binary, DataType::List(Box::new(Field::new("a", DataType::Binary, true))), - DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true), - DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false), + DataType::Union( + vec![Field::new("a", DataType::Binary, true)], + None, + UnionMode::Sparse, + ), + DataType::Union( + vec![Field::new("a", DataType::Binary, true)], + None, + UnionMode::Dense, + ), ]; let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0); assert!(a); diff --git a/tests/it/array/union.rs b/tests/it/array/union.rs index 7e0dfa73a12..9ed49a272b6 100644 --- a/tests/it/array/union.rs +++ b/tests/it/array/union.rs @@ -8,7 +8,7 @@ fn display() -> Result<()> { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Utf8, true), ]; - let data_type = DataType::Union(fields, None, true); + let data_type = DataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(&[0, 0, 1]); let fields = vec![ Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc, @@ -28,7 +28,7 @@ fn slice() -> Result<()> { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Utf8, true), ]; - let data_type = DataType::Union(fields, None, true); + let data_type = DataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(&[0, 0, 1]); let fields = vec![ Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc, diff --git a/tests/it/io/print.rs b/tests/it/io/print.rs index bce6750ee7f..b8fa3f6451b 100644 --- a/tests/it/io/print.rs +++ b/tests/it/io/print.rs @@ -391,7 +391,7 @@ fn write_union() -> Result<()> { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Utf8, true), ]; - let data_type = DataType::Union(fields, None, true); + let data_type = DataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(&[0, 0, 1]); let fields = vec![ Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc,