Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Support sum, min and max for extension and decimal #907

Merged
merged 1 commit into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 70 additions & 67 deletions src/compute/aggregate/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::datatypes::{DataType, PhysicalType, PrimitiveType};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
Expand Down Expand Up @@ -348,58 +348,55 @@ pub fn max_boolean(array: &BooleanArray) -> Option<bool> {
.or(Some(false))
}

macro_rules! dyn_primitive {
($ty:ty, $array:expr, $f:ident) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
$f::<$ty>(array),
))
}};
}

macro_rules! dyn_generic {
($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{
let array = $array.as_any().downcast_ref::<$array_ty>().unwrap();
Box::new(<$scalar_ty>::new($f(array)))
}};
}

macro_rules! with_match_primitive_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
_ => return Err(ArrowError::InvalidArgumentError(format!(
"`min` and `max` operator do not support primitive `{:?}`",
$key_type,
))),
}
})}

/// Returns the maximum of [`Array`]. The scalar is null when all elements are null.
/// # Error
/// Errors iff the type does not support this operation.
pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
DataType::Int8 => dyn_primitive!(i8, array, max_primitive),
DataType::Int16 => dyn_primitive!(i16, array, max_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, max_primitive)
Ok(match array.data_type().to_physical_type() {
PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::<$T>::new(data_type, max_primitive::<$T>(array)))
}),
PhysicalType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
PhysicalType::Binary => {
dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, max_binary)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, max_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, max_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, max_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, max_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, max_primitive),
DataType::Float64 => dyn_primitive!(f64, array, max_primitive),
DataType::Decimal(_, _) => dyn_primitive!(i128, array, max_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
DataType::Binary => dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, max_binary),
DataType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, max_binary)
PhysicalType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, min_binary)
}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -414,33 +411,19 @@ pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
/// # Error
/// Errors iff the type does not support this operation.
pub fn min(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
DataType::Int8 => dyn_primitive!(i8, array, min_primitive),
DataType::Int16 => dyn_primitive!(i16, array, min_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, min_primitive)
Ok(match array.data_type().to_physical_type() {
PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::<$T>::new(data_type, min_primitive::<$T>(array)))
}),
PhysicalType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
PhysicalType::Binary => {
dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, min_binary)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, min_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, min_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, min_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, min_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, min_primitive),
DataType::Float64 => dyn_primitive!(f64, array, min_primitive),
DataType::Decimal(_, _) => dyn_primitive!(i128, array, min_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
DataType::Binary => dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, min_binary),
DataType::LargeBinary => {
PhysicalType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, min_binary)
}
_ => {
Expand All @@ -451,3 +434,23 @@ pub fn min(array: &dyn Array) -> Result<Box<dyn Scalar>> {
}
})
}

/// Whether [`min`] supports `data_type`
pub fn can_min(data_type: &DataType) -> bool {
let physical = data_type.to_physical_type();
if let PhysicalType::Primitive(primitive) = physical {
use PrimitiveType::*;
matches!(
primitive,
Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64
)
} else {
use PhysicalType::*;
matches!(physical, Boolean | Utf8 | LargeUtf8 | Binary | LargeBinary)
}
}

/// Whether [`max`] supports `data_type`
pub fn can_max(data_type: &DataType) -> bool {
can_min(data_type)
}
44 changes: 43 additions & 1 deletion src/compute/aggregate/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
use super::SimdOrd;
use std::ops::Add;

use crate::types::simd::{i128x8, NativeSimd};

use super::{SimdOrd, Sum};

macro_rules! simd_add {
($simd:tt, $type:ty, $lanes:expr, $add:tt) => {
impl std::ops::AddAssign for $simd {
#[inline]
fn add_assign(&mut self, rhs: Self) {
for i in 0..$lanes {
self[i] = <$type>::$add(self[i], rhs[i]);
}
}
}

impl std::ops::Add for $simd {
type Output = Self;

#[inline]
fn add(self, rhs: Self) -> Self::Output {
let mut result = Self::default();
for i in 0..$lanes {
result[i] = <$type>::$add(self[i], rhs[i]);
}
result
}
}

impl Sum<$type> for $simd {
#[inline]
fn simd_sum(self) -> $type {
let mut reduced = <$type>::default();
(0..<$simd>::LANES).for_each(|i| {
reduced += self[i];
});
reduced
}
}
};
}

macro_rules! simd_ord_int {
($simd:tt, $type:ty) => {
impl SimdOrd<$type> for $simd {
Expand Down Expand Up @@ -54,8 +94,10 @@ macro_rules! simd_ord_int {
};
}

pub(super) use simd_add;
pub(super) use simd_ord_int;

simd_add!(i128x8, i128, 8, add);
simd_ord_int!(i128x8, i128);

#[cfg(not(feature = "simd"))]
Expand Down
39 changes: 1 addition & 38 deletions src/compute/aggregate/simd/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,7 @@ use crate::types::simd::*;

use super::super::min_max::SimdOrd;
use super::super::sum::Sum;
use super::simd_ord_int;

macro_rules! simd_add {
($simd:tt, $type:ty, $lanes:expr, $add:tt) => {
impl std::ops::AddAssign for $simd {
#[inline]
fn add_assign(&mut self, rhs: Self) {
for i in 0..$lanes {
self[i] = <$type>::$add(self[i], rhs[i]);
}
}
}

impl std::ops::Add for $simd {
type Output = Self;

#[inline]
fn add(self, rhs: Self) -> Self::Output {
let mut result = Self::default();
for i in 0..$lanes {
result[i] = <$type>::$add(self[i], rhs[i]);
}
result
}
}

impl Sum<$type> for $simd {
#[inline]
fn simd_sum(self) -> $type {
let mut reduced = <$type>::default();
(0..<$simd>::LANES).for_each(|i| {
reduced += self[i];
});
reduced
}
}
};
}
use super::{simd_add, simd_ord_int};

simd_add!(u8x64, u8, 64, wrapping_add);
simd_add!(u16x32, u16, 32, wrapping_add);
Expand Down
96 changes: 41 additions & 55 deletions src/compute/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Add;
use multiversion::multiversion;

use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::datatypes::{DataType, PhysicalType, PrimitiveType};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
Expand Down Expand Up @@ -104,68 +104,54 @@ where
}
}

macro_rules! dyn_sum {
($ty:ty, $array:expr) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
sum_primitive::<$ty>(array),
))
}};
}

/// Whether [`sum`] is valid for `data_type`
/// Whether [`sum`] supports `data_type`
pub fn can_sum(data_type: &DataType) -> bool {
use DataType::*;
matches!(
data_type,
Int8 | Int16
| Date32
| Time32(_)
| Interval(IntervalUnit::YearMonth)
| Int64
| Date64
| Time64(_)
| Timestamp(_, _)
| Duration(_)
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
)
if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() {
use PrimitiveType::*;
matches!(
primitive,
Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64
)
} else {
false
}
}

macro_rules! with_match_primitive_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
_ => return Err(ArrowError::InvalidArgumentError(format!(
"`sum` operator do not support primitive `{:?}`",
$key_type,
))),
}
})}

/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
/// and logical types as `array`.
/// # Error
/// Errors iff the operation is not supported.
pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Int8 => dyn_sum!(i8, array),
DataType::Int16 => dyn_sum!(i16, array),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_sum!(i32, array)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_sum!(i64, array),
DataType::UInt8 => dyn_sum!(u8, array),
DataType::UInt16 => dyn_sum!(u16, array),
DataType::UInt32 => dyn_sum!(u32, array),
DataType::UInt64 => dyn_sum!(u64, array),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_sum!(f32, array),
DataType::Float64 => dyn_sum!(f64, array),
Ok(match array.data_type().to_physical_type() {
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array)))
}),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `sum` operator does not support type `{:?}`",
Expand Down