diff --git a/src/compute/arithmetics/basic/mod.rs b/src/compute/arithmetics/basic/mod.rs index 6522338da6c..492a29b37a9 100644 --- a/src/compute/arithmetics/basic/mod.rs +++ b/src/compute/arithmetics/basic/mod.rs @@ -1,4 +1,4 @@ -//! Contains arithemtic functions for [`PrimitiveArray`](crate::array::PrimitiveArray)s. +//! Contains arithemtic functions for [`PrimitiveArray`]s. //! //! Each operation has four variants, like the rest of Rust's ecosystem: //! * usual, that [`panic!`]s on overflow @@ -20,3 +20,72 @@ pub use sub::*; mod common; pub(crate) use common::*; + +use std::ops::Neg; + +use num_traits::{CheckedNeg, WrappingNeg}; + +use crate::{ + array::{Array, PrimitiveArray}, + types::NativeType, +}; + +use super::super::arity::{unary, unary_checked}; + +/// Negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::negate; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(7)]); +/// let result = negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected) +/// ``` +pub fn negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Neg, +{ + unary(array, |a| -a, array.data_type().clone()) +} + +/// Checked negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::checked_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = checked_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected); +/// assert!(!result.is_valid(2)) +/// ``` +pub fn checked_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + CheckedNeg, +{ + unary_checked(array, |a| a.checked_neg(), array.data_type().clone()) +} + +/// Wrapping negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::wrapping_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = wrapping_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), Some(i8::MIN), Some(-7)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + WrappingNeg, +{ + unary(array, |a| a.wrapping_neg(), array.data_type().clone()) +} diff --git a/src/compute/arithmetics/decimal/mod.rs b/src/compute/arithmetics/decimal/mod.rs index c3b69468756..ade2d0cec9c 100644 --- a/src/compute/arithmetics/decimal/mod.rs +++ b/src/compute/arithmetics/decimal/mod.rs @@ -3,10 +3,14 @@ //! precision and scale parameters. These affect the arithmetic operations and //! need to be considered while doing operations with Decimal numbers. -pub mod add; -pub mod div; -pub mod mul; -pub mod sub; +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +mod sub; +pub use sub::*; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index e75c604d20d..9a47889ee91 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -1,350 +1,262 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines basic arithmetic kernels for [`PrimitiveArray`]s. -//! -//! # Description +//! Defines basic arithmetic kernels for [`PrimitiveArray`](crate::array::PrimitiveArray)s. //! //! The Arithmetics module is composed by basic arithmetics operations that can -//! be performed on PrimitiveArray Arrays. These operations can be the building for -//! any implementation using Arrow. -//! -//! Whenever possible, each of the operations in these modules has variations -//! of the basic operation that offers different guarantees. These options are: -//! -//! * plain: The plain type (add, sub, mul, and div) don't offer any protection -//! when performing the operations. This means that if overflow is found, -//! then the operations will panic. -//! -//! * checked: A checked operation will change the validity Bitmap for the -//! offending operation. For example, if one of the operations overflows, the -//! validity will be changed to None, indicating a Null value. -//! -//! * saturating: If overflowing is presented in one operation, the resulting -//! value for that index will be saturated to the MAX or MIN value possible -//! for that type. For [`Decimal`](crate::datatypes::DataType::Decimal) -//! arrays, the saturated value is calculated considering the precision and -//! scale of the array. +//! be performed on [`PrimitiveArray`](crate::array::PrimitiveArray). //! -//! * overflowing: When an operation overflows, the resulting will be the -//! overflowed value for the operation. The result from the array operation -//! includes a Binary bitmap indicating which values overflowed. -//! -//! * adaptive: For [`Decimal`](crate::datatypes::DataType::Decimal) arrays, -//! the adaptive variation adjusts the precision and scale to avoid -//! saturation or overflowing. -//! -//! # New kernels -//! -//! When adding a new operation to this module, it is strongly suggested to -//! follow the design description presented in the README.md file located in -//! the [`compute`](crate::compute) module and the function descriptions -//! presented in this document. +//! Whenever possible, each operation declares variations +//! of the basic operation that offers different guarantees: +//! * plain: panics on overflowing and underflowing. +//! * checked: turns an overflowing to a null. +//! * saturating: turns the overflowing to the MAX or MIN value respectively. +//! * overflowing: returns an extra [`Bitmap`] denoting whether the operation overflowed. +//! * adaptive: for [`Decimal`](crate::datatypes::DataType::Decimal) only, +//! adjusts the precision and scale to make the resulting value fit. pub mod basic; pub mod decimal; pub mod time; -use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; - -use num_traits::{CheckedNeg, NumCast, WrappingNeg, Zero}; - -use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; -use crate::error::{ArrowError, Result}; -use crate::types::NativeType; -use crate::{array::*, bitmap::Bitmap}; - -use super::arity::{unary, unary_checked}; +use crate::{ + array::Array, + bitmap::Bitmap, + datatypes::{DataType, IntervalUnit, TimeUnit}, + types::NativeType, +}; // Macro to evaluate match branch in arithmetic function. -// The macro is used to downcast both arrays to a primitive_array_type. If there -// is an error then an ArrowError is return with the data_type that cause it. -// It returns the result from the arithmetic_primitive function evaluated with -// the Operator selected macro_rules! primitive { - ($lhs: expr, $rhs: expr, $op: expr, $array_type: ty) => {{ - let res_lhs = $lhs.as_any().downcast_ref().unwrap(); - let res_rhs = $rhs.as_any().downcast_ref().unwrap(); + ($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{ + let lhs = $lhs.as_any().downcast_ref().unwrap(); + let rhs = $rhs.as_any().downcast_ref().unwrap(); - let res = arithmetic_primitive::<$array_type>(res_lhs, $op, res_rhs); - Ok(Box::new(res) as Box) + let result = basic::$op::<$type>(lhs, rhs); + Box::new(result) as Box }}; } -/// Execute an arithmetic operation with two arrays. It uses the enum Operator -/// to select the type of operation that is going to be performed with the two -/// arrays -pub fn arithmetic(lhs: &dyn Array, op: Operator, rhs: &dyn Array) -> Result> { - use DataType::*; - use Operator::*; - match (lhs.data_type(), op, rhs.data_type()) { - (Int8, _, Int8) => primitive!(lhs, rhs, op, i8), - (Int16, _, Int16) => primitive!(lhs, rhs, op, i16), - (Int32, _, Int32) => primitive!(lhs, rhs, op, i32), - (Int64, _, Int64) | (Duration(_), _, Duration(_)) => { - primitive!(lhs, rhs, op, i64) - } - (UInt8, _, UInt8) => primitive!(lhs, rhs, op, u8), - (UInt16, _, UInt16) => primitive!(lhs, rhs, op, u16), - (UInt32, _, UInt32) => primitive!(lhs, rhs, op, u32), - (UInt64, _, UInt64) => primitive!(lhs, rhs, op, u64), - (Float32, _, Float32) => primitive!(lhs, rhs, op, f32), - (Float64, _, Float64) => primitive!(lhs, rhs, op, f64), - (Decimal(_, _), _, Decimal(_, _)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - - let res = match op { - Add => decimal::add::add(lhs, rhs), - Subtract => decimal::sub::sub(lhs, rhs), - Multiply => decimal::mul::mul(lhs, rhs), - Divide => decimal::div::div(lhs, rhs), - Remainder => { - return Err(ArrowError::NotYetImplemented(format!( - "Arithmetics of ({:?}, {:?}, {:?}) is not supported", - lhs, op, rhs - ))) - } - }; - - Ok(Box::new(res) as Box) - } - (Time32(TimeUnit::Second), Add, Duration(_)) - | (Time32(TimeUnit::Millisecond), Add, Duration(_)) - | (Date32, Add, Duration(_)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::add_duration::(lhs, rhs).map(|x| Box::new(x) as Box) - } - (Time32(TimeUnit::Second), Subtract, Duration(_)) - | (Time32(TimeUnit::Millisecond), Subtract, Duration(_)) - | (Date32, Subtract, Duration(_)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::subtract_duration::(lhs, rhs).map(|x| Box::new(x) as Box) - } - (Time64(TimeUnit::Microsecond), Add, Duration(_)) - | (Time64(TimeUnit::Nanosecond), Add, Duration(_)) - | (Date64, Add, Duration(_)) - | (Timestamp(_, _), Add, Duration(_)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::add_duration::(lhs, rhs).map(|x| Box::new(x) as Box) - } - (Timestamp(_, _), Add, Interval(IntervalUnit::MonthDayNano)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::add_interval(lhs, rhs).map(|x| Box::new(x) as Box) - } - (Time64(TimeUnit::Microsecond), Subtract, Duration(_)) - | (Time64(TimeUnit::Nanosecond), Subtract, Duration(_)) - | (Date64, Subtract, Duration(_)) - | (Timestamp(_, _), Subtract, Duration(_)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::subtract_duration::(lhs, rhs).map(|x| Box::new(x) as Box) +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => primitive!(lhs, rhs, $op, i8), + (Int16, Int16) => primitive!(lhs, rhs, $op, i16), + (Int32, Int32) => primitive!(lhs, rhs, $op, i32), + (Int64, Int64) | (Duration(_), Duration(_)) => { + primitive!(lhs, rhs, $op, i64) + } + (UInt8, UInt8) => primitive!(lhs, rhs, $op, u8), + (UInt16, UInt16) => primitive!(lhs, rhs, $op, u16), + (UInt32, UInt32) => primitive!(lhs, rhs, $op, u32), + (UInt64, UInt64) => primitive!(lhs, rhs, $op, u64), + (Float32, Float32) => primitive!(lhs, rhs, $op, f32), + (Float64, Float64) => primitive!(lhs, rhs, $op, f64), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(decimal::$op_decimal(lhs, rhs)) as Box + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), } - (Timestamp(_, None), Subtract, Timestamp(_, None)) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - time::subtract_timestamps(lhs, rhs).map(|x| Box::new(x) as Box) - } - (lhs, op, rhs) => Err(ArrowError::NotYetImplemented(format!( - "Arithmetics of ({:?}, {:?}, {:?}) is not supported", - lhs, op, rhs - ))), - } + }}; +} + +/// Adds two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the opertion is not supported for the logical types (use [`can_add`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + add, + duration = add_duration, + interval = add_interval + ) } -/// Checks if an array of type `datatype` can perform basic arithmetic -/// operations. These operations include add, subtract, multiply, divide. -/// -/// # Examples -/// ``` -/// use arrow2::compute::arithmetics::{can_arithmetic, Operator}; -/// use arrow2::datatypes::DataType; -/// -/// let data_type = DataType::Int8; -/// assert_eq!(can_arithmetic(&data_type, Operator::Add, &data_type), true); -/// -/// let data_type = DataType::LargeBinary; -/// assert_eq!(can_arithmetic(&data_type, Operator::Add, &data_type), false) -/// ``` -pub fn can_arithmetic(lhs: &DataType, op: Operator, rhs: &DataType) -> bool { +/// Returns whether two [`DataType`]s can be added by [`add`]. +pub fn can_add(lhs: &DataType, rhs: &DataType) -> bool { use DataType::*; - use Operator::*; - if let (Decimal(_, _), Remainder, Decimal(_, _)) = (lhs, op, rhs) { - return false; - }; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) + ) +} + +/// Subtracts two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the opertion is not supported for the logical types (use [`can_sub`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + sub, + decimal = sub, + duration = subtract_duration, + timestamp = subtract_timestamps + ) +} +/// Returns whether two [`DataType`]s can be subtracted by [`sub`]. +pub fn can_sub(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; matches!( - (lhs, op, rhs), - (Int8, _, Int8) - | (Int16, _, Int16) - | (Int32, _, Int32) - | (Int64, _, Int64) - | (UInt8, _, UInt8) - | (UInt16, _, UInt16) - | (UInt32, _, UInt32) - | (UInt64, _, UInt64) - | (Float64, _, Float64) - | (Float32, _, Float32) - | (Duration(_), _, Duration(_)) - | (Decimal(_, _), _, Decimal(_, _)) - | (Date32, Subtract, Duration(_)) - | (Date32, Add, Duration(_)) - | (Date64, Subtract, Duration(_)) - | (Date64, Add, Duration(_)) - | (Time32(TimeUnit::Millisecond), Subtract, Duration(_)) - | (Time32(TimeUnit::Second), Subtract, Duration(_)) - | (Time32(TimeUnit::Millisecond), Add, Duration(_)) - | (Time32(TimeUnit::Second), Add, Duration(_)) - | (Time64(TimeUnit::Microsecond), Subtract, Duration(_)) - | (Time64(TimeUnit::Nanosecond), Subtract, Duration(_)) - | (Time64(TimeUnit::Microsecond), Add, Duration(_)) - | (Time64(TimeUnit::Nanosecond), Add, Duration(_)) - | (Timestamp(_, _), Subtract, Duration(_)) - | (Timestamp(_, _), Add, Duration(_)) - | (Timestamp(_, _), Add, Interval(IntervalUnit::MonthDayNano)) - | (Timestamp(_, None), Subtract, Timestamp(_, None)) + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, None), Timestamp(_, None)) ) } -/// Arithmetic operator -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Operator { - /// Add - Add, - /// Subtract - Subtract, - /// Multiply - Multiply, - /// Divide - Divide, - /// Remainder - Remainder, +/// Multiply two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the opertion is not supported for the logical types (use [`can_mul`] to check) +/// * the arrays have a different length +pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, mul, decimal = mul) } -/// Perform arithmetic operations on two primitive arrays based on the Operator enum -// -pub fn arithmetic_primitive( - lhs: &PrimitiveArray, - op: Operator, - rhs: &PrimitiveArray, -) -> PrimitiveArray -where - T: NativeType - + Div - + Zero - + Add - + Sub - + Mul - + Rem, -{ - match op { - Operator::Add => basic::add(lhs, rhs), - Operator::Subtract => basic::sub(lhs, rhs), - Operator::Multiply => basic::mul(lhs, rhs), - Operator::Divide => basic::div(lhs, rhs), - Operator::Remainder => basic::rem(lhs, rhs), - } +/// Returns whether two [`DataType`]s can be multiplied by [`mul`]. +pub fn can_mul(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Decimal(_, _), Decimal(_, _)) + ) } -/// Performs primitive operation on an array and and scalar -pub fn arithmetic_primitive_scalar( - lhs: &PrimitiveArray, - op: Operator, - rhs: &T, -) -> Result> -where - T: NativeType - + Div - + Zero - + Add - + Sub - + Mul - + Rem - + NumCast, -{ - match op { - Operator::Add => Ok(basic::add_scalar(lhs, rhs)), - Operator::Subtract => Ok(basic::sub_scalar(lhs, rhs)), - Operator::Multiply => Ok(basic::mul_scalar(lhs, rhs)), - Operator::Divide => Ok(basic::div_scalar(lhs, rhs)), - Operator::Remainder => Ok(basic::rem_scalar(lhs, rhs)), - } +/// Divide of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the opertion is not supported for the logical types (use [`can_div`] to check) +/// * the arrays have a different length +pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, div, decimal = div) } -/// Negates values from array. -/// -/// # Examples -/// ``` -/// use arrow2::compute::arithmetics::negate; -/// use arrow2::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([None, Some(6), None, Some(7)]); -/// let result = negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); -/// assert_eq!(result, expected) -/// ``` -pub fn negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + Neg, -{ - unary(array, |a| -a, array.data_type().clone()) +/// Returns whether two [`DataType`]s can be divided by [`div`]. +pub fn can_div(lhs: &DataType, rhs: &DataType) -> bool { + can_mul(lhs, rhs) } -/// Checked negates values from array. -/// -/// # Examples -/// ``` -/// use arrow2::compute::arithmetics::checked_negate; -/// use arrow2::array::{Array, PrimitiveArray}; -/// -/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); -/// let result = checked_negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); -/// assert_eq!(result, expected); -/// assert!(!result.is_valid(2)) -/// ``` -pub fn checked_negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + CheckedNeg, -{ - unary_checked(array, |a| a.checked_neg(), array.data_type().clone()) +/// Remainder of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the opertion is not supported for the logical types (use [`can_rem`] to check) +/// * the arrays have a different length +pub fn rem(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, rem) } -/// Wrapping negates values from array. -/// -/// # Examples -/// ``` -/// use arrow2::compute::arithmetics::wrapping_negate; -/// use arrow2::array::{Array, PrimitiveArray}; -/// -/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); -/// let result = wrapping_negate(&a); -/// let expected = PrimitiveArray::from([None, Some(-6), Some(i8::MIN), Some(-7)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_negate(array: &PrimitiveArray) -> PrimitiveArray -where - T: NativeType + WrappingNeg, -{ - unary(array, |a| a.wrapping_neg(), array.data_type().clone()) +/// Returns whether two [`DataType`]s "can be remainder" by [`rem`]. +pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + ) } /// Defines basic addition operation for primitive arrays diff --git a/src/compute/arithmetics/time.rs b/src/compute/arithmetics/time.rs index 1d468853427..d622652823e 100644 --- a/src/compute/arithmetics/time.rs +++ b/src/compute/arithmetics/time.rs @@ -103,18 +103,18 @@ fn create_scale(lhs: &DataType, rhs: &DataType) -> Result { pub fn add_duration( time: &PrimitiveArray, duration: &PrimitiveArray, -) -> Result> +) -> PrimitiveArray where f64: AsPrimitive, T: NativeType + Add, { - let scale = create_scale(time.data_type(), duration.data_type())?; + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); // Closure for the binary operation. The closure contains the scale // required to add a duration to the timestamp array. let op = move |a: T, b: i64| a + (b as f64 * scale).as_(); - Ok(binary(time, duration, time.data_type().clone(), op)) + binary(time, duration, time.data_type().clone(), op) } /// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit @@ -159,18 +159,18 @@ where pub fn subtract_duration( time: &PrimitiveArray, duration: &PrimitiveArray, -) -> Result> +) -> PrimitiveArray where f64: AsPrimitive, T: NativeType + Sub, { - let scale = create_scale(time.data_type(), duration.data_type())?; + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); // Closure for the binary operation. The closure contains the scale // required to add a duration to the timestamp array. let op = move |a: T, b: i64| a - (b as f64 * scale).as_(); - Ok(binary(time, duration, time.data_type().clone(), op)) + binary(time, duration, time.data_type().clone(), op) } /// Calculates the difference between two timestamps returning an array of type