Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added support to compare intervals #746

Merged
merged 1 commit into from
Jan 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Basic comparison kernels.
//! Contains comparison operators
//!
//! The module contains functions that compare either an array and a scalar
//! or two arrays of the same [`DataType`]. The scalar-oriented functions are
//! The module contains functions that compare either an [`Array`] and a [`Scalar`]
//! or two [`Array`]s (of the same [`DataType`]). The scalar-oriented functions are
//! suffixed with `_scalar`.
//!
//! The functions are organized in two variants:
Expand Down Expand Up @@ -45,7 +45,7 @@
//! ```
use crate::array::*;
use crate::datatypes::DataType;
use crate::datatypes::{DataType, IntervalUnit};
use crate::scalar::*;

pub mod binary;
Expand All @@ -54,11 +54,11 @@ pub mod primitive;
pub mod utf8;

mod simd;
pub use simd::{Simd8, Simd8Lanes};
pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

pub(crate) use primitive::compare_values_op as primitive_compare_values_op;

macro_rules! with_match_primitive_cmp {(
macro_rules! match_eq_ord {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
Expand All @@ -80,8 +80,31 @@ macro_rules! with_match_primitive_cmp {(
}
})}

macro_rules! match_eq {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}

macro_rules! compare {
($lhs:expr, $rhs:expr, $op:tt) => {{
($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{
let lhs = $lhs;
let rhs = $rhs;
assert_eq!(
Expand All @@ -96,7 +119,7 @@ macro_rules! compare {
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::$op(lhs, rhs)
}
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
Primitive(primitive) => $p!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<$T>(lhs, rhs)
Expand Down Expand Up @@ -137,7 +160,7 @@ macro_rules! compare {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, eq)
compare!(lhs, rhs, eq, match_eq)
}

/// `!=` between two [`Array`]s.
Expand All @@ -148,7 +171,7 @@ pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, neq)
compare!(lhs, rhs, neq, match_eq)
}

/// `<` between two [`Array`]s.
Expand All @@ -159,7 +182,7 @@ pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, lt)
compare!(lhs, rhs, lt, match_eq_ord)
}

/// `<=` between two [`Array`]s.
Expand All @@ -170,7 +193,7 @@ pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, lt_eq)
compare!(lhs, rhs, lt_eq, match_eq_ord)
}

/// `>` between two [`Array`]s.
Expand All @@ -181,7 +204,7 @@ pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, gt)
compare!(lhs, rhs, gt, match_eq_ord)
}

/// `>=` between two [`Array`]s.
Expand All @@ -192,11 +215,11 @@ pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, gt_eq)
compare!(lhs, rhs, gt_eq, match_eq_ord)
}

macro_rules! compare_scalar {
($lhs:expr, $rhs:expr, $op:tt) => {{
($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{
let lhs = $lhs;
let rhs = $rhs;
assert_eq!(
Expand All @@ -215,7 +238,7 @@ macro_rules! compare_scalar {
// validity checked above
boolean::$op(lhs, rhs.value().unwrap())
}
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
Primitive(primitive) => $p!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<$T>>().unwrap();
primitive::$op::<$T>(lhs, rhs.value().unwrap())
Expand Down Expand Up @@ -252,7 +275,7 @@ macro_rules! compare_scalar {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, eq_scalar)
compare_scalar!(lhs, rhs, eq_scalar, match_eq)
}

/// `!=` between an [`Array`] and a [`Scalar`].
Expand All @@ -262,7 +285,7 @@ pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, neq_scalar)
compare_scalar!(lhs, rhs, neq_scalar, match_eq)
}

/// `<` between an [`Array`] and a [`Scalar`].
Expand All @@ -272,7 +295,7 @@ pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, lt_scalar)
compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord)
}

/// `<=` between an [`Array`] and a [`Scalar`].
Expand All @@ -282,7 +305,7 @@ pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, lt_eq_scalar)
compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord)
}

/// `>` between an [`Array`] and a [`Scalar`].
Expand All @@ -292,7 +315,7 @@ pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, gt_scalar)
compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord)
}

/// `>=` between an [`Array`] and a [`Scalar`].
Expand All @@ -302,41 +325,41 @@ pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, gt_eq_scalar)
compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_neq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_lt(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_lt_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_gt(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_gt_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

// The list of operations currently supported.
fn can_compare(data_type: &DataType) -> bool {
fn can_partial_eq_and_ord(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Boolean
Expand All @@ -345,7 +368,7 @@ fn can_compare(data_type: &DataType) -> bool {
| DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(_)
| DataType::Interval(IntervalUnit::YearMonth)
| DataType::Int64
| DataType::Timestamp(_, _)
| DataType::Date64
Expand All @@ -364,3 +387,13 @@ fn can_compare(data_type: &DataType) -> bool {
| DataType::LargeBinary
)
}

// The list of operations currently supported.
fn can_partial_eq(data_type: &DataType) -> bool {
can_partial_eq_and_ord(data_type)
|| matches!(
data_type.to_logical_type(),
DataType::Interval(IntervalUnit::DayTime)
| DataType::Interval(IntervalUnit::MonthDayNano)
)
}
14 changes: 13 additions & 1 deletion src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};

use super::super::utils::combine_validities;
use super::simd::{Simd8, Simd8Lanes};
use super::simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

pub(crate) fn compare_values_op<T, F>(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap
where
Expand Down Expand Up @@ -87,6 +87,7 @@ where
pub fn eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op(lhs, rhs, |a, b| a.eq(b))
}
Expand All @@ -95,6 +96,7 @@ where
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}
Expand All @@ -103,6 +105,7 @@ where
pub fn neq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op(lhs, rhs, |a, b| a.neq(b))
}
Expand All @@ -111,6 +114,7 @@ where
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}
Expand All @@ -119,6 +123,7 @@ where
pub fn lt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.lt(b))
}
Expand All @@ -127,6 +132,7 @@ where
pub fn lt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.lt(b))
}
Expand All @@ -135,6 +141,7 @@ where
pub fn lt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.lt_eq(b))
}
Expand All @@ -144,6 +151,7 @@ where
pub fn lt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b))
}
Expand All @@ -153,6 +161,7 @@ where
pub fn gt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.gt(b))
}
Expand All @@ -162,6 +171,7 @@ where
pub fn gt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.gt(b))
}
Expand All @@ -171,6 +181,7 @@ where
pub fn gt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.gt_eq(b))
}
Expand All @@ -180,6 +191,7 @@ where
pub fn gt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b))
}
Expand Down
Loading