Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Introduce UnionMode enum
Browse files Browse the repository at this point in the history
I feel like seeing `UnionMode::Sparse` is easier than `true`
  • Loading branch information
simonvandel committed Oct 30, 2021
1 parent 5fc843d commit bbb7c9c
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 43 deletions.
22 changes: 10 additions & 12 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
array::{display::get_value_display, display_fmt, new_empty_array, new_null_array, Array},
bitmap::Bitmap,
buffer::Buffer,
datatypes::{DataType, Field},
datatypes::{DataType, Field, UnionMode},
scalar::{new_scalar, Scalar},
};

Expand Down Expand Up @@ -37,13 +37,13 @@ pub struct UnionArray {
impl UnionArray {
/// Creates a new null [`UnionArray`].
pub fn new_null(data_type: DataType, length: usize) -> Self {
if let DataType::Union(f, _, is_sparse) = &data_type {
if let DataType::Union(f, _, mode) = &data_type {
let fields = f
.iter()
.map(|x| new_null_array(x.data_type().clone(), length).into())
.collect();

let offsets = if *is_sparse {
let offsets = if mode.is_sparse() {
None
} else {
Some((0..length as i32).collect::<Buffer<i32>>())
Expand All @@ -60,13 +60,13 @@ impl UnionArray {

/// Creates a new empty [`UnionArray`].
pub fn new_empty(data_type: DataType) -> Self {
if let DataType::Union(f, _, is_sparse) = &data_type {
if let DataType::Union(f, _, mode) = &data_type {
let fields = f
.iter()
.map(|x| new_empty_array(x.data_type().clone()).into())
.collect();

let offsets = if *is_sparse {
let offsets = if mode.is_sparse() {
None
} else {
Some(Buffer::new())
Expand All @@ -92,7 +92,7 @@ impl UnionArray {
fields: Vec<Arc<dyn Array>>,
offsets: Option<Buffer<i32>>,
) -> Self {
let (f, ids, is_sparse) = Self::get_all(&data_type);
let (f, ids, mode) = Self::get_all(&data_type);

if f.len() != fields.len() {
panic!("The number of `fields` must equal the number of fields in the Union DataType")
Expand All @@ -104,7 +104,7 @@ impl UnionArray {
if !same_data_types {
panic!("All fields' datatype in the union must equal the datatypes on the fields.")
}
if offsets.is_none() != is_sparse {
if offsets.is_none() != mode.is_sparse() {
panic!("Sparsness flag must equal to noness of offsets in UnionArray")
}
let fields_hash = ids.as_ref().map(|ids| {
Expand Down Expand Up @@ -244,11 +244,9 @@ impl Array for UnionArray {
}

impl UnionArray {
fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, bool) {
fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, UnionMode) {
match data_type.to_logical_type() {
DataType::Union(fields, ids, is_sparse) => {
(fields, ids.as_ref().map(|x| x.as_ref()), *is_sparse)
}
DataType::Union(fields, ids, mode) => (fields, ids.as_ref().map(|x| x.as_ref()), *mode),
_ => panic!("Wrong datatype passed to UnionArray."),
}
}
Expand All @@ -264,7 +262,7 @@ impl UnionArray {
/// # Panic
/// Panics iff `data_type`'s logical type is not [`DataType::Union`].
pub fn is_sparse(data_type: &DataType) -> bool {
Self::get_all(data_type).2
Self::get_all(data_type).2.is_sparse()
}
}

Expand Down
35 changes: 33 additions & 2 deletions src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ pub enum DataType {
/// A nested datatype that contains a number of sub-fields.
Struct(Vec<Field>),
/// A nested datatype that can represent slots of differing types.
/// Third argument represents sparsness
Union(Vec<Field>, Option<Vec<i32>>, bool),
/// Third argument represents mode
Union(Vec<Field>, Option<Vec<i32>>, UnionMode),
/// A nested type that is represented as
///
/// List<entries: Struct<key: K, value: V>>
Expand Down Expand Up @@ -144,6 +144,37 @@ impl std::fmt::Display for DataType {
}
}

/// Mode of [`DataType::Union`]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnionMode {
/// Dense union
Dense,
/// Sparse union
Sparse,
}

impl UnionMode {
/// Constructs a [`UnionMode::Sparse`] if the input bool is true,
/// or otherwise constructs a [`UnionMode::Dense`]
pub fn sparse(is_sparse: bool) -> Self {
if is_sparse {
Self::Sparse
} else {
Self::Dense
}
}

/// Returns whether the mode is sparse
pub fn is_sparse(&self) -> bool {
matches!(self, Self::Sparse)
}

/// Returns whether the mode is dense
pub fn is_dense(&self) -> bool {
matches!(self, Self::Dense)
}
}

/// The time units defined in Arrow.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TimeUnit {
Expand Down
10 changes: 5 additions & 5 deletions src/ffi/schema.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr};

use crate::{
datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit},
datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit, UnionMode},
error::{ArrowError, Result},
};

Expand Down Expand Up @@ -314,7 +314,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result<DataType> {
DataType::Decimal(precision, scale)
} else if !parts.is_empty() && ((parts[0] == "+us") || (parts[0] == "+ud")) {
// union
let is_sparse = parts[0] == "+us";
let mode = UnionMode::sparse(parts[0] == "+us");
let type_ids = parts[1]
.split(',')
.map(|x| {
Expand All @@ -326,7 +326,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result<DataType> {
let fields = (0..schema.n_children as usize)
.map(|x| to_field(schema.child(x)))
.collect::<Result<Vec<_>>>()?;
DataType::Union(fields, Some(type_ids), is_sparse)
DataType::Union(fields, Some(type_ids), mode)
} else {
return Err(ArrowError::Ffi(format!(
"The datatype \"{}\" is still not supported in Rust implementation",
Expand Down Expand Up @@ -397,8 +397,8 @@ fn to_format(data_type: &DataType) -> String {
DataType::Struct(_) => "+s".to_string(),
DataType::FixedSizeBinary(size) => format!("w{}", size),
DataType::FixedSizeList(_, size) => format!("+w:{}", size),
DataType::Union(f, ids, is_sparse) => {
let sparsness = if *is_sparse { 's' } else { 'd' };
DataType::Union(f, ids, mode) => {
let sparsness = if mode.is_sparse() { 's' } else { 'd' };
let mut r = format!("+u{}:", sparsness);
let ids = if let Some(ids) = ids {
ids.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn schema_to_field(
.iter()
.map(|s| schema_to_field(s, None, has_nullable, None))
.collect::<Result<Vec<Field>>>()?;
DataType::Union(fields, None, false)
DataType::Union(fields, None, UnionMode::Dense)
}
}
AvroSchema::Record { name, fields, .. } => {
Expand Down
10 changes: 5 additions & 5 deletions src/io/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod ipc {
}

use crate::datatypes::{
get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit,
get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode,
};
use crate::io::ipc::endianess::is_native_little_endian;

Expand Down Expand Up @@ -292,7 +292,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo
ipc::Type::Union => {
let type_ = field.type_as_union().unwrap();

let is_sparse = type_.mode() == ipc::UnionMode::Sparse;
let mode = UnionMode::sparse(type_.mode() == ipc::UnionMode::Sparse);

let ids = type_.typeIds().map(|x| x.iter().collect());

Expand All @@ -303,7 +303,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo
} else {
vec![]
};
DataType::Union(fields, ids, is_sparse)
DataType::Union(fields, ids, mode)
}
ipc::Type::Map => {
let map = field.type_as_map().unwrap();
Expand Down Expand Up @@ -704,13 +704,13 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
Union(fields, ids, is_sparse) => {
Union(fields, ids, mode) => {
let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect();

let ids = ids.as_ref().map(|ids| fbb.create_vector(ids));

let mut builder = ipc::UnionBuilder::new(fbb);
builder.add_mode(if *is_sparse {
builder.add_mode(if mode.is_sparse() {
ipc::UnionMode::Sparse
} else {
ipc::UnionMode::Dense
Expand Down
11 changes: 5 additions & 6 deletions src/io/ipc/read/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow_format::ipc;

use crate::array::UnionArray;
use crate::datatypes::DataType;
use crate::datatypes::UnionMode::Dense;
use crate::error::Result;

use super::super::deserialize::{read, skip, Node};
Expand Down Expand Up @@ -36,8 +37,8 @@ pub fn read_union<R: Read + Seek>(
compression,
)?;

let offsets = if let DataType::Union(_, _, is_sparse) = data_type {
if !is_sparse {
let offsets = if let DataType::Union(_, _, mode) = data_type {
if !mode.is_sparse() {
Some(read_buffer(
buffers,
field_node.length() as usize,
Expand Down Expand Up @@ -82,10 +83,8 @@ pub fn skip_union(
let _ = field_nodes.pop_front().unwrap();

let _ = buffers.pop_front().unwrap();
if let DataType::Union(_, _, is_sparse) = data_type {
if !*is_sparse {
let _ = buffers.pop_front().unwrap();
}
if let DataType::Union(_, _, Dense) = data_type {
let _ = buffers.pop_front().unwrap();
} else {
panic!()
};
Expand Down
11 changes: 7 additions & 4 deletions src/io/json_integration/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::{
use serde_derive::Deserialize;
use serde_json::{json, Value};

use crate::error::{ArrowError, Result};
use crate::{
datatypes::UnionMode,
error::{ArrowError, Result},
};

use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit};

Expand Down Expand Up @@ -395,8 +398,8 @@ fn to_data_type(item: &Value, mut children: Vec<Field>) -> Result<DataType> {
}
"struct" => DataType::Struct(children),
"union" => {
let is_sparse = if let Some(Value::String(mode)) = item.get("mode") {
mode == "SPARSE"
let mode = if let Some(Value::String(mode)) = item.get("mode") {
UnionMode::sparse(mode == "SPARSE")
} else {
return Err(ArrowError::Schema("union requires mode".to_string()));
};
Expand All @@ -405,7 +408,7 @@ fn to_data_type(item: &Value, mut children: Vec<Field>) -> Result<DataType> {
} else {
return Err(ArrowError::Schema("union requires ids".to_string()));
};
DataType::Union(children, ids, is_sparse)
DataType::Union(children, ids, mode)
}
"map" => {
let sorted_keys = if let Some(Value::Bool(sorted_keys)) = item.get("keysSorted") {
Expand Down
26 changes: 21 additions & 5 deletions tests/it/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod utf8;

use arrow2::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray};
use arrow2::bitmap::Bitmap;
use arrow2::datatypes::{DataType, Field};
use arrow2::datatypes::{DataType, Field, UnionMode};

#[test]
fn nulls() {
Expand All @@ -31,8 +31,16 @@ fn nulls() {

// unions' null count is always 0
let datatypes = vec![
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Dense,
),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Sparse,
),
];
let a = datatypes
.into_iter()
Expand All @@ -48,8 +56,16 @@ fn empty() {
DataType::Utf8,
DataType::Binary,
DataType::List(Box::new(Field::new("a", DataType::Binary, true))),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Sparse,
),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Dense,
),
];
let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0);
assert!(a);
Expand Down
4 changes: 2 additions & 2 deletions tests/it/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn display() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand All @@ -28,7 +28,7 @@ fn slice() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand Down
2 changes: 1 addition & 1 deletion tests/it/io/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ fn write_union() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand Down

0 comments on commit bbb7c9c

Please sign in to comment.