Skip to content

Commit

Permalink
add lt, gt, lt_eq, gt_eq
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Oct 24, 2021
1 parent b1a6bb2 commit 84250ae
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,54 @@ macro_rules! typed_cmp {
}

macro_rules! typed_compares {
($LEFT: expr, $RIGHT: expr, $OP_PRIM: ident, $OP_STR: ident) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Int8, DataType::Int8) => {
typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
}
(DataType::Int16, DataType::Int16) => {
typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
}
(DataType::Int32, DataType::Int32) => {
typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
}
(DataType::Int64, DataType::Int64) => {
typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
}
(DataType::UInt8, DataType::UInt8) => {
typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
}
(DataType::UInt16, DataType::UInt16) => {
typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
}
(DataType::UInt32, DataType::UInt32) => {
typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
}
(DataType::UInt64, DataType::UInt64) => {
typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
}
(DataType::Float32, DataType::Float32) => {
typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
}
(DataType::Float64, DataType::Float64) => {
typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
}
(DataType::Utf8, DataType::Utf8) => {
typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
}
(DataType::LargeUtf8, DataType::LargeUtf8) => {
typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing arrays of type {} is not yet implemented",
t1
))),
(t1, t2) => Err(ArrowError::CastError(format!(
"Cannot compare two arrays of different types ({} and {})",
t1, t2
))),
}
}};
($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Boolean, DataType::Boolean) => {
Expand Down Expand Up @@ -1065,6 +1113,38 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, neq_bool, neq, neq_utf8)
}

/// Perform `left < right` operation on two (dynamic) arrays.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a downcast_ref error.
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt, lt_utf8)
}

/// Perform `left <= right` operation on two (dynamic) arrays.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a downcast_ref error.
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_eq, lt_eq_utf8)
}

/// Perform `left > right` operation on two (dynamic) arrays.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a downcast_ref error.
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt, gt_utf8)
}

/// Perform `left >= right` operation on two (dynamic) arrays.
///
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
/// with a downcast_ref error.
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_eq, gt_eq_utf8)
}

/// Perform `left == right` operation on two primitive arrays.
pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
Expand Down Expand Up @@ -1351,11 +1431,17 @@ mod tests {
/// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
/// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
macro_rules! cmp_i64 {
($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
let a = Int64Array::from($A_VEC);
let b = Int64Array::from($B_VEC);
let c = $KERNEL(&a, &b).unwrap();
assert_eq!(BooleanArray::from($EXPECTED), c);

// slice and test if the dynamic array works
let a = a.slice(0, a.len());
let b = b.slice(0, b.len());
let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap();
assert_eq!(BooleanArray::from($EXPECTED), c);
};
}

Expand All @@ -1375,6 +1461,7 @@ mod tests {
fn test_primitive_array_eq() {
cmp_i64!(
eq,
eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, false, false, false, false, true, false, false]
Expand Down Expand Up @@ -1421,6 +1508,7 @@ mod tests {
fn test_primitive_array_neq() {
cmp_i64!(
neq,
neq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, true, true, true, true, false, true, true]
Expand Down Expand Up @@ -1502,6 +1590,7 @@ mod tests {
fn test_primitive_array_lt() {
cmp_i64!(
lt,
lt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, false, true, true, false, false, false, true, true]
Expand All @@ -1522,6 +1611,7 @@ mod tests {
fn test_primitive_array_lt_nulls() {
cmp_i64!(
lt,
lt_dyn,
vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
vec![None, None, None, Some(false), None, None, None, Some(true)]
Expand All @@ -1542,6 +1632,7 @@ mod tests {
fn test_primitive_array_lt_eq() {
cmp_i64!(
lt_eq,
lt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, true, true, false, false, true, true, true]
Expand All @@ -1562,6 +1653,7 @@ mod tests {
fn test_primitive_array_lt_eq_nulls() {
cmp_i64!(
lt_eq,
lt_dyn,
vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)],
vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)],
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
Expand All @@ -1582,6 +1674,7 @@ mod tests {
fn test_primitive_array_gt() {
cmp_i64!(
gt,
gt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, false, false, true, true, false, false, false]
Expand All @@ -1602,6 +1695,7 @@ mod tests {
fn test_primitive_array_gt_nulls() {
cmp_i64!(
gt,
gt_dyn,
vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)],
vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)],
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
Expand All @@ -1622,6 +1716,7 @@ mod tests {
fn test_primitive_array_gt_eq() {
cmp_i64!(
gt_eq,
gt_eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, true, false, false, true, true, true, false, false]
Expand All @@ -1642,6 +1737,7 @@ mod tests {
fn test_primitive_array_gt_eq_nulls() {
cmp_i64!(
gt_eq,
gt_eq_dyn,
vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)],
vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)],
vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]
Expand Down

0 comments on commit 84250ae

Please sign in to comment.