diff --git a/src/array/union/mod.rs b/src/array/union/mod.rs index 32620b1a0e4..f3bcf3726cb 100644 --- a/src/array/union/mod.rs +++ b/src/array/union/mod.rs @@ -242,6 +242,14 @@ impl UnionArray { &self.types } + #[inline] + fn field(&self, type_: i8) -> &Arc { + self.fields_hash + .as_ref() + .map(|x| &x[&type_].1) + .unwrap_or_else(|| &self.fields[type_ as usize]) + } + #[inline] fn field_slot(&self, index: usize) -> usize { self.offsets() @@ -264,7 +272,10 @@ impl UnionArray { /// Returns the slot `index` as a [`Scalar`]. pub fn value(&self, index: usize) -> Box { - new_scalar(self, index) + let type_ = self.types()[index]; + let field = self.field(type_); + let index = self.field_slot(index); + new_scalar(field.as_ref(), index) } } diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index ba9a6259788..c7394af0097 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -148,13 +148,10 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } Union => { let array = array.as_any().downcast_ref::().unwrap(); - let type_id = array.types()[index]; - let (field_index, index) = array.index(index); - let field_value = new_scalar(&*array.fields()[field_index], index); Box::new(UnionScalar::new( array.data_type().clone(), - type_id, - field_value.into(), + array.types()[index], + array.value(index).into(), )) } Map => todo!(), diff --git a/src/scalar/union.rs b/src/scalar/union.rs index a381f73f552..df625c7aac9 100644 --- a/src/scalar/union.rs +++ b/src/scalar/union.rs @@ -28,6 +28,12 @@ impl UnionScalar { pub fn value(&self) -> &Arc { &self.value } + + /// Returns the type of the union scalar + #[inline] + pub fn type_(&self) -> i8 { + self.type_ + } } impl Scalar for UnionScalar { diff --git a/tests/it/array/union.rs b/tests/it/array/union.rs index 3295ef7b59b..b52b5ff4fd7 100644 --- a/tests/it/array/union.rs +++ b/tests/it/array/union.rs @@ -5,7 +5,7 @@ use arrow2::{ buffer::Buffer, datatypes::*, error::Result, - scalar::{PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}, + scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}, }; fn next_unchecked(iter: &mut I) -> T @@ -16,29 +16,11 @@ where iter.next() .unwrap() .as_any() - .downcast_ref::() - .unwrap() - .value() - .as_any() .downcast_ref::() .unwrap() .clone() } -fn assert_next_is_none(iter: &mut I) -where - I: Iterator>, -{ - assert!(!iter - .next() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .value() - .is_valid()) -} - #[test] fn sparse_debug() -> Result<()> { let fields = vec![ @@ -128,7 +110,10 @@ fn iter_sparse() -> Result<()> { next_unchecked::, _>(&mut iter).value(), Some(1) ); - assert_next_is_none(&mut iter); + assert_eq!( + next_unchecked::, _>(&mut iter).value(), + None + ); assert_eq!( next_unchecked::, _>(&mut iter).value(), Some("c") @@ -159,7 +144,10 @@ fn iter_dense() -> Result<()> { next_unchecked::, _>(&mut iter).value(), Some(1) ); - assert_next_is_none(&mut iter); + assert_eq!( + next_unchecked::, _>(&mut iter).value(), + None + ); assert_eq!( next_unchecked::, _>(&mut iter).value(), Some("c") @@ -221,3 +209,60 @@ fn iter_dense_slice() -> Result<()> { Ok(()) } + +#[test] +fn scalar() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from_slice([0, 0, 1]); + let offsets = Buffer::::from_slice([0, 1, 0]); + let fields = vec![ + Arc::new(Int32Array::from(&[Some(1), None])) as Arc, + Arc::new(Utf8Array::::from(&[Some("c")])) as Arc, + ]; + + let array = UnionArray::from_data(data_type, types, fields.clone(), Some(offsets)); + + let scalar = new_scalar(&array, 0); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some(1) + ); + assert_eq!(union_scalar.type_(), 0); + let scalar = new_scalar(&array, 1); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + None + ); + assert_eq!(union_scalar.type_(), 0); + + let scalar = new_scalar(&array, 2); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some("c") + ); + assert_eq!(union_scalar.type_(), 1); + + Ok(()) +}