diff --git a/src/compute/aggregate/min_max.rs b/src/compute/aggregate/min_max.rs index 0e501703c24..e841f0717cb 100644 --- a/src/compute/aggregate/min_max.rs +++ b/src/compute/aggregate/min_max.rs @@ -1,3 +1,6 @@ +use crate::datatypes::{DataType, IntervalUnit}; +use crate::error::{ArrowError, Result}; +use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; use crate::{ @@ -248,6 +251,94 @@ pub fn max_boolean(array: &BooleanArray) -> Option { .or(Some(false)) } +macro_rules! dyn_primitive { + ($ty:ty, $array:expr, $f:ident) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(PrimitiveScalar::<$ty>::new( + $array.data_type().clone(), + $f::<$ty>(array), + )) + }}; +} + +macro_rules! dyn_generic { + ($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{ + let array = $array.as_any().downcast_ref::<$array_ty>().unwrap(); + Box::new(<$scalar_ty>::new($f(array))) + }}; +} + +pub fn max(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), + DataType::Int8 => dyn_primitive!(i8, array, max_primitive), + DataType::Int16 => dyn_primitive!(i16, array, max_primitive), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(i32, array, max_primitive) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_primitive!(i64, array, max_primitive), + DataType::UInt8 => dyn_primitive!(u8, array, max_primitive), + DataType::UInt16 => dyn_primitive!(u16, array, max_primitive), + DataType::UInt32 => dyn_primitive!(u32, array, max_primitive), + DataType::UInt64 => dyn_primitive!(u64, array, max_primitive), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_primitive!(f32, array, max_primitive), + DataType::Float64 => dyn_primitive!(f64, array, max_primitive), + DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `max` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + +pub fn min(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), + DataType::Int8 => dyn_primitive!(i8, array, min_primitive), + DataType::Int16 => dyn_primitive!(i16, array, min_primitive), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(i32, array, min_primitive) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_primitive!(i64, array, min_primitive), + DataType::UInt8 => dyn_primitive!(u8, array, min_primitive), + DataType::UInt16 => dyn_primitive!(u16, array, min_primitive), + DataType::UInt32 => dyn_primitive!(u32, array, min_primitive), + DataType::UInt64 => dyn_primitive!(u64, array, min_primitive), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_primitive!(f32, array, min_primitive), + DataType::Float64 => dyn_primitive!(f64, array, min_primitive), + DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `max` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/compute/aggregate/sum.rs b/src/compute/aggregate/sum.rs index 8e280b8af31..2227b00e626 100644 --- a/src/compute/aggregate/sum.rs +++ b/src/compute/aggregate/sum.rs @@ -2,6 +2,9 @@ use std::ops::Add; use multiversion::multiversion; +use crate::datatypes::{DataType, IntervalUnit}; +use crate::error::{ArrowError, Result}; +use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; use crate::{ @@ -68,7 +71,7 @@ where /// Returns the sum of values in the array. /// /// Returns `None` if the array is empty or only contains null values. -pub fn sum(array: &PrimitiveArray) -> Option +pub fn sum_primitive(array: &PrimitiveArray) -> Option where T: NativeType + Simd, T::Simd: Add + Sum, @@ -85,6 +88,50 @@ where } } +macro_rules! dyn_sum { + ($ty:ty, $array:expr) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(PrimitiveScalar::<$ty>::new( + $array.data_type().clone(), + sum_primitive::<$ty>(array), + )) + }}; +} + +pub fn sum(array: &dyn Array) -> Result> { + Ok(match array.data_type() { + DataType::Int8 => dyn_sum!(i8, array), + DataType::Int16 => dyn_sum!(i16, array), + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_sum!(i32, array) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_sum!(i64, array), + DataType::UInt8 => dyn_sum!(u8, array), + DataType::UInt16 => dyn_sum!(u16, array), + DataType::UInt32 => dyn_sum!(u32, array), + DataType::UInt64 => dyn_sum!(u64, array), + DataType::Float16 => unreachable!(), + DataType::Float32 => dyn_sum!(f32, array), + DataType::Float64 => dyn_sum!(f64, array), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "The `sum` operator does not support type `{}`", + array.data_type(), + ))) + } + }) +} + #[cfg(test)] mod tests { use super::super::super::arithmetics; @@ -95,25 +142,25 @@ mod tests { #[test] fn test_primitive_array_sum() { let a = Primitive::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Int32); - assert_eq!(15, sum(&a).unwrap()); + assert_eq!(15, sum_primitive(&a).unwrap()); } #[test] fn test_primitive_array_float_sum() { let a = Primitive::from_slice(&[1.1f64, 2.2, 3.3, 4.4, 5.5]).to(DataType::Float64); - assert!((16.5 - sum(&a).unwrap()).abs() < f64::EPSILON); + assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON); } #[test] fn test_primitive_array_sum_with_nulls() { let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]); - assert_eq!(10, sum(&a).unwrap()); + assert_eq!(10, sum_primitive(&a).unwrap()); } #[test] fn test_primitive_array_sum_all_nulls() { let a = Int32Array::from(&[None, None, None]); - assert_eq!(None, sum(&a)); + assert_eq!(None, sum_primitive(&a)); } #[test] @@ -126,6 +173,9 @@ mod tests { .collect(); // create an array that actually has non-zero values at the invalid indices let c = arithmetics::basic::add::add(&a, &b).unwrap(); - assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum()), + sum_primitive(&c) + ); } }