Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
chore: reducing type check and len check code
Browse files Browse the repository at this point in the history
  • Loading branch information
yjhmelody committed Sep 30, 2021
1 parent b78eeb7 commit f4624da
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 140 deletions.
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Add;

use num_traits::{ops::overflowing::OverflowingAdd, CheckedAdd, SaturatingAdd, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Add<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b)
}
Expand All @@ -63,11 +60,7 @@ pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedAdd<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_add(&b);

Expand Down Expand Up @@ -96,11 +89,7 @@ pub fn saturating_add<T>(
where
T: NativeType + SaturatingAdd<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_add(&b);

Expand Down Expand Up @@ -130,11 +119,7 @@ pub fn overflowing_add<T>(
where
T: NativeType + OverflowingAdd<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_add(&b);

Expand Down
31 changes: 31 additions & 0 deletions src/compute/arithmetics/basic/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::array::{Array, PrimitiveArray};
use crate::error::{ArrowError, Result};
use crate::types::NativeType;

// Checking if both arrays have the same type
#[inline]
pub fn check_same_type<L: NativeType, R: NativeType>(
lhs: &PrimitiveArray<L>,
rhs: &PrimitiveArray<R>,
) -> Result<()> {
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
Ok(())
}

// Checking if both arrays have the same length
#[inline]
pub fn check_same_len<L: NativeType, R: NativeType>(
lhs: &PrimitiveArray<L>,
rhs: &PrimitiveArray<R>,
) -> Result<()> {
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
Ok(())
}
21 changes: 5 additions & 16 deletions src/compute/arithmetics/basic/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use std::ops::Div;

use num_traits::{CheckedDiv, NumCast, Zero};

use crate::compute::arithmetics::basic::{check_same_len, check_same_type};
use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
arithmetics::{ArrayCheckedDiv, ArrayDiv, NotI128},
arity::{binary, binary_checked, unary, unary_checked},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};
use strength_reduce::{
Expand All @@ -35,20 +36,12 @@ pub fn div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Div<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

if rhs.null_count() == 0 {
binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b)
} else {
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;
let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) {
(Some(l), Some(r)) => Some(*l / *r),
_ => None,
Expand Down Expand Up @@ -77,11 +70,7 @@ pub fn checked_div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedDiv<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_div(&b);

Expand Down
3 changes: 3 additions & 0 deletions src/compute/arithmetics/basic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ mod rem;
pub use rem::*;
mod sub;
pub use sub::*;

mod common;
pub(crate) use common::*;
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Mul;

use num_traits::{ops::overflowing::OverflowingMul, CheckedMul, SaturatingMul, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn mul<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Mul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b)
}
Expand All @@ -64,11 +61,7 @@ pub fn checked_mul<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedMul<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_mul(&b);

Expand Down Expand Up @@ -97,11 +90,7 @@ pub fn saturating_mul<T>(
where
T: NativeType + SaturatingMul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_mul(&b);

Expand Down Expand Up @@ -131,11 +120,7 @@ pub fn overflowing_mul<T>(
where
T: NativeType + OverflowingMul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_mul(&b);

Expand Down
15 changes: 4 additions & 11 deletions src/compute/arithmetics/basic/rem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::ops::Rem;

use num_traits::{CheckedRem, NumCast, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
arithmetics::{ArrayCheckedRem, ArrayRem, NotI128},
arity::{binary, binary_checked, unary, unary_checked},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};
use strength_reduce::{
Expand All @@ -34,11 +35,7 @@ pub fn rem<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Rem<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b)
}
Expand All @@ -62,11 +59,7 @@ pub fn checked_rem<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedRem<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_rem(&b);

Expand Down
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Sub;

use num_traits::{ops::overflowing::OverflowingSub, CheckedSub, SaturatingSub, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn sub<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Sub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b)
}
Expand All @@ -63,11 +60,7 @@ pub fn checked_sub<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedSub<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_sub(&b);

Expand Down Expand Up @@ -96,11 +89,7 @@ pub fn saturating_sub<T>(
where
T: NativeType + SaturatingSub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_sub(&b);

Expand Down Expand Up @@ -130,11 +119,7 @@ pub fn overflowing_sub<T>(
where
T: NativeType + OverflowingSub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_sub(&b);

Expand Down
8 changes: 2 additions & 6 deletions src/compute/arithmetics/decimal/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

//! Defines the addition arithmetic kernels for Decimal `PrimitiveArrays`.
use crate::compute::arithmetics::basic::check_same_len;
use crate::{
array::{Array, PrimitiveArray},
buffer::Buffer,
Expand Down Expand Up @@ -253,12 +254,7 @@ pub fn adaptive_add(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
// Checking if both arrays have the same length
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;

if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
Expand Down
8 changes: 2 additions & 6 deletions src/compute/arithmetics/decimal/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Defines the division arithmetic kernels for Decimal
//! `PrimitiveArrays`.
use crate::compute::arithmetics::basic::check_same_len;
use crate::{
array::{Array, PrimitiveArray},
buffer::Buffer,
Expand Down Expand Up @@ -272,12 +273,7 @@ pub fn adaptive_div(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
// Checking if both arrays have the same length
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;

if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
Expand Down
Loading

0 comments on commit f4624da

Please sign in to comment.