diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index ffd603ccb3a..8da298375bf 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -1,7 +1,7 @@ -//! Basic comparison kernels. +//! Contains comparison operators //! -//! The module contains functions that compare either an array and a scalar -//! or two arrays of the same [`DataType`]. The scalar-oriented functions are +//! The module contains functions that compare either an [`Array`] and a [`Scalar`] +//! or two [`Array`]s (of the same [`DataType`]). The scalar-oriented functions are //! suffixed with `_scalar`. //! //! The functions are organized in two variants: @@ -45,7 +45,7 @@ //! ``` use crate::array::*; -use crate::datatypes::DataType; +use crate::datatypes::{DataType, IntervalUnit}; use crate::scalar::*; pub mod binary; @@ -54,11 +54,11 @@ pub mod primitive; pub mod utf8; mod simd; -pub use simd::{Simd8, Simd8Lanes}; +pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; pub(crate) use primitive::compare_values_op as primitive_compare_values_op; -macro_rules! with_match_primitive_cmp {( +macro_rules! match_eq_ord {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* ) => ({ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} @@ -80,8 +80,31 @@ macro_rules! with_match_primitive_cmp {( } })} +macro_rules! match_eq {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + macro_rules! compare { - ($lhs:expr, $rhs:expr, $op:tt) => {{ + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ let lhs = $lhs; let rhs = $rhs; assert_eq!( @@ -96,7 +119,7 @@ macro_rules! compare { let rhs = rhs.as_any().downcast_ref().unwrap(); boolean::$op(lhs, rhs) } - Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| { + Primitive(primitive) => $p!(primitive, |$T| { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); primitive::$op::<$T>(lhs, rhs) @@ -137,7 +160,7 @@ macro_rules! compare { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, eq) + compare!(lhs, rhs, eq, match_eq) } /// `!=` between two [`Array`]s. @@ -148,7 +171,7 @@ pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, neq) + compare!(lhs, rhs, neq, match_eq) } /// `<` between two [`Array`]s. @@ -159,7 +182,7 @@ pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, lt) + compare!(lhs, rhs, lt, match_eq_ord) } /// `<=` between two [`Array`]s. @@ -170,7 +193,7 @@ pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, lt_eq) + compare!(lhs, rhs, lt_eq, match_eq_ord) } /// `>` between two [`Array`]s. @@ -181,7 +204,7 @@ pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, gt) + compare!(lhs, rhs, gt, match_eq_ord) } /// `>=` between two [`Array`]s. @@ -192,11 +215,11 @@ pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { /// * the arrays do not have the same length /// * the operation is not supported for the logical type pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { - compare!(lhs, rhs, gt_eq) + compare!(lhs, rhs, gt_eq, match_eq_ord) } macro_rules! compare_scalar { - ($lhs:expr, $rhs:expr, $op:tt) => {{ + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ let lhs = $lhs; let rhs = $rhs; assert_eq!( @@ -215,7 +238,7 @@ macro_rules! compare_scalar { // validity checked above boolean::$op(lhs, rhs.value().unwrap()) } - Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| { + Primitive(primitive) => $p!(primitive, |$T| { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref::>().unwrap(); primitive::$op::<$T>(lhs, rhs.value().unwrap()) @@ -252,7 +275,7 @@ macro_rules! compare_scalar { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, eq_scalar) + compare_scalar!(lhs, rhs, eq_scalar, match_eq) } /// `!=` between an [`Array`] and a [`Scalar`]. @@ -262,7 +285,7 @@ pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, neq_scalar) + compare_scalar!(lhs, rhs, neq_scalar, match_eq) } /// `<` between an [`Array`] and a [`Scalar`]. @@ -272,7 +295,7 @@ pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, lt_scalar) + compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord) } /// `<=` between an [`Array`] and a [`Scalar`]. @@ -282,7 +305,7 @@ pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, lt_eq_scalar) + compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord) } /// `>` between an [`Array`] and a [`Scalar`]. @@ -292,7 +315,7 @@ pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, gt_scalar) + compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord) } /// `>=` between an [`Array`] and a [`Scalar`]. @@ -302,41 +325,41 @@ pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { /// * they do not have have the same logical type /// * the operation is not supported for the logical type pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { - compare_scalar!(lhs, rhs, gt_eq_scalar) + compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_eq(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_neq(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_lt(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq_and_ord(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_lt_eq(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq_and_ord(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_gt(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq_and_ord(data_type) } -/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +/// Returns whether a [`DataType`] is comparable (either array or scalar). pub fn can_gt_eq(data_type: &DataType) -> bool { - can_compare(data_type) + can_partial_eq_and_ord(data_type) } // The list of operations currently supported. -fn can_compare(data_type: &DataType) -> bool { +fn can_partial_eq_and_ord(data_type: &DataType) -> bool { matches!( data_type, DataType::Boolean @@ -345,7 +368,7 @@ fn can_compare(data_type: &DataType) -> bool { | DataType::Int32 | DataType::Date32 | DataType::Time32(_) - | DataType::Interval(_) + | DataType::Interval(IntervalUnit::YearMonth) | DataType::Int64 | DataType::Timestamp(_, _) | DataType::Date64 @@ -364,3 +387,13 @@ fn can_compare(data_type: &DataType) -> bool { | DataType::LargeBinary ) } + +// The list of operations currently supported. +fn can_partial_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} diff --git a/src/compute/comparison/primitive.rs b/src/compute/comparison/primitive.rs index 257db4ba4c7..c7949379634 100644 --- a/src/compute/comparison/primitive.rs +++ b/src/compute/comparison/primitive.rs @@ -7,7 +7,7 @@ use crate::{ }; use super::super::utils::combine_validities; -use super::simd::{Simd8, Simd8Lanes}; +use super::simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; pub(crate) fn compare_values_op(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap where @@ -87,6 +87,7 @@ where pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialEq, { compare_op(lhs, rhs, |a, b| a.eq(b)) } @@ -95,6 +96,7 @@ where pub fn eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialEq, { compare_op_scalar(lhs, rhs, |a, b| a.eq(b)) } @@ -103,6 +105,7 @@ where pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialEq, { compare_op(lhs, rhs, |a, b| a.neq(b)) } @@ -111,6 +114,7 @@ where pub fn neq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialEq, { compare_op_scalar(lhs, rhs, |a, b| a.neq(b)) } @@ -119,6 +123,7 @@ where pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op(lhs, rhs, |a, b| a.lt(b)) } @@ -127,6 +132,7 @@ where pub fn lt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op_scalar(lhs, rhs, |a, b| a.lt(b)) } @@ -135,6 +141,7 @@ where pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op(lhs, rhs, |a, b| a.lt_eq(b)) } @@ -144,6 +151,7 @@ where pub fn lt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b)) } @@ -153,6 +161,7 @@ where pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op(lhs, rhs, |a, b| a.gt(b)) } @@ -162,6 +171,7 @@ where pub fn gt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op_scalar(lhs, rhs, |a, b| a.gt(b)) } @@ -171,6 +181,7 @@ where pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op(lhs, rhs, |a, b| a.gt_eq(b)) } @@ -180,6 +191,7 @@ where pub fn gt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray where T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, { compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) } diff --git a/src/compute/comparison/simd/mod.rs b/src/compute/comparison/simd/mod.rs index ec85c53f9d0..30d9773cd4c 100644 --- a/src/compute/comparison/simd/mod.rs +++ b/src/compute/comparison/simd/mod.rs @@ -12,10 +12,18 @@ pub trait Simd8Lanes: Copy { fn from_chunk(v: &[T]) -> Self; /// loads an incomplete chunk, filling the remaining items with `remaining`. fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialEq]. +pub trait Simd8PartialEq: Copy { /// Equal fn eq(self, other: Self) -> u8; /// Not equal fn neq(self, other: Self) -> u8; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialOrd]. +pub trait Simd8PartialOrd: Copy { /// Less than or equal to fn lt_eq(self, other: Self) -> u8; /// Less than @@ -38,6 +46,7 @@ pub(super) fn set bool>(lhs: [T; 8], rhs: [T; 8], op: F) byte } +/// Types that implement Simd8 macro_rules! simd8_native { ($type:ty) => { impl Simd8 for $type { @@ -56,7 +65,14 @@ macro_rules! simd8_native { a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); a } + } + }; +} +/// Types that implement PartialEq +macro_rules! simd8_native_partial_eq { + ($type:ty) => { + impl Simd8PartialEq for [$type; 8] { #[inline] fn eq(self, other: Self) -> u8 { set(self, other, |x, y| x == y) @@ -67,7 +83,14 @@ macro_rules! simd8_native { #[allow(clippy::float_cmp)] set(self, other, |x, y| x != y) } + } + }; +} +/// Types that implement PartialOrd +macro_rules! simd8_native_partial_ord { + ($type:ty) => { + impl Simd8PartialOrd for [$type; 8] { #[inline] fn lt_eq(self, other: Self) -> u8 { set(self, other, |x, y| x <= y) @@ -91,6 +114,15 @@ macro_rules! simd8_native { }; } +/// Types that implement simd8, PartialEq and PartialOrd +macro_rules! simd8_native_all { + ($type:ty) => { + simd8_native! {$type} + simd8_native_partial_eq! {$type} + simd8_native_partial_ord! {$type} + }; +} + #[cfg(not(feature = "simd"))] mod native; #[cfg(not(feature = "simd"))] diff --git a/src/compute/comparison/simd/native.rs b/src/compute/comparison/simd/native.rs index bc3e820e60e..a4a760bf1e9 100644 --- a/src/compute/comparison/simd/native.rs +++ b/src/compute/comparison/simd/native.rs @@ -1,15 +1,20 @@ use std::convert::TryInto; -use super::{set, Simd8, Simd8Lanes}; +use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use crate::types::{days_ms, months_days_ns}; -simd8_native!(u8); -simd8_native!(u16); -simd8_native!(u32); -simd8_native!(u64); -simd8_native!(i8); -simd8_native!(i16); -simd8_native!(i32); -simd8_native!(i128); -simd8_native!(i64); -simd8_native!(f32); -simd8_native!(f64); +simd8_native_all!(u8); +simd8_native_all!(u16); +simd8_native_all!(u32); +simd8_native_all!(u64); +simd8_native_all!(i8); +simd8_native_all!(i16); +simd8_native_all!(i32); +simd8_native_all!(i128); +simd8_native_all!(i64); +simd8_native_all!(f32); +simd8_native_all!(f64); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/src/compute/comparison/simd/packed.rs b/src/compute/comparison/simd/packed.rs index 7974e10490c..f12e0b3b78f 100644 --- a/src/compute/comparison/simd/packed.rs +++ b/src/compute/comparison/simd/packed.rs @@ -1,9 +1,11 @@ use std::convert::TryInto; -use super::{set, Simd8, Simd8Lanes}; - use packed_simd::*; +use crate::types::{days_ms, months_days_ns}; + +use super::*; + macro_rules! simd8 { ($type:ty, $md:ty) => { impl Simd8 for $type { @@ -22,7 +24,9 @@ macro_rules! simd8 { a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); Self::from_chunk(a.as_ref()) } + } + impl Simd8PartialEq<$type> for $md { #[inline] fn eq(self, other: Self) -> u8 { self.eq(other).bitmask() @@ -64,6 +68,10 @@ simd8!(i8, i8x8); simd8!(i16, i16x8); simd8!(i32, i32x8); simd8!(i64, i64x8); -simd8_native!(i128); +simd8_native_all!(i128); simd8!(f32, f32x8); simd8!(f64, f64x8); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/src/compute/nullif.rs b/src/compute/nullif.rs index ffeec517d86..a83e096806c 100644 --- a/src/compute/nullif.rs +++ b/src/compute/nullif.rs @@ -1,7 +1,7 @@ //! Contains the operator [`nullif`]. use crate::array::PrimitiveArray; use crate::bitmap::Bitmap; -use crate::compute::comparison::{primitive_compare_values_op, Simd8, Simd8Lanes}; +use crate::compute::comparison::{primitive_compare_values_op, Simd8, Simd8PartialEq}; use crate::compute::utils::check_same_type; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; @@ -32,10 +32,14 @@ use super::utils::combine_validities; /// This function errors iff /// * The arguments do not have the same logical type /// * The arguments do not have the same length -pub fn nullif_primitive( +pub fn nullif_primitive( lhs: &PrimitiveArray, rhs: &PrimitiveArray, -) -> Result> { +) -> Result> +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ check_same_type(lhs, rhs)?; let equal = primitive_compare_values_op(lhs.values(), rhs.values(), |lhs, rhs| lhs.neq(rhs)); diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 0fdb3e35532..1c89cc84620 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -1,7 +1,7 @@ use arrow2::array::*; use arrow2::compute::comparison::boolean::*; -use arrow2::datatypes::DataType::*; use arrow2::datatypes::TimeUnit; +use arrow2::datatypes::{DataType::*, IntervalUnit}; use arrow2::scalar::new_scalar; #[test] @@ -20,6 +20,9 @@ fn consistency() { Int64, Float32, Float64, + Interval(IntervalUnit::YearMonth), + Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::DayTime), Timestamp(TimeUnit::Second, None), Timestamp(TimeUnit::Millisecond, None), Timestamp(TimeUnit::Microsecond, None), @@ -46,6 +49,9 @@ fn consistency() { if can_eq(&d1) { eq(array.as_ref(), array.as_ref()); } + if can_lt_eq(&d1) { + lt_eq(array.as_ref(), array.as_ref()); + } }); // array <> scalar @@ -55,6 +61,9 @@ fn consistency() { if can_eq(&d1) { eq_scalar(array.as_ref(), scalar.as_ref()); } + if can_lt_eq(&d1) { + lt_eq(array.as_ref(), array.as_ref()); + } }); }