diff --git a/Cargo.toml b/Cargo.toml index 1803bea47ec..1b02c0a6010 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,7 +158,7 @@ compute_bitwise = [] compute_boolean = [] compute_boolean_kleene = [] compute_cast = ["lexical-core", "compute_take"] -compute_comparison = [] +compute_comparison = ["compute_take"] compute_concatenate = [] compute_contains = [] compute_filter = [] diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 836eab497b4..655567e876c 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -56,6 +56,7 @@ pub mod utf8; mod simd; pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use super::take::take_boolean; pub(crate) use primitive::{ compare_values_op as primitive_compare_values_op, compare_values_op_scalar as primitive_compare_values_op_scalar, @@ -166,6 +167,11 @@ pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, eq, match_eq) } +/// Returns whether a [`DataType`] is comparable is supported by [`eq`]. +pub fn can_eq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + /// `!=` between two [`Array`]s. /// Use [`can_neq`] to check whether the operation is valid /// # Panic @@ -177,6 +183,11 @@ pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, neq, match_eq) } +/// Returns whether a [`DataType`] is comparable is supported by [`neq`]. +pub fn can_neq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + /// `<` between two [`Array`]s. /// Use [`can_lt`] to check whether the operation is valid /// # Panic @@ -188,6 +199,11 @@ pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, lt, match_eq_ord) } +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + /// `<=` between two [`Array`]s. /// Use [`can_lt_eq`] to check whether the operation is valid /// # Panic @@ -199,6 +215,11 @@ pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, lt_eq, match_eq_ord) } +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + /// `>` between two [`Array`]s. /// Use [`can_gt`] to check whether the operation is valid /// # Panic @@ -210,6 +231,11 @@ pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, gt, match_eq_ord) } +/// Returns whether a [`DataType`] is comparable is supported by [`gt`]. +pub fn can_gt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + /// `>=` between two [`Array`]s. /// Use [`can_gt_eq`] to check whether the operation is valid /// # Panic @@ -221,6 +247,11 @@ pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { compare!(lhs, rhs, gt_eq, match_eq_ord) } +/// Returns whether a [`DataType`] is comparable is supported by [`gt_eq`]. +pub fn can_gt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + macro_rules! compare_scalar { ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ let lhs = $lhs; @@ -266,13 +297,21 @@ macro_rules! compare_scalar { let rhs = rhs.as_any().downcast_ref::>().unwrap(); binary::$op::(lhs, rhs.value().unwrap()) } + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let values = $op(lhs.values().as_ref(), rhs); + + take_boolean(&values, lhs.keys()) + }) + } _ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()), } }}; } /// `==` between an [`Array`] and a [`Scalar`]. -/// Use [`can_eq`] to check whether the operation is valid +/// Use [`can_eq_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -281,8 +320,13 @@ pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, eq_scalar, match_eq) } +/// Returns whether a [`DataType`] is supported by [`eq_scalar`]. +pub fn can_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + /// `!=` between an [`Array`] and a [`Scalar`]. -/// Use [`can_neq`] to check whether the operation is valid +/// Use [`can_neq_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -291,8 +335,13 @@ pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, neq_scalar, match_eq) } +/// Returns whether a [`DataType`] is supported by [`neq_scalar`]. +pub fn can_neq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + /// `<` between an [`Array`] and a [`Scalar`]. -/// Use [`can_lt`] to check whether the operation is valid +/// Use [`can_lt_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -301,8 +350,13 @@ pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord) } +/// Returns whether a [`DataType`] is supported by [`lt_scalar`]. +pub fn can_lt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + /// `<=` between an [`Array`] and a [`Scalar`]. -/// Use [`can_lt_eq`] to check whether the operation is valid +/// Use [`can_lt_eq_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -311,8 +365,13 @@ pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord) } +/// Returns whether a [`DataType`] is supported by [`lt_eq_scalar`]. +pub fn can_lt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + /// `>` between an [`Array`] and a [`Scalar`]. -/// Use [`can_gt`] to check whether the operation is valid +/// Use [`can_gt_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -321,8 +380,13 @@ pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord) } +/// Returns whether a [`DataType`] is supported by [`gt_scalar`]. +pub fn can_gt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + /// `>=` between an [`Array`] and a [`Scalar`]. -/// Use [`can_gt_eq`] to check whether the operation is valid +/// Use [`can_gt_eq_scalar`] to check whether the operation is valid /// # Panic /// Panics iff either: /// * they do not have have the same logical type @@ -331,33 +395,16 @@ pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord) } -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_eq(data_type: &DataType) -> bool { - can_partial_eq(data_type) -} - -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_neq(data_type: &DataType) -> bool { - can_partial_eq(data_type) +/// Returns whether a [`DataType`] is supported by [`gt_eq_scalar`]. +pub fn can_gt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_lt(data_type: &DataType) -> bool { - can_partial_eq_and_ord(data_type) -} - -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_lt_eq(data_type: &DataType) -> bool { - can_partial_eq_and_ord(data_type) -} - -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_gt(data_type: &DataType) -> bool { - can_partial_eq_and_ord(data_type) -} - -/// Returns whether a [`DataType`] is comparable (either array or scalar). -pub fn can_gt_eq(data_type: &DataType) -> bool { +// The list of operations currently supported. +fn can_partial_eq_and_ord_scalar(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_partial_eq_and_ord_scalar(values.as_ref()); + } can_partial_eq_and_ord(data_type) } @@ -400,3 +447,13 @@ fn can_partial_eq(data_type: &DataType) -> bool { | DataType::Interval(IntervalUnit::MonthDayNano) ) } + +// The list of operations currently supported. +fn can_partial_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index b127ddd5c4a..3d71e098d0a 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -33,6 +33,8 @@ mod primitive; mod structure; mod utf8; +pub(crate) use boolean::take as take_boolean; + /// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. /// The returned array has a length equal to `indices.len()`. pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result> { diff --git a/src/types/index.rs b/src/types/index.rs index ae040ec7709..b44b3957e79 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -11,10 +11,10 @@ pub trait Index: + std::ops::AddAssign + std::ops::Sub + num_traits::One - + PartialOrd + num_traits::Num - + Ord + num_traits::CheckedAdd + + PartialOrd + + Ord { /// Convert itself to [`usize`]. fn to_usize(&self) -> usize; @@ -32,53 +32,30 @@ pub trait Index: } } -impl Index for i32 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} - -impl Index for i64 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} - -impl Index for u32 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } +macro_rules! index { + ($t:ty) => { + impl Index for $t { + #[inline] + fn to_usize(&self) -> usize { + *self as usize + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } + } + }; } -impl Index for u64 { - #[inline] - fn to_usize(&self) -> usize { - *self as usize - } - - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } -} +index!(i8); +index!(i16); +index!(i32); +index!(i64); +index!(u8); +index!(u16); +index!(u32); +index!(u64); /// Range of [`Index`], equivalent to `(a..b)`. /// `Step` is unstable in Rust, which does not allow us to implement (a..b) for [`Index`]. diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 1c89cc84620..225a3b17855 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -1,7 +1,7 @@ use arrow2::array::*; use arrow2::compute::comparison::boolean::*; -use arrow2::datatypes::TimeUnit; use arrow2::datatypes::{DataType::*, IntervalUnit}; +use arrow2::datatypes::{IntegerType, TimeUnit}; use arrow2::scalar::new_scalar; #[test] @@ -41,6 +41,7 @@ fn consistency() { Duration(TimeUnit::Millisecond), Duration(TimeUnit::Microsecond), Duration(TimeUnit::Nanosecond), + Dictionary(IntegerType::Int32, Box::new(LargeBinary), false), ]; // array <> array @@ -58,10 +59,10 @@ fn consistency() { datatypes.into_iter().for_each(|d1| { let array = new_null_array(d1.clone(), 10); let scalar = new_scalar(array.as_ref(), 0); - if can_eq(&d1) { + if can_eq_scalar(&d1) { eq_scalar(array.as_ref(), scalar.as_ref()); } - if can_lt_eq(&d1) { + if can_lt_eq_scalar(&d1) { lt_eq(array.as_ref(), array.as_ref()); } });