Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Simpler internals
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Mar 12, 2022
1 parent 1431b96 commit df60795
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 142 deletions.
117 changes: 50 additions & 67 deletions src/compute/aggregate/min_max.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::datatypes::PhysicalType;
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
Expand Down Expand Up @@ -348,58 +348,55 @@ pub fn max_boolean(array: &BooleanArray) -> Option<bool> {
.or(Some(false))
}

macro_rules! dyn_primitive {
($ty:ty, $array:expr, $f:ident) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
$f::<$ty>(array),
))
}};
}

macro_rules! dyn_generic {
($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{
let array = $array.as_any().downcast_ref::<$array_ty>().unwrap();
Box::new(<$scalar_ty>::new($f(array)))
}};
}

macro_rules! with_match_primitive_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
_ => return Err(ArrowError::InvalidArgumentError(format!(
"`min` and `max` operator do not support primitive `{:?}`",
$key_type,
))),
}
})}

/// Returns the maximum of [`Array`]. The scalar is null when all elements are null.
/// # Error
/// Errors iff the type does not support this operation.
pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
DataType::Int8 => dyn_primitive!(i8, array, max_primitive),
DataType::Int16 => dyn_primitive!(i16, array, max_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, max_primitive)
Ok(match array.data_type().to_physical_type() {
PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean),
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::<$T>::new(data_type, max_primitive::<$T>(array)))
}),
PhysicalType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
PhysicalType::Binary => {
dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, max_binary)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, max_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, max_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, max_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, max_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, max_primitive),
DataType::Float64 => dyn_primitive!(f64, array, max_primitive),
DataType::Decimal(_, _) => dyn_primitive!(i128, array, max_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, max_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, max_string),
DataType::Binary => dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, max_binary),
DataType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, max_binary)
PhysicalType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, min_binary)
}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -414,33 +411,19 @@ pub fn max(array: &dyn Array) -> Result<Box<dyn Scalar>> {
/// # Error
/// Errors iff the type does not support this operation.
pub fn min(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
DataType::Int8 => dyn_primitive!(i8, array, min_primitive),
DataType::Int16 => dyn_primitive!(i16, array, min_primitive),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_primitive!(i32, array, min_primitive)
Ok(match array.data_type().to_physical_type() {
PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean),
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::<$T>::new(data_type, min_primitive::<$T>(array)))
}),
PhysicalType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
PhysicalType::Binary => {
dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, min_binary)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive),
DataType::UInt8 => dyn_primitive!(u8, array, min_primitive),
DataType::UInt16 => dyn_primitive!(u16, array, min_primitive),
DataType::UInt32 => dyn_primitive!(u32, array, min_primitive),
DataType::UInt64 => dyn_primitive!(u64, array, min_primitive),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_primitive!(f32, array, min_primitive),
DataType::Float64 => dyn_primitive!(f64, array, min_primitive),
DataType::Decimal(_, _) => dyn_primitive!(i128, array, min_primitive),
DataType::Utf8 => dyn_generic!(Utf8Array<i32>, Utf8Scalar<i32>, array, min_string),
DataType::LargeUtf8 => dyn_generic!(Utf8Array<i64>, Utf8Scalar<i64>, array, min_string),
DataType::Binary => dyn_generic!(BinaryArray<i32>, BinaryScalar<i32>, array, min_binary),
DataType::LargeBinary => {
PhysicalType::LargeBinary => {
dyn_generic!(BinaryArray<i64>, BinaryScalar<i64>, array, min_binary)
}
_ => {
Expand Down
44 changes: 43 additions & 1 deletion src/compute/aggregate/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
use super::SimdOrd;
use std::ops::Add;

use crate::types::simd::{i128x8, NativeSimd};

use super::{SimdOrd, Sum};

macro_rules! simd_add {
($simd:tt, $type:ty, $lanes:expr, $add:tt) => {
impl std::ops::AddAssign for $simd {
#[inline]
fn add_assign(&mut self, rhs: Self) {
for i in 0..$lanes {
self[i] = <$type>::$add(self[i], rhs[i]);
}
}
}

impl std::ops::Add for $simd {
type Output = Self;

#[inline]
fn add(self, rhs: Self) -> Self::Output {
let mut result = Self::default();
for i in 0..$lanes {
result[i] = <$type>::$add(self[i], rhs[i]);
}
result
}
}

impl Sum<$type> for $simd {
#[inline]
fn simd_sum(self) -> $type {
let mut reduced = <$type>::default();
(0..<$simd>::LANES).for_each(|i| {
reduced += self[i];
});
reduced
}
}
};
}

macro_rules! simd_ord_int {
($simd:tt, $type:ty) => {
impl SimdOrd<$type> for $simd {
Expand Down Expand Up @@ -54,8 +94,10 @@ macro_rules! simd_ord_int {
};
}

pub(super) use simd_add;
pub(super) use simd_ord_int;

simd_add!(i128x8, i128, 8, add);
simd_ord_int!(i128x8, i128);

#[cfg(not(feature = "simd"))]
Expand Down
39 changes: 1 addition & 38 deletions src/compute/aggregate/simd/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,7 @@ use crate::types::simd::*;

use super::super::min_max::SimdOrd;
use super::super::sum::Sum;
use super::simd_ord_int;

macro_rules! simd_add {
($simd:tt, $type:ty, $lanes:expr, $add:tt) => {
impl std::ops::AddAssign for $simd {
#[inline]
fn add_assign(&mut self, rhs: Self) {
for i in 0..$lanes {
self[i] = <$type>::$add(self[i], rhs[i]);
}
}
}

impl std::ops::Add for $simd {
type Output = Self;

#[inline]
fn add(self, rhs: Self) -> Self::Output {
let mut result = Self::default();
for i in 0..$lanes {
result[i] = <$type>::$add(self[i], rhs[i]);
}
result
}
}

impl Sum<$type> for $simd {
#[inline]
fn simd_sum(self) -> $type {
let mut reduced = <$type>::default();
(0..<$simd>::LANES).for_each(|i| {
reduced += self[i];
});
reduced
}
}
};
}
use super::{simd_add, simd_ord_int};

simd_add!(u8x64, u8, 64, wrapping_add);
simd_add!(u16x32, u16, 32, wrapping_add);
Expand Down
68 changes: 32 additions & 36 deletions src/compute/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Add;
use multiversion::multiversion;

use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
use crate::datatypes::{DataType, IntervalUnit};
use crate::datatypes::{DataType, IntervalUnit, PhysicalType};
use crate::error::{ArrowError, Result};
use crate::scalar::*;
use crate::types::simd::*;
Expand Down Expand Up @@ -104,19 +104,6 @@ where
}
}

macro_rules! dyn_sum {
($ty:ty, $array:expr) => {{
let array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$ty>>()
.unwrap();
Box::new(PrimitiveScalar::<$ty>::new(
$array.data_type().clone(),
sum_primitive::<$ty>(array),
))
}};
}

/// Whether [`sum`] is valid for `data_type`
pub fn can_sum(data_type: &DataType) -> bool {
use DataType::*;
Expand All @@ -140,35 +127,44 @@ pub fn can_sum(data_type: &DataType) -> bool {
)
}

macro_rules! with_match_primitive_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
_ => return Err(ArrowError::InvalidArgumentError(format!(
"`sum` operator do not support primitive `{:?}`",
$key_type,
))),
}
})}

/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
/// and logical types as `array`.
/// # Error
/// Errors iff the operation is not supported.
pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
Ok(match array.data_type() {
DataType::Int8 => dyn_sum!(i8, array),
DataType::Int16 => dyn_sum!(i16, array),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_sum!(i32, array)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => dyn_sum!(i64, array),
DataType::UInt8 => dyn_sum!(u8, array),
DataType::UInt16 => dyn_sum!(u16, array),
DataType::UInt32 => dyn_sum!(u32, array),
DataType::UInt64 => dyn_sum!(u64, array),
DataType::Float16 => unreachable!(),
DataType::Float32 => dyn_sum!(f32, array),
DataType::Float64 => dyn_sum!(f64, array),
Ok(match array.data_type().to_physical_type() {
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
let data_type = array.data_type().clone();
let array = array.as_any().downcast_ref().unwrap();
Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array)))
}),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"The `sum` operator does not support type `{:?}`",
"The `max` operator does not support type `{:?}`",
array.data_type(),
)))
}
Expand Down

0 comments on commit df60795

Please sign in to comment.