Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnionFields (#3955) #3981

Merged
merged 4 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef,
DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::FixedSizeList(_, _) => {
Arc::new(FixedSizeListArray::from(data)) as ArrayRef
}
Expand Down Expand Up @@ -740,7 +740,7 @@ mod tests {
use crate::cast::{as_union_array, downcast_array};
use crate::downcast_run_array;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_schema::{Field, Fields, UnionMode};
use arrow_schema::{Field, Fields, UnionFields, UnionMode};

#[test]
fn test_empty_primitive() {
Expand Down Expand Up @@ -874,11 +874,13 @@ mod tests {
fn test_null_union() {
for mode in [UnionMode::Sparse, UnionMode::Dense] {
let data_type = DataType::Union(
vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Int64, true),
],
vec![2, 1],
UnionFields::new(
vec![2, 1],
vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Int64, true),
],
),
mode,
);
let array = new_null_array(&data_type, 4);
Expand Down
53 changes: 26 additions & 27 deletions arrow-array/src/array/union_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{make_array, Array, ArrayRef};
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, UnionMode};
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
Expand Down Expand Up @@ -145,8 +145,7 @@ impl UnionArray {
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
) -> Self {
let (field_types, field_values): (Vec<_>, Vec<_>) =
child_arrays.into_iter().unzip();
let (fields, field_values): (Vec<_>, Vec<_>) = child_arrays.into_iter().unzip();
let len = type_ids.len();

let mode = if value_offsets.is_some() {
Expand All @@ -156,8 +155,7 @@ impl UnionArray {
};

let builder = ArrayData::builder(DataType::Union(
field_types,
Vec::from(field_type_ids),
UnionFields::new(field_type_ids.iter().copied(), fields),
mode,
))
.add_buffer(type_ids)
Expand Down Expand Up @@ -282,9 +280,9 @@ impl UnionArray {
/// Returns the names of the types in the union.
pub fn type_names(&self) -> Vec<&str> {
match self.data.data_type() {
DataType::Union(fields, _, _) => fields
DataType::Union(fields, _) => fields
.iter()
.map(|f| f.name().as_str())
.map(|(_, f)| f.name().as_str())
.collect::<Vec<&str>>(),
_ => unreachable!("Union array's data type is not a union!"),
}
Expand All @@ -293,7 +291,7 @@ impl UnionArray {
/// Returns whether the `UnionArray` is dense (or sparse if `false`).
fn is_dense(&self) -> bool {
match self.data.data_type() {
DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
DataType::Union(_, mode) => mode == &UnionMode::Dense,
_ => unreachable!("Union array's data type is not a union!"),
}
}
Expand All @@ -307,8 +305,8 @@ impl UnionArray {

impl From<ArrayData> for UnionArray {
fn from(data: ArrayData) -> Self {
let (field_ids, mode) = match data.data_type() {
DataType::Union(_, ids, mode) => (ids, *mode),
let (fields, mode) = match data.data_type() {
DataType::Union(fields, mode) => (fields, *mode),
d => panic!("UnionArray expected ArrayData with type Union got {d}"),
};
let (type_ids, offsets) = match mode {
Expand All @@ -326,10 +324,10 @@ impl From<ArrayData> for UnionArray {
),
};

let max_id = field_ids.iter().copied().max().unwrap_or_default() as usize;
let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
let mut boxed_fields = vec![None; max_id + 1];
for (cd, field_id) in data.child_data().iter().zip(field_ids) {
boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
}
Self {
data,
Expand Down Expand Up @@ -402,19 +400,18 @@ impl std::fmt::Debug for UnionArray {
writeln!(f, "-- type id buffer:")?;
writeln!(f, "{:?}", self.type_ids)?;

let (fields, ids) = match self.data_type() {
DataType::Union(f, ids, _) => (f, ids),
_ => unreachable!(),
};

if let Some(offsets) = &self.offsets {
writeln!(f, "-- offsets buffer:")?;
writeln!(f, "{:?}", offsets)?;
}

assert_eq!(fields.len(), ids.len());
for (field, type_id) in fields.iter().zip(ids) {
let child = self.child(*type_id);
let fields = match self.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};

for (type_id, field) in fields.iter() {
let child = self.child(type_id);
writeln!(
f,
"-- child {}: \"{}\" ({:?})",
Expand Down Expand Up @@ -1058,12 +1055,14 @@ mod tests {
#[test]
fn test_custom_type_ids() {
let data_type = DataType::Union(
vec![
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
vec![8, 4, 9],
UnionFields::new(
vec![8, 4, 9],
vec![
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
),
UnionMode::Dense,
);

Expand Down
2 changes: 1 addition & 1 deletion arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ mod tests {
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();
assert_eq!(record_batch.get_array_memory_size(), 628);
assert_eq!(record_batch.get_array_memory_size(), 564);
}

fn check_batch(record_batch: RecordBatch, num_rows: usize) {
Expand Down
14 changes: 7 additions & 7 deletions arrow-cast/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ fn make_formatter<'a>(
}
DataType::Struct(_) => array_format(as_struct_array(array), options),
DataType::Map(_, _) => array_format(as_map_array(array), options),
DataType::Union(_, _, _) => array_format(as_union_array(array), options),
DataType::Union(_, _) => array_format(as_union_array(array), options),
d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not yet supported"))),
}
}
Expand Down Expand Up @@ -801,16 +801,16 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
);

fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, ArrowError> {
let (fields, type_ids, mode) = match (*self).data_type() {
DataType::Union(fields, type_ids, mode) => (fields, type_ids, mode),
let (fields, mode) = match (*self).data_type() {
DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};

let max_id = type_ids.iter().copied().max().unwrap_or_default() as usize;
let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() as usize;
let mut out: Vec<Option<FieldDisplay>> = (0..max_id + 1).map(|_| None).collect();
for (i, field) in type_ids.iter().zip(fields) {
let formatter = make_formatter(self.child(*i).as_ref(), options)?;
out[*i as usize] = Some((field.name().as_str(), formatter))
for (i, field) in fields.iter() {
let formatter = make_formatter(self.child(i).as_ref(), options)?;
out[i as usize] = Some((field.name().as_str(), formatter))
}
Ok((out, *mode))
}
Expand Down
42 changes: 25 additions & 17 deletions arrow-cast/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,11 +703,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
),
UnionMode::Dense,
),
false,
Expand Down Expand Up @@ -743,11 +745,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
),
UnionMode::Sparse,
),
false,
Expand Down Expand Up @@ -785,11 +789,13 @@ mod tests {
let inner_field = Field::new(
"European Union",
DataType::Union(
vec![
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Float64, false),
],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Float64, false),
],
),
UnionMode::Dense,
),
false,
Expand All @@ -809,8 +815,10 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![Field::new("a", DataType::Int32, true), inner_field],
vec![0, 1],
UnionFields::new(
vec![0, 1],
vec![Field::new("a", DataType::Int32, true), inner_field],
),
UnionMode::Sparse,
),
false,
Expand Down
25 changes: 13 additions & 12 deletions arrow-data/src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
Expand All @@ -162,7 +162,7 @@ pub(crate) fn into_buffers(
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
match mode {
// Based on Union's DataTypeLayout
UnionMode::Sparse => vec![buffer1.into()],
Expand Down Expand Up @@ -621,8 +621,9 @@ impl ArrayData {
vec![ArrayData::new_empty(v.as_ref())],
true,
),
DataType::Union(f, i, mode) => {
let ids = Buffer::from_iter(std::iter::repeat(i[0]).take(len));
DataType::Union(f, mode) => {
let (id, _) = f.iter().next().unwrap();
let ids = Buffer::from_iter(std::iter::repeat(id).take(len));
let buffers = match mode {
UnionMode::Sparse => vec![ids],
UnionMode::Dense => {
Expand All @@ -634,7 +635,7 @@ impl ArrayData {
let children = f
.iter()
.enumerate()
.map(|(idx, f)| match idx {
.map(|(idx, (_, f))| match idx {
0 => Self::new_null(f.data_type(), len),
_ => Self::new_empty(f.data_type()),
})
Expand Down Expand Up @@ -986,10 +987,10 @@ impl ArrayData {
}
Ok(())
}
DataType::Union(fields, _, mode) => {
DataType::Union(fields, mode) => {
self.validate_num_child_data(fields.len())?;

for (i, field) in fields.iter().enumerate() {
for (i, (_, field)) in fields.iter().enumerate() {
let field_data = self.get_valid_child_data(i, field.data_type())?;

if mode == &UnionMode::Sparse
Expand Down Expand Up @@ -1255,7 +1256,7 @@ impl ArrayData {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len)
}
DataType::Union(_, _, _) => {
DataType::Union(_, _) => {
// Validate Union Array as part of implementing new Union semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
Expand Down Expand Up @@ -1568,7 +1569,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::<i64>()),
DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data,
DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all in child data,
DataType::Union(_, _, mode) => {
DataType::Union(_, mode) => {
let type_ids = BufferSpec::FixedWidth {
byte_width: size_of::<i8>(),
};
Expand Down Expand Up @@ -1823,7 +1824,7 @@ impl From<ArrayData> for ArrayDataBuilder {
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::Field;
use arrow_schema::{Field, UnionFields};

// See arrow/tests/array_data_validation.rs for test of array validation

Expand Down Expand Up @@ -2072,8 +2073,8 @@ mod tests {
#[test]
fn test_into_buffers() {
let data_types = vec![
DataType::Union(vec![], vec![], UnionMode::Dense),
DataType::Union(vec![], vec![], UnionMode::Sparse),
DataType::Union(UnionFields::empty(), UnionMode::Dense),
DataType::Union(UnionFields::empty(), UnionMode::Sparse),
];

for data_type in data_types {
Expand Down
2 changes: 1 addition & 1 deletion arrow-data/src/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn equal_values(
fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int16 => {
Expand Down
Loading