From a10d525414a2adb4199adf7eade55baa0f76a211 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 12 Mar 2022 22:38:08 +0000 Subject: [PATCH] Simpler internals --- src/compute/aggregate/min_max.rs | 137 ++++++++++++++------------- src/compute/aggregate/simd/mod.rs | 44 ++++++++- src/compute/aggregate/simd/native.rs | 39 +------- src/compute/aggregate/sum.rs | 96 ++++++++----------- 4 files changed, 155 insertions(+), 161 deletions(-) diff --git a/src/compute/aggregate/min_max.rs b/src/compute/aggregate/min_max.rs index 8b78614a1a7..2c6478cff48 100644 --- a/src/compute/aggregate/min_max.rs +++ b/src/compute/aggregate/min_max.rs @@ -1,5 +1,5 @@ use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; -use crate::datatypes::{DataType, IntervalUnit}; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; use crate::error::{ArrowError, Result}; use crate::scalar::*; use crate::types::simd::*; @@ -348,19 +348,6 @@ pub fn max_boolean(array: &BooleanArray) -> Option { .or(Some(false)) } -macro_rules! dyn_primitive { - ($ty:ty, $array:expr, $f:ident) => {{ - let array = $array - .as_any() - .downcast_ref::>() - .unwrap(); - Box::new(PrimitiveScalar::<$ty>::new( - $array.data_type().clone(), - $f::<$ty>(array), - )) - }}; -} - macro_rules! dyn_generic { ($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{ let array = $array.as_any().downcast_ref::<$array_ty>().unwrap(); @@ -368,38 +355,48 @@ macro_rules! dyn_generic { }}; } +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(ArrowError::InvalidArgumentError(format!( + "`min` and `max` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + /// Returns the maximum of [`Array`]. The scalar is null when all elements are null. /// # Error /// Errors iff the type does not support this operation. pub fn max(array: &dyn Array) -> Result> { - Ok(match array.data_type() { - DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), - DataType::Int8 => dyn_primitive!(i8, array, max_primitive), - DataType::Int16 => dyn_primitive!(i16, array, max_primitive), - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_primitive!(i32, array, max_primitive) + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, max_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, max_binary) } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive), - DataType::UInt8 => dyn_primitive!(u8, array, max_primitive), - DataType::UInt16 => dyn_primitive!(u16, array, max_primitive), - DataType::UInt32 => dyn_primitive!(u32, array, max_primitive), - DataType::UInt64 => dyn_primitive!(u64, array, max_primitive), - DataType::Float16 => unreachable!(), - DataType::Float32 => dyn_primitive!(f32, array, max_primitive), - DataType::Float64 => dyn_primitive!(f64, array, max_primitive), - DataType::Decimal(_, _) => dyn_primitive!(i128, array, max_primitive), - DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), - DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), - DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, max_binary), - DataType::LargeBinary => { - dyn_generic!(BinaryArray, BinaryScalar, array, max_binary) + PhysicalType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) } _ => { return Err(ArrowError::InvalidArgumentError(format!( @@ -414,33 +411,19 @@ pub fn max(array: &dyn Array) -> Result> { /// # Error /// Errors iff the type does not support this operation. pub fn min(array: &dyn Array) -> Result> { - Ok(match array.data_type() { - DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), - DataType::Int8 => dyn_primitive!(i8, array, min_primitive), - DataType::Int16 => dyn_primitive!(i16, array, min_primitive), - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_primitive!(i32, array, min_primitive) + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, min_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive), - DataType::UInt8 => dyn_primitive!(u8, array, min_primitive), - DataType::UInt16 => dyn_primitive!(u16, array, min_primitive), - DataType::UInt32 => dyn_primitive!(u32, array, min_primitive), - DataType::UInt64 => dyn_primitive!(u64, array, min_primitive), - DataType::Float16 => unreachable!(), - DataType::Float32 => dyn_primitive!(f32, array, min_primitive), - DataType::Float64 => dyn_primitive!(f64, array, min_primitive), - DataType::Decimal(_, _) => dyn_primitive!(i128, array, min_primitive), - DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), - DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), - DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, min_binary), - DataType::LargeBinary => { + PhysicalType::LargeBinary => { dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) } _ => { @@ -451,3 +434,23 @@ pub fn min(array: &dyn Array) -> Result> { } }) } + +/// Whether [`min`] supports `data_type` +pub fn can_min(data_type: &DataType) -> bool { + let physical = data_type.to_physical_type(); + if let PhysicalType::Primitive(primitive) = physical { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + use PhysicalType::*; + matches!(physical, Boolean | Utf8 | LargeUtf8 | Binary | LargeBinary) + } +} + +/// Whether [`max`] supports `data_type` +pub fn can_max(data_type: &DataType) -> bool { + can_min(data_type) +} diff --git a/src/compute/aggregate/simd/mod.rs b/src/compute/aggregate/simd/mod.rs index f8cbb07e711..14b4625aa4a 100644 --- a/src/compute/aggregate/simd/mod.rs +++ b/src/compute/aggregate/simd/mod.rs @@ -1,6 +1,46 @@ -use super::SimdOrd; +use std::ops::Add; + use crate::types::simd::{i128x8, NativeSimd}; +use super::{SimdOrd, Sum}; + +macro_rules! simd_add { + ($simd:tt, $type:ty, $lanes:expr, $add:tt) => { + impl std::ops::AddAssign for $simd { + #[inline] + fn add_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self[i] = <$type>::$add(self[i], rhs[i]); + } + } + } + + impl std::ops::Add for $simd { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut result = Self::default(); + for i in 0..$lanes { + result[i] = <$type>::$add(self[i], rhs[i]); + } + result + } + } + + impl Sum<$type> for $simd { + #[inline] + fn simd_sum(self) -> $type { + let mut reduced = <$type>::default(); + (0..<$simd>::LANES).for_each(|i| { + reduced += self[i]; + }); + reduced + } + } + }; +} + macro_rules! simd_ord_int { ($simd:tt, $type:ty) => { impl SimdOrd<$type> for $simd { @@ -54,8 +94,10 @@ macro_rules! simd_ord_int { }; } +pub(super) use simd_add; pub(super) use simd_ord_int; +simd_add!(i128x8, i128, 8, add); simd_ord_int!(i128x8, i128); #[cfg(not(feature = "simd"))] diff --git a/src/compute/aggregate/simd/native.rs b/src/compute/aggregate/simd/native.rs index 1da285afe99..a4161048c9c 100644 --- a/src/compute/aggregate/simd/native.rs +++ b/src/compute/aggregate/simd/native.rs @@ -4,44 +4,7 @@ use crate::types::simd::*; use super::super::min_max::SimdOrd; use super::super::sum::Sum; -use super::simd_ord_int; - -macro_rules! simd_add { - ($simd:tt, $type:ty, $lanes:expr, $add:tt) => { - impl std::ops::AddAssign for $simd { - #[inline] - fn add_assign(&mut self, rhs: Self) { - for i in 0..$lanes { - self[i] = <$type>::$add(self[i], rhs[i]); - } - } - } - - impl std::ops::Add for $simd { - type Output = Self; - - #[inline] - fn add(self, rhs: Self) -> Self::Output { - let mut result = Self::default(); - for i in 0..$lanes { - result[i] = <$type>::$add(self[i], rhs[i]); - } - result - } - } - - impl Sum<$type> for $simd { - #[inline] - fn simd_sum(self) -> $type { - let mut reduced = <$type>::default(); - (0..<$simd>::LANES).for_each(|i| { - reduced += self[i]; - }); - reduced - } - } - }; -} +use super::{simd_add, simd_ord_int}; simd_add!(u8x64, u8, 64, wrapping_add); simd_add!(u16x32, u16, 32, wrapping_add); diff --git a/src/compute/aggregate/sum.rs b/src/compute/aggregate/sum.rs index 578ea6bd42d..609d47254f8 100644 --- a/src/compute/aggregate/sum.rs +++ b/src/compute/aggregate/sum.rs @@ -3,7 +3,7 @@ use std::ops::Add; use multiversion::multiversion; use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; -use crate::datatypes::{DataType, IntervalUnit}; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; use crate::error::{ArrowError, Result}; use crate::scalar::*; use crate::types::simd::*; @@ -104,68 +104,54 @@ where } } -macro_rules! dyn_sum { - ($ty:ty, $array:expr) => {{ - let array = $array - .as_any() - .downcast_ref::>() - .unwrap(); - Box::new(PrimitiveScalar::<$ty>::new( - $array.data_type().clone(), - sum_primitive::<$ty>(array), - )) - }}; -} - -/// Whether [`sum`] is valid for `data_type` +/// Whether [`sum`] supports `data_type` pub fn can_sum(data_type: &DataType) -> bool { - use DataType::*; - matches!( - data_type, - Int8 | Int16 - | Date32 - | Time32(_) - | Interval(IntervalUnit::YearMonth) - | Int64 - | Date64 - | Time64(_) - | Timestamp(_, _) - | Duration(_) - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - ) + if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + false + } } +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(ArrowError::InvalidArgumentError(format!( + "`sum` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + /// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical /// and logical types as `array`. /// # Error /// Errors iff the operation is not supported. pub fn sum(array: &dyn Array) -> Result> { - Ok(match array.data_type() { - DataType::Int8 => dyn_sum!(i8, array), - DataType::Int16 => dyn_sum!(i16, array), - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sum!(i32, array) - } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => dyn_sum!(i64, array), - DataType::UInt8 => dyn_sum!(u8, array), - DataType::UInt16 => dyn_sum!(u16, array), - DataType::UInt32 => dyn_sum!(u32, array), - DataType::UInt64 => dyn_sum!(u64, array), - DataType::Float16 => unreachable!(), - DataType::Float32 => dyn_sum!(f32, array), - DataType::Float64 => dyn_sum!(f64, array), + Ok(match array.data_type().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array))) + }), _ => { return Err(ArrowError::InvalidArgumentError(format!( "The `sum` operator does not support type `{:?}`",