Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
DRY via macro.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Dec 13, 2021
1 parent 89921d3 commit ce2497b
Showing 1 changed file with 39 additions and 121 deletions.
160 changes: 39 additions & 121 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
//! ```
use crate::array::*;
use crate::datatypes::{DataType, IntervalUnit};
use crate::datatypes::DataType;
use crate::scalar::*;

pub mod binary;
Expand All @@ -58,6 +58,28 @@ pub use simd::{Simd8, Simd8Lanes};

pub(crate) use primitive::compare_values_op as primitive_compare_values_op;

macro_rules! with_match_primitive_cmp {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => todo!(),
MonthDayNano => todo!(),
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}

macro_rules! compare {
($lhs:expr, $rhs:expr, $op:tt) => {{
let lhs = $lhs;
Expand All @@ -67,65 +89,18 @@ macro_rules! compare {
rhs.data_type().to_logical_type()
);

use DataType::*;
let data_type = lhs.data_type().to_logical_type();
match data_type {
use crate::datatypes::PhysicalType::*;
match lhs.data_type().to_physical_type() {
Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::$op(lhs, rhs)
}
Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<i8>(lhs, rhs)
}
Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<i16>(lhs, rhs)
}
Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<i32>(lhs, rhs)
}
Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<i64>(lhs, rhs)
}
UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<u8>(lhs, rhs)
}
UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<u16>(lhs, rhs)
}
UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<u32>(lhs, rhs)
}
UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<u64>(lhs, rhs)
}
Float16 => unreachable!(),
Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<f32>(lhs, rhs)
}
Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<f64>(lhs, rhs)
}
primitive::$op::<$T>(lhs, rhs)
}),
Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Expand All @@ -136,11 +111,6 @@ macro_rules! compare {
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::$op::<i64>(lhs, rhs)
}
Decimal(_, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<i128>(lhs, rhs)
}
Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Expand All @@ -151,7 +121,10 @@ macro_rules! compare {
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::$op::<i64>(lhs, rhs)
}
_ => todo!("Comparisons of {:?} are not yet supported", data_type),
_ => todo!(
"Comparison between {:?} are not yet supported",
lhs.data_type()
),
}
}};
}
Expand Down Expand Up @@ -234,66 +207,19 @@ macro_rules! compare_scalar {
return BooleanArray::new_null(DataType::Boolean, lhs.len());
}

use DataType::*;
let data_type = lhs.data_type().to_logical_type();
match data_type {
use crate::datatypes::PhysicalType::*;
match lhs.data_type().to_physical_type() {
Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
// validity checked above
boolean::$op(lhs, rhs.value().unwrap())
}
Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i8>>().unwrap();
primitive::$op::<i8>(lhs, rhs.value().unwrap())
}
Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i16>>().unwrap();
primitive::$op::<i16>(lhs, rhs.value().unwrap())
}
Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i32>>().unwrap();
primitive::$op::<i32>(lhs, rhs.value().unwrap())
}
Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
primitive::$op::<i64>(lhs, rhs.value().unwrap())
}
UInt8 => {
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u8>>().unwrap();
primitive::$op::<u8>(lhs, rhs.value().unwrap())
}
UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u16>>().unwrap();
primitive::$op::<u16>(lhs, rhs.value().unwrap())
}
UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u32>>().unwrap();
primitive::$op::<u32>(lhs, rhs.value().unwrap())
}
UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u64>>().unwrap();
primitive::$op::<u64>(lhs, rhs.value().unwrap())
}
Float16 => unreachable!(),
Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f32>>().unwrap();
primitive::$op::<f32>(lhs, rhs.value().unwrap())
}
Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f64>>().unwrap();
primitive::$op::<f64>(lhs, rhs.value().unwrap())
}
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<$T>>().unwrap();
primitive::$op::<$T>(lhs, rhs.value().unwrap())
}),
Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i32>>().unwrap();
Expand All @@ -304,14 +230,6 @@ macro_rules! compare_scalar {
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i64>>().unwrap();
utf8::$op::<i64>(lhs, rhs.value().unwrap())
}
Decimal(_, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs
.as_any()
.downcast_ref::<PrimitiveScalar<i128>>()
.unwrap();
primitive::$op::<i128>(lhs, rhs.value().unwrap())
}
Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i32>>().unwrap();
Expand All @@ -322,7 +240,7 @@ macro_rules! compare_scalar {
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i64>>().unwrap();
binary::$op::<i64>(lhs, rhs.value().unwrap())
}
_ => todo!("Comparisons of {:?} are not yet supported", data_type),
_ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()),
}
}};
}
Expand Down

0 comments on commit ce2497b

Please sign in to comment.