From ce2497b39d2a1ebee3a2506645993d66b61e1f86 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 13 Dec 2021 19:27:00 +0000 Subject: [PATCH] DRY via macro. --- src/compute/comparison/mod.rs | 160 +++++++++------------------------- 1 file changed, 39 insertions(+), 121 deletions(-) diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index f194d9f6775..ffd603ccb3a 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -45,7 +45,7 @@ //! ``` use crate::array::*; -use crate::datatypes::{DataType, IntervalUnit}; +use crate::datatypes::DataType; use crate::scalar::*; pub mod binary; @@ -58,6 +58,28 @@ pub use simd::{Simd8, Simd8Lanes}; pub(crate) use primitive::compare_values_op as primitive_compare_values_op; +macro_rules! with_match_primitive_cmp {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => todo!(), + MonthDayNano => todo!(), + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + macro_rules! compare { ($lhs:expr, $rhs:expr, $op:tt) => {{ let lhs = $lhs; @@ -67,65 +89,18 @@ macro_rules! compare { rhs.data_type().to_logical_type() ); - use DataType::*; - let data_type = lhs.data_type().to_logical_type(); - match data_type { + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { Boolean => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); boolean::$op(lhs, rhs) } - Int8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - Int16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - UInt8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - UInt16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - UInt32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - UInt64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - Float16 => unreachable!(), - Float32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } - Float64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } + primitive::$op::<$T>(lhs, rhs) + }), Utf8 => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); @@ -136,11 +111,6 @@ macro_rules! compare { let rhs = rhs.as_any().downcast_ref().unwrap(); utf8::$op::(lhs, rhs) } - Decimal(_, _) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::$op::(lhs, rhs) - } Binary => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); @@ -151,7 +121,10 @@ macro_rules! compare { let rhs = rhs.as_any().downcast_ref().unwrap(); binary::$op::(lhs, rhs) } - _ => todo!("Comparisons of {:?} are not yet supported", data_type), + _ => todo!( + "Comparison between {:?} are not yet supported", + lhs.data_type() + ), } }}; } @@ -234,66 +207,19 @@ macro_rules! compare_scalar { return BooleanArray::new_null(DataType::Boolean, lhs.len()); } - use DataType::*; - let data_type = lhs.data_type().to_logical_type(); - match data_type { + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { Boolean => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref::().unwrap(); // validity checked above boolean::$op(lhs, rhs.value().unwrap()) } - Int8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - Int16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - UInt8 => { + Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| { let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - UInt16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - UInt32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - UInt64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - Float16 => unreachable!(), - Float32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } - Float64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::<$T>(lhs, rhs.value().unwrap()) + }), Utf8 => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref::>().unwrap(); @@ -304,14 +230,6 @@ macro_rules! compare_scalar { let rhs = rhs.as_any().downcast_ref::>().unwrap(); utf8::$op::(lhs, rhs.value().unwrap()) } - Decimal(_, _) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs - .as_any() - .downcast_ref::>() - .unwrap(); - primitive::$op::(lhs, rhs.value().unwrap()) - } Binary => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref::>().unwrap(); @@ -322,7 +240,7 @@ macro_rules! compare_scalar { let rhs = rhs.as_any().downcast_ref::>().unwrap(); binary::$op::(lhs, rhs.value().unwrap()) } - _ => todo!("Comparisons of {:?} are not yet supported", data_type), + _ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()), } }}; }