Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
address review feedback + rebase + add isolated test for union scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
ncpenke committed Apr 29, 2022
1 parent 45935ed commit a9ca398
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 27 deletions.
13 changes: 12 additions & 1 deletion src/array/union/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ impl UnionArray {
&self.types
}

#[inline]
fn field(&self, type_: i8) -> &Arc<dyn Array> {
self.fields_hash
.as_ref()
.map(|x| &x[&type_].1)
.unwrap_or_else(|| &self.fields[type_ as usize])
}

#[inline]
fn field_slot(&self, index: usize) -> usize {
self.offsets()
Expand All @@ -264,7 +272,10 @@ impl UnionArray {

/// Returns the slot `index` as a [`Scalar`].
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
new_scalar(self, index)
let type_ = self.types()[index];
let field = self.field(type_);
let index = self.field_slot(index);
new_scalar(field.as_ref(), index)
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,10 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box<dyn Scalar> {
}
Union => {
let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
let type_id = array.types()[index];
let (field_index, index) = array.index(index);
let field_value = new_scalar(&*array.fields()[field_index], index);
Box::new(UnionScalar::new(
array.data_type().clone(),
type_id,
field_value.into(),
array.types()[index],
array.value(index).into(),
))
}
Map => todo!(),
Expand Down
6 changes: 6 additions & 0 deletions src/scalar/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ impl UnionScalar {
pub fn value(&self) -> &Arc<dyn Scalar> {
&self.value
}

/// Returns the type of the union scalar
#[inline]
pub fn type_(&self) -> i8 {
self.type_
}
}

impl Scalar for UnionScalar {
Expand Down
87 changes: 66 additions & 21 deletions tests/it/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow2::{
buffer::Buffer,
datatypes::*,
error::Result,
scalar::{PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar},
scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar},
};

fn next_unchecked<T, I>(iter: &mut I) -> T
Expand All @@ -16,29 +16,11 @@ where
iter.next()
.unwrap()
.as_any()
.downcast_ref::<UnionScalar>()
.unwrap()
.value()
.as_any()
.downcast_ref::<T>()
.unwrap()
.clone()
}

fn assert_next_is_none<I>(iter: &mut I)
where
I: Iterator<Item = Box<dyn Scalar>>,
{
assert!(!iter
.next()
.unwrap()
.as_any()
.downcast_ref::<UnionScalar>()
.unwrap()
.value()
.is_valid())
}

#[test]
fn sparse_debug() -> Result<()> {
let fields = vec![
Expand Down Expand Up @@ -128,7 +110,10 @@ fn iter_sparse() -> Result<()> {
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(1)
);
assert_next_is_none(&mut iter);
assert_eq!(
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
None
);
assert_eq!(
next_unchecked::<Utf8Scalar<i32>, _>(&mut iter).value(),
Some("c")
Expand Down Expand Up @@ -159,7 +144,10 @@ fn iter_dense() -> Result<()> {
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(1)
);
assert_next_is_none(&mut iter);
assert_eq!(
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
None
);
assert_eq!(
next_unchecked::<Utf8Scalar<i32>, _>(&mut iter).value(),
Some("c")
Expand Down Expand Up @@ -221,3 +209,60 @@ fn iter_dense_slice() -> Result<()> {

Ok(())
}

#[test]
fn scalar() -> Result<()> {
let fields = vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, UnionMode::Dense);
let types = Buffer::from_slice([0, 0, 1]);
let offsets = Buffer::<i32>::from_slice([0, 1, 0]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None])) as Arc<dyn Array>,
Arc::new(Utf8Array::<i32>::from(&[Some("c")])) as Arc<dyn Array>,
];

let array = UnionArray::from_data(data_type, types, fields.clone(), Some(offsets));

let scalar = new_scalar(&array, 0);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
Some(1)
);
assert_eq!(union_scalar.type_(), 0);
let scalar = new_scalar(&array, 1);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
None
);
assert_eq!(union_scalar.type_(), 0);

let scalar = new_scalar(&array, 2);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<Utf8Scalar<i32>>()
.unwrap()
.value(),
Some("c")
);
assert_eq!(union_scalar.type_(), 1);

Ok(())
}

0 comments on commit a9ca398

Please sign in to comment.