diff --git a/src/compute/arithmetics/basic/add.rs b/src/compute/arithmetics/basic/add.rs index 48ce1d01eee..152b903ee9a 100644 --- a/src/compute/arithmetics/basic/add.rs +++ b/src/compute/arithmetics/basic/add.rs @@ -3,6 +3,7 @@ use std::ops::Add; use num_traits::{ops::overflowing::OverflowingAdd, CheckedAdd, SaturatingAdd, Zero}; +use crate::compute::arithmetics::basic::check_same_type; use crate::{ array::{Array, PrimitiveArray}, bitmap::Bitmap, @@ -14,7 +15,7 @@ use crate::{ binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, }, }, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; @@ -36,11 +37,7 @@ pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) } @@ -63,11 +60,7 @@ pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Resul where T: NativeType + CheckedAdd + Zero, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.checked_add(&b); @@ -96,11 +89,7 @@ pub fn saturating_add( where T: NativeType + SaturatingAdd, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.saturating_add(&b); @@ -130,11 +119,7 @@ pub fn overflowing_add( where T: NativeType + OverflowingAdd, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.overflowing_add(&b); diff --git a/src/compute/arithmetics/basic/common.rs b/src/compute/arithmetics/basic/common.rs new file mode 100644 index 00000000000..ed99efe7fbd --- /dev/null +++ b/src/compute/arithmetics/basic/common.rs @@ -0,0 +1,31 @@ +use crate::array::{Array, PrimitiveArray}; +use crate::error::{ArrowError, Result}; +use crate::types::NativeType; + +// Checking if both arrays have the same type +#[inline] +pub fn check_same_type( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result<()> { + if lhs.data_type() != rhs.data_type() { + return Err(ArrowError::InvalidArgumentError( + "Arrays must have the same logical type".to_string(), + )); + } + Ok(()) +} + +// Checking if both arrays have the same length +#[inline] +pub fn check_same_len( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result<()> { + if lhs.len() != rhs.len() { + return Err(ArrowError::InvalidArgumentError( + "Arrays must have the same length".to_string(), + )); + } + Ok(()) +} diff --git a/src/compute/arithmetics/basic/div.rs b/src/compute/arithmetics/basic/div.rs index 481053dd234..4b891d4d707 100644 --- a/src/compute/arithmetics/basic/div.rs +++ b/src/compute/arithmetics/basic/div.rs @@ -3,6 +3,7 @@ use std::ops::Div; use num_traits::{CheckedDiv, NumCast, Zero}; +use crate::compute::arithmetics::basic::{check_same_len, check_same_type}; use crate::datatypes::DataType; use crate::{ array::{Array, PrimitiveArray}, @@ -10,7 +11,7 @@ use crate::{ arithmetics::{ArrayCheckedDiv, ArrayDiv, NotI128}, arity::{binary, binary_checked, unary, unary_checked}, }, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; use strength_reduce::{ @@ -35,20 +36,12 @@ pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; if rhs.null_count() == 0 { binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b) } else { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) { (Some(l), Some(r)) => Some(*l / *r), _ => None, @@ -77,11 +70,7 @@ pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Resul where T: NativeType + CheckedDiv + Zero, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.checked_div(&b); diff --git a/src/compute/arithmetics/basic/mod.rs b/src/compute/arithmetics/basic/mod.rs index 95d4ff4e0e1..6522338da6c 100644 --- a/src/compute/arithmetics/basic/mod.rs +++ b/src/compute/arithmetics/basic/mod.rs @@ -17,3 +17,6 @@ mod rem; pub use rem::*; mod sub; pub use sub::*; + +mod common; +pub(crate) use common::*; diff --git a/src/compute/arithmetics/basic/mul.rs b/src/compute/arithmetics/basic/mul.rs index 2cb22edf60e..9b6e27ee1d4 100644 --- a/src/compute/arithmetics/basic/mul.rs +++ b/src/compute/arithmetics/basic/mul.rs @@ -3,6 +3,7 @@ use std::ops::Mul; use num_traits::{ops::overflowing::OverflowingMul, CheckedMul, SaturatingMul, Zero}; +use crate::compute::arithmetics::basic::check_same_type; use crate::{ array::{Array, PrimitiveArray}, bitmap::Bitmap, @@ -14,7 +15,7 @@ use crate::{ binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, }, }, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; @@ -36,11 +37,7 @@ pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) } @@ -64,11 +61,7 @@ pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Resul where T: NativeType + CheckedMul + Zero, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.checked_mul(&b); @@ -97,11 +90,7 @@ pub fn saturating_mul( where T: NativeType + SaturatingMul, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.saturating_mul(&b); @@ -131,11 +120,7 @@ pub fn overflowing_mul( where T: NativeType + OverflowingMul, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.overflowing_mul(&b); diff --git a/src/compute/arithmetics/basic/rem.rs b/src/compute/arithmetics/basic/rem.rs index d2601c44992..006fc5ab800 100644 --- a/src/compute/arithmetics/basic/rem.rs +++ b/src/compute/arithmetics/basic/rem.rs @@ -2,6 +2,7 @@ use std::ops::Rem; use num_traits::{CheckedRem, NumCast, Zero}; +use crate::compute::arithmetics::basic::check_same_type; use crate::datatypes::DataType; use crate::{ array::{Array, PrimitiveArray}, @@ -9,7 +10,7 @@ use crate::{ arithmetics::{ArrayCheckedRem, ArrayRem, NotI128}, arity::{binary, binary_checked, unary, unary_checked}, }, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; use strength_reduce::{ @@ -34,11 +35,7 @@ pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) } @@ -62,11 +59,7 @@ pub fn checked_rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Resul where T: NativeType + CheckedRem + Zero, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.checked_rem(&b); diff --git a/src/compute/arithmetics/basic/sub.rs b/src/compute/arithmetics/basic/sub.rs index 5f0cf7d8521..d69f2b04f4c 100644 --- a/src/compute/arithmetics/basic/sub.rs +++ b/src/compute/arithmetics/basic/sub.rs @@ -3,6 +3,7 @@ use std::ops::Sub; use num_traits::{ops::overflowing::OverflowingSub, CheckedSub, SaturatingSub, Zero}; +use crate::compute::arithmetics::basic::check_same_type; use crate::{ array::{Array, PrimitiveArray}, bitmap::Bitmap, @@ -14,7 +15,7 @@ use crate::{ binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, }, }, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; @@ -36,11 +37,7 @@ pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) } @@ -63,11 +60,7 @@ pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Resul where T: NativeType + CheckedSub + Zero, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.checked_sub(&b); @@ -96,11 +89,7 @@ pub fn saturating_sub( where T: NativeType + SaturatingSub, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.saturating_sub(&b); @@ -130,11 +119,7 @@ pub fn overflowing_sub( where T: NativeType + OverflowingSub, { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let op = move |a: T, b: T| a.overflowing_sub(&b); diff --git a/src/compute/arithmetics/decimal/add.rs b/src/compute/arithmetics/decimal/add.rs index 6d8d2e6e429..76a17a09afa 100644 --- a/src/compute/arithmetics/decimal/add.rs +++ b/src/compute/arithmetics/decimal/add.rs @@ -16,6 +16,7 @@ // under the License. //! Defines the addition arithmetic kernels for Decimal `PrimitiveArrays`. +use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, buffer::Buffer, @@ -253,12 +254,7 @@ pub fn adaptive_add( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Checking if both arrays have the same length - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) diff --git a/src/compute/arithmetics/decimal/div.rs b/src/compute/arithmetics/decimal/div.rs index b83fba11cfa..1a88881bace 100644 --- a/src/compute/arithmetics/decimal/div.rs +++ b/src/compute/arithmetics/decimal/div.rs @@ -18,6 +18,7 @@ //! Defines the division arithmetic kernels for Decimal //! `PrimitiveArrays`. +use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, buffer::Buffer, @@ -272,12 +273,7 @@ pub fn adaptive_div( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Checking if both arrays have the same length - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) diff --git a/src/compute/arithmetics/decimal/mul.rs b/src/compute/arithmetics/decimal/mul.rs index 008f608db02..8ea10b49f69 100644 --- a/src/compute/arithmetics/decimal/mul.rs +++ b/src/compute/arithmetics/decimal/mul.rs @@ -18,6 +18,7 @@ //! Defines the multiplication arithmetic kernels for Decimal //! `PrimitiveArrays`. +use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, buffer::Buffer, @@ -277,12 +278,7 @@ pub fn adaptive_mul( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Checking if both arrays have the same length - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) diff --git a/src/compute/arithmetics/decimal/sub.rs b/src/compute/arithmetics/decimal/sub.rs index 130193bbb36..8a692c7339d 100644 --- a/src/compute/arithmetics/decimal/sub.rs +++ b/src/compute/arithmetics/decimal/sub.rs @@ -17,6 +17,7 @@ //! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. +use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, buffer::Buffer, @@ -251,12 +252,7 @@ pub fn adaptive_sub( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Checking if both arrays have the same length - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) diff --git a/src/compute/arity.rs b/src/compute/arity.rs index fc73eb9143d..733f85aceb1 100644 --- a/src/compute/arity.rs +++ b/src/compute/arity.rs @@ -1,12 +1,13 @@ //! Defines kernels suitable to perform operations to primitive arrays. use super::utils::combine_validities; +use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, bitmap::{Bitmap, MutableBitmap}, buffer::Buffer, datatypes::DataType, - error::{ArrowError, Result}, + error::Result, types::NativeType, }; @@ -145,11 +146,7 @@ where D: NativeType, F: Fn(T, D) -> T, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; let validity = combine_validities(lhs.validity(), rhs.validity()); @@ -176,11 +173,7 @@ where D: NativeType, F: Fn(T, D) -> Result, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; let validity = combine_validities(lhs.validity(), rhs.validity()); @@ -208,11 +201,7 @@ where D: NativeType, F: Fn(T, D) -> (T, bool), { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; let validity = combine_validities(lhs.validity(), rhs.validity()); @@ -246,11 +235,7 @@ where D: NativeType, F: Fn(T, D) -> Option, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same length".to_string(), - )); - } + check_same_len(lhs, rhs)?; let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); diff --git a/src/compute/nullif.rs b/src/compute/nullif.rs index ec0ebfad896..20f0cc9eea1 100644 --- a/src/compute/nullif.rs +++ b/src/compute/nullif.rs @@ -1,5 +1,6 @@ use crate::array::PrimitiveArray; use crate::bitmap::Bitmap; +use crate::compute::arithmetics::basic::check_same_type; use crate::compute::comparison::{primitive_compare_values_op, Simd8, Simd8Lanes}; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; @@ -34,11 +35,7 @@ pub fn nullif_primitive( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - if lhs.data_type() != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Arrays must have the same logical type".to_string(), - )); - } + check_same_type(lhs, rhs)?; let equal = primitive_compare_values_op(lhs.values(), rhs.values(), |lhs, rhs| lhs.neq(rhs)); let equal: Option = equal.into();