From a1b9bf1f5c65ad8e96e775f74ad16f7745f18896 Mon Sep 17 00:00:00 2001 From: baishen Date: Wed, 8 Mar 2023 02:44:37 +0800 Subject: [PATCH] Added MapScalar (#1428) --- src/scalar/equal.rs | 2 +- src/scalar/map.rs | 66 ++++++++++++++++++++++++++++++++++++++++ src/scalar/mod.rs | 12 +++++++- tests/it/scalar/map.rs | 68 ++++++++++++++++++++++++++++++++++++++++++ tests/it/scalar/mod.rs | 1 + 5 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 src/scalar/map.rs create mode 100644 tests/it/scalar/map.rs diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index 3e7e2a620b9..dcb3c836be5 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -54,6 +54,6 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), Union => dyn_eq!(UnionScalar, lhs, rhs), - Map => unimplemented!("{:?}", Map), + Map => dyn_eq!(MapScalar, lhs, rhs), } } diff --git a/src/scalar/map.rs b/src/scalar/map.rs new file mode 100644 index 00000000000..3901dd0a18f --- /dev/null +++ b/src/scalar/map.rs @@ -0,0 +1,66 @@ +use std::any::Any; + +use crate::{array::*, datatypes::DataType}; + +use super::Scalar; + +/// The scalar equivalent of [`MapArray`]. Like [`MapArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct MapScalar { + values: Box, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for MapScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl MapScalar { + /// returns a new [`MapScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `Map` + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_field = MapArray::try_get_field(&data_type).unwrap(); + let inner_data_type = inner_field.data_type(); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + } + None => (false, new_empty_array(inner_data_type.clone())), + }; + Self { + values, + is_valid, + data_type, + } + } + + /// The values of the [`MapScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for MapScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index 41d32aaa957..e3404e4eaa2 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -17,6 +17,8 @@ mod boolean; pub use boolean::*; mod list; pub use list::*; +mod map; +pub use map::*; mod null; pub use null::*; mod struct_; @@ -156,7 +158,15 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { array.value(index), )) } - Map => todo!(), + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(MapScalar::new(array.data_type().clone(), value)) + } Dictionary(key_type) => match_integer_type!(key_type, |$T| { let array = array .as_any() diff --git a/tests/it/scalar/map.rs b/tests/it/scalar/map.rs new file mode 100644 index 00000000000..1a232a5049c --- /dev/null +++ b/tests/it/scalar/map.rs @@ -0,0 +1,68 @@ +use arrow2::{ + array::{BooleanArray, StructArray, Utf8Array}, + datatypes::{DataType, Field}, + scalar::{MapScalar, Scalar}, +}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let kv_dt = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Boolean, true), + ]); + let kv_array1 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + let kv_array2 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k3")]).boxed(), + BooleanArray::from_slice([true, true]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = DataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array1))); + let b = MapScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = MapScalar::new(dt, Some(Box::new(kv_array2))); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let kv_dt = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Boolean, true), + ]); + let kv_array = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = DataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array.clone()))); + + assert_eq!(kv_array, a.values().as_ref()); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/tests/it/scalar/mod.rs b/tests/it/scalar/mod.rs index 9f89d41f863..5dd1568f6d1 100644 --- a/tests/it/scalar/mod.rs +++ b/tests/it/scalar/mod.rs @@ -3,6 +3,7 @@ mod boolean; mod fixed_size_binary; mod fixed_size_list; mod list; +mod map; mod null; mod primitive; mod struct_;