Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added Union.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 12, 2021
1 parent 00200e4 commit d12961f
Show file tree
Hide file tree
Showing 31 changed files with 555 additions and 48 deletions.
2 changes: 2 additions & 0 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow2::{
datatypes::*,
error::{ArrowError, Result},
io::ipc,
io::ipc::gen::Schema::MetadataVersion,
io::ipc::read::read_record_batch,
io::ipc::write,
io::ipc::write::common::{encoded_batch, DictionaryTracker, EncodedData, IpcWriteOptions},
Expand Down Expand Up @@ -168,6 +169,7 @@ pub fn flight_data_to_arrow_batch(
None,
is_little_endian,
&dictionaries_by_field,
MetadataVersion::V5,
&mut reader,
0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow2::{
datatypes::*,
io::ipc,
io::ipc::gen::Message::{Message, MessageHeader},
io::ipc::gen::Schema::MetadataVersion,
record_batch::RecordBatch,
};
use arrow_flight::flight_descriptor::*;
Expand Down Expand Up @@ -295,6 +296,7 @@ async fn record_batch_from_message(
None,
true,
&dictionaries_by_field,
MetadataVersion::V5,
&mut reader,
0,
);
Expand Down
9 changes: 6 additions & 3 deletions src/array/equal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::unimplemented;

use crate::{
datatypes::{DataType, IntervalUnit},
types::{days_ms, NativeType},
Expand All @@ -19,6 +17,7 @@ mod list;
mod null;
mod primitive;
mod struct_;
mod union;
mod utf8;

impl PartialEq for dyn Array {
Expand Down Expand Up @@ -323,7 +322,11 @@ pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool {
let rhs = rhs.as_any().downcast_ref().unwrap();
fixed_size_list::equal(lhs, rhs)
}
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
union::equal(lhs, rhs)
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/array/equal/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use crate::array::{Array, UnionArray};

pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool {
lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter())
}
2 changes: 1 addition & 1 deletion src/array/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub fn buffers_children_dictionary(array: &dyn Array) -> BuffersChildren {
DataType::LargeList(_) => ffi_dyn!(array, ListArray::<i64>),
DataType::FixedSizeList(_, _) => ffi_dyn!(array, FixedSizeListArray),
DataType::Struct(_) => ffi_dyn!(array, StructArray),
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => ffi_dict_dyn!(array, DictionaryArray::<i8>),
DataType::Int16 => ffi_dict_dyn!(array, DictionaryArray::<i16>),
Expand Down
2 changes: 1 addition & 1 deletion src/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ pub fn make_growable<'a>(
))
}
DataType::FixedSizeList(_, _) => todo!(),
DataType::Union(_) => todo!(),
DataType::Union(_, _, _) => todo!(),
DataType::Dictionary(key, _) => match key.as_ref() {
DataType::UInt8 => dyn_dict_growable!(u8, arrays, use_validity, capacity),
DataType::UInt16 => dyn_dict_growable!(u16, arrays, use_validity, capacity),
Expand Down
13 changes: 9 additions & 4 deletions src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ pub trait Array: std::fmt::Debug + Send + Sync {
/// This is `O(1)`.
#[inline]
fn null_count(&self) -> usize {
if self.data_type() == &DataType::Null {
return self.len();
};
self.validity()
.as_ref()
.map(|x| x.null_count())
Expand Down Expand Up @@ -185,7 +188,7 @@ impl Display for dyn Array {
DataType::LargeList(_) => fmt_dyn!(self, ListArray::<i64>, f),
DataType::FixedSizeList(_, _) => fmt_dyn!(self, FixedSizeListArray, f),
DataType::Struct(_) => fmt_dyn!(self, StructArray, f),
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => fmt_dyn!(self, DictionaryArray::<i8>, f),
DataType::Int16 => fmt_dyn!(self, DictionaryArray::<i16>, f),
Expand Down Expand Up @@ -239,7 +242,7 @@ pub fn new_empty_array(data_type: DataType) -> Box<dyn Array> {
DataType::LargeList(_) => Box::new(ListArray::<i64>::new_empty(data_type)),
DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_empty(data_type)),
DataType::Struct(fields) => Box::new(StructArray::new_empty(&fields)),
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(key_type, value_type) => match key_type.as_ref() {
DataType::Int8 => Box::new(DictionaryArray::<i8>::new_empty(*value_type)),
DataType::Int16 => Box::new(DictionaryArray::<i16>::new_empty(*value_type)),
Expand Down Expand Up @@ -293,7 +296,7 @@ pub fn new_null_array(data_type: DataType, length: usize) -> Box<dyn Array> {
DataType::LargeList(_) => Box::new(ListArray::<i64>::new_null(data_type, length)),
DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new_null(data_type, length)),
DataType::Struct(fields) => Box::new(StructArray::new_null(&fields, length)),
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(key_type, value_type) => match key_type.as_ref() {
DataType::Int8 => Box::new(DictionaryArray::<i8>::new_null(*value_type, length)),
DataType::Int16 => Box::new(DictionaryArray::<i16>::new_null(*value_type, length)),
Expand Down Expand Up @@ -354,7 +357,7 @@ pub fn clone(array: &dyn Array) -> Box<dyn Array> {
DataType::LargeList(_) => clone_dyn!(array, ListArray::<i64>),
DataType::FixedSizeList(_, _) => clone_dyn!(array, FixedSizeListArray),
DataType::Struct(_) => clone_dyn!(array, StructArray),
DataType::Union(_) => unimplemented!(),
DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => clone_dyn!(array, DictionaryArray::<i8>),
DataType::Int16 => clone_dyn!(array, DictionaryArray::<i16>),
Expand All @@ -380,6 +383,7 @@ mod null;
mod primitive;
mod specification;
mod struct_;
mod union;
mod utf8;

mod equal;
Expand All @@ -399,6 +403,7 @@ pub use null::NullArray;
pub use primitive::*;
pub use specification::Offset;
pub use struct_::StructArray;
pub use union::UnionArray;
pub use utf8::{MutableUtf8Array, Utf8Array, Utf8ValuesIter};

pub(crate) use self::ffi::buffers_children_dictionary;
Expand Down
55 changes: 55 additions & 0 deletions src/array/union/iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use super::{Array, UnionArray};
use crate::{scalar::Scalar, trusted_len::TrustedLen};

#[derive(Debug, Clone)]
pub struct UnionIter<'a> {
array: &'a UnionArray,
current: usize,
}

impl<'a> UnionIter<'a> {
pub fn new(array: &'a UnionArray) -> Self {
Self { array, current: 0 }
}
}

impl<'a> Iterator for UnionIter<'a> {
type Item = Box<dyn Scalar>;

fn next(&mut self) -> Option<Self::Item> {
if self.current == self.array.len() {
None
} else {
let old = self.current;
self.current += 1;
Some(self.array.value(old))
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.len() - self.current;
(len, Some(len))
}
}

impl<'a> IntoIterator for &'a UnionArray {
type Item = Box<dyn Scalar>;
type IntoIter = UnionIter<'a>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

impl<'a> UnionArray {
/// constructs a new iterator
#[inline]
pub fn iter(&'a self) -> UnionIter<'a> {
UnionIter::new(self)
}
}

impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {}

unsafe impl<'a> TrustedLen for UnionIter<'a> {}
148 changes: 148 additions & 0 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use std::{collections::HashMap, sync::Arc};

use crate::{
bitmap::Bitmap,
buffer::Buffer,
datatypes::{DataType, Field},
scalar::{new_scalar, Scalar},
};

use super::Array;

mod iterator;

/// A union
// How to read a value at slot i:
// ```
// let index = self.types()[i] as usize;
// let field = self.fields()[index];
// let offset = self.offsets().map(|x| x[index]).unwrap_or(i);
// let field = field.as_any().downcast to correct type;
// let value = field.value(offset);
// ```
#[derive(Debug, Clone)]
pub struct UnionArray {
types: Buffer<i8>,
fields_hash: HashMap<i8, Arc<dyn Array>>,
fields: Vec<Arc<dyn Array>>,
offsets: Option<Buffer<i32>>,
data_type: DataType,
offset: usize,
}

impl UnionArray {
pub fn from_data(
data_type: DataType,
types: Buffer<i8>,
fields: Vec<Arc<dyn Array>>,
offsets: Option<Buffer<i32>>,
) -> Self {
let fields_hash = if let DataType::Union(f, ids, is_sparse) = &data_type {
let ids: Vec<i8> = ids
.as_ref()
.map(|x| x.iter().map(|x| *x as i8).collect())
.unwrap_or_else(|| (0..f.len() as i8).collect());
if f.len() != fields.len() {
panic!(
"The number of `fields` must equal the number of fields in the Union DataType"
)
};
let same_data_types = f
.iter()
.zip(fields.iter())
.all(|(f, array)| f.data_type() == array.data_type());
if !same_data_types {
panic!("All fields' datatype in the union must equal the datatypes on the fields.")
}
if offsets.is_none() != *is_sparse {
panic!("Sparsness flag must equal to noness of offsets in UnionArray")
}
ids.into_iter().zip(fields.iter().cloned()).collect()
} else {
panic!("Union struct must be created with the corresponding Union DataType")
};
// not validated:
// * `offsets` is valid
// * max id < fields.len()
Self {
data_type,
fields_hash,
fields,
offsets,
types,
offset: 0,
}
}

pub fn offsets(&self) -> &Option<Buffer<i32>> {
&self.offsets
}

pub fn fields(&self) -> &Vec<Arc<dyn Array>> {
&self.fields
}

pub fn types(&self) -> &Buffer<i8> {
&self.types
}

pub fn value(&self, index: usize) -> Box<dyn Scalar> {
let field_index = self.types()[index];
let field = self.fields_hash[&field_index].as_ref();
let offset = self
.offsets()
.as_ref()
.map(|x| x[index] as usize)
.unwrap_or(index);
new_scalar(field, offset)
}

/// Returns a slice of this [`UnionArray`].
/// # Implementation
/// This operation is `O(F)` where `F` is the number of fields.
/// # Panic
/// This function panics iff `offset + length >= self.len()`.
#[inline]
pub fn slice(&self, offset: usize, length: usize) -> Self {
Self {
data_type: self.data_type.clone(),
fields: self.fields.clone(),
fields_hash: self.fields_hash.clone(),
types: self.types.clone().slice(offset, length),
offsets: self.offsets.clone(),
offset: self.offset + offset,
}
}
}

impl Array for UnionArray {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn len(&self) -> usize {
self.types.len()
}

fn data_type(&self) -> &DataType {
&self.data_type
}

fn validity(&self) -> &Option<Bitmap> {
&None
}

fn slice(&self, offset: usize, length: usize) -> Box<dyn Array> {
Box::new(self.slice(offset, length))
}
}

impl UnionArray {
pub fn get_fields(data_type: &DataType) -> &[Field] {
if let DataType::Union(fields, _, _) = data_type {
fields
} else {
panic!("Wrong datatype passed to Struct.")
}
}
}
2 changes: 1 addition & 1 deletion src/compute/aggregate/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize {
.sum::<usize>()
+ validity_size(array.validity())
}
Union(_) => unreachable!(),
Union(_, _, _) => unreachable!(),
Dictionary(keys, _) => match keys.as_ref() {
Int8 => dyn_dict!(array, i8),
Int16 => dyn_dict!(array, i16),
Expand Down
4 changes: 2 additions & 2 deletions src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ impl Field {
));
}
},
DataType::Union(nested_fields) => match &from.data_type {
DataType::Union(from_nested_fields) => {
DataType::Union(nested_fields, _, _) => match &from.data_type {
DataType::Union(from_nested_fields, _, _) => {
for from_field in from_nested_fields {
let mut is_new_field = true;
for self_field in nested_fields.iter_mut() {
Expand Down
3 changes: 2 additions & 1 deletion src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ pub enum DataType {
/// A nested datatype that contains a number of sub-fields.
Struct(Vec<Field>),
/// A nested datatype that can represent slots of differing types.
Union(Vec<Field>),
/// Third argument represents sparsness
Union(Vec<Field>, Option<Vec<i32>>, bool),
/// A dictionary encoded array (`key_type`, `value_type`), where
/// each array element is an index of `key_type` into an
/// associated dictionary of `value_type`.
Expand Down
2 changes: 1 addition & 1 deletion src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ fn to_format(data_type: &DataType) -> Result<String> {
DataType::Struct(_) => "+s",
DataType::FixedSizeBinary(size) => return Ok(format!("w{}", size)),
DataType::FixedSizeList(_, size) => return Ok(format!("+w:{}", size)),
DataType::Union(_) => todo!(),
DataType::Union(_, _, _) => todo!(),
DataType::Dictionary(index, _) => return to_format(index.as_ref()),
_ => todo!(),
}
Expand Down
Loading

0 comments on commit d12961f

Please sign in to comment.