diff --git a/src/compute/arithmetics/decimal/add.rs b/src/compute/arithmetics/decimal/add.rs index 958b737076a..078316f7cf9 100644 --- a/src/compute/arithmetics/decimal/add.rs +++ b/src/compute/arithmetics/decimal/add.rs @@ -1,21 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines the addition arithmetic kernels for Decimal `PrimitiveArrays`. +//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. use crate::compute::arithmetics::basic::check_same_len; use crate::{ array::{Array, PrimitiveArray}, @@ -31,12 +14,14 @@ use crate::{ error::{ArrowError, Result}, }; -use super::{adjusted_precision_scale, max_value, number_digits}; +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; -/// Adds two decimal primitive arrays with the same precision and scale. If the -/// precision and scale is different, then an InvalidArgumentError is returned. -/// This function panics if the added numbers result in a number larger than -/// the possible number for the selected precision. +/// Adds two decimal [`PrimitiveArray`] with the same precision and scale. +/// # Error +/// Errors if the precision and scale are different. +/// # Panic +/// This function panics iff the added numbers result in a number larger than +/// the possible number for the precision. /// /// # Examples /// ``` @@ -53,38 +38,22 @@ use super::{adjusted_precision_scale, max_value, number_digits}; /// assert_eq!(result, expected); /// ``` pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result> { - // Matching on both data types from both arrays - // This match will be true only when precision and scale from both - // arrays are the same, otherwise it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. This closure will panic if - // the sum of the values is larger than the max value possible - // for the decimal precision - let op = move |a, b| { - let res: i128 = a + b; + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - assert!( - !(res.abs() > max_value(*lhs_p)), - "Overflow in addition presented for precision {}", - lhs_p - ); + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; - res - }; + assert!( + res.abs() <= max, + "Overflow in addition presented for precision {}", + precision + ); - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } - } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Saturated addition of two decimal primitive arrays with the same precision @@ -111,40 +80,24 @@ pub fn saturating_add( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a, b| { - let res: i128 = a + b; - let max = max_value(*lhs_p); + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - } - _ => res, - } - }; + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; - binary(lhs, rhs, lhs.data_type().clone(), op) + if res.abs() > max { + if res > 0 { + max } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) + -max } + } else { + res } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Checked addition of two decimal primitive arrays with the same precision @@ -171,33 +124,20 @@ pub fn checked_add( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a, b| { - let res: i128 = a + b; + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - match res { - res if res.abs() > max_value(*lhs_p) => None, - _ => Some(res), - } - }; + let max = max_value(precision); + let op = move |a, b| { + let result: i128 = a + b; - binary_checked(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } + if result.abs() > max { + None + } else { + Some(result) } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) } // Implementation of ArrayAdd trait for PrimitiveArrays @@ -259,14 +199,16 @@ pub fn adaptive_add( // looping through the iterator let (mut res_p, res_s, diff) = adjusted_precision_scale(*lhs_p, *lhs_s, *rhs_p, *rhs_s); - let mut result = Vec::new(); - for (l, r) in lhs.values().iter().zip(rhs.values().iter()) { + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| { // Based on the array's scales one of the arguments in the sum has to be shifted // to the left to match the final scale let res = if lhs_s > rhs_s { - l + r * 10i128.pow(diff as u32) + l + r * shift } else { - l * 10i128.pow(diff as u32) + r + l * shift + r }; // The precision of the resulting array will change if one of the @@ -277,15 +219,15 @@ pub fn adaptive_add( // 00.0001 -> 6, 4 // ----------------- // 100.0000 -> 7, 4 - if res.abs() > max_value(res_p) { + if res.abs() > max { res_p = number_digits(res); + max = max_value(res_p); } - - result.push(res); - } + res + }); + let values = Buffer::from_trusted_len_iter(iter); let validity = combine_validities(lhs.validity(), rhs.validity()); - let values = Buffer::from(result); Ok(PrimitiveArray::::from_data( DataType::Decimal(res_p, res_s), diff --git a/src/compute/arithmetics/decimal/div.rs b/src/compute/arithmetics/decimal/div.rs index 961fb1988e0..515aced7977 100644 --- a/src/compute/arithmetics/decimal/div.rs +++ b/src/compute/arithmetics/decimal/div.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Defines the division arithmetic kernels for Decimal //! `PrimitiveArrays`. @@ -27,13 +10,11 @@ use crate::{ arity::{binary, binary_checked}, utils::combine_validities, }, -}; -use crate::{ datatypes::DataType, error::{ArrowError, Result}, }; -use super::{adjusted_precision_scale, max_value, number_digits}; +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// Divide two decimal primitive arrays with the same precision and scale. If /// the precision and scale is different, then an InvalidArgumentError is @@ -56,50 +37,35 @@ use super::{adjusted_precision_scale, max_value, number_digits}; /// assert_eq!(result, expected); /// ``` pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result> { - // Matching on both data types from both arrays - // This match will be true only when precision and scale from both - // arrays are the same, otherwise it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. This closure will panic if - // the sum of the values is larger than the max value possible - // for the decimal precision - let op = move |a: i128, b: i128| { - // The division is done using the numbers without scale. - // The dividend is scaled up to maintain precision after the - // division + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - // 222.222 --> 222222000 - // 123.456 --> 123456 - // -------- --------- - // 1.800 <-- 1800 - let numeral: i128 = a * 10i128.pow(*lhs_s as u32); + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + let op = move |a: i128, b: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division - // The division can overflow if the dividend is divided - // by zero. - let res: i128 = numeral.checked_div(b).expect("Found division by zero"); + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; - assert!( - !(res.abs() > max_value(*lhs_p)), - "Overflow in multiplication presented for precision {}", - lhs_p - ); + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(b).expect("Found division by zero"); - res - }; + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {}", + precision + ); - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } - } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Saturated division of two decimal primitive arrays with the same @@ -127,46 +93,30 @@ pub fn saturating_div( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a: i128, b: i128| { - let numeral: i128 = a * 10i128.pow(*lhs_s as u32); + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - match numeral.checked_div(b) { - Some(res) => { - let max = max_value(*lhs_p); + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - } - _ => res, - } - } - None => 0, - } - }; + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + } + _ => res, + }, + None => 0, } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Checked division of two decimal primitive arrays with the same precision @@ -192,36 +142,24 @@ pub fn checked_div( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a: i128, b: i128| { - let numeral: i128 = a * 10i128.pow(*lhs_s as u32); + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - match numeral.checked_div(b) { - Some(res) => match res { - res if res.abs() > max_value(*lhs_p) => None, - _ => Some(res), - }, - None => None, - } - }; + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); - binary_checked(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => None, + _ => Some(res), + }, + None => None, } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) } // Implementation of ArrayDiv trait for PrimitiveArrays @@ -277,21 +215,21 @@ pub fn adaptive_div( // looping through the iterator let (mut res_p, res_s, diff) = adjusted_precision_scale(*lhs_p, *lhs_s, *rhs_p, *rhs_s); - let mut result = Vec::new(); - for (l, r) in lhs.values().iter().zip(rhs.values().iter()) { - let numeral: i128 = l * 10i128.pow(res_s as u32); + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| { + let numeral: i128 = l * shift_1; // Based on the array's scales one of the arguments in the sum has to be shifted // to the left to match the final scale let res = if lhs_s > rhs_s { - numeral - .checked_div(r * 10i128.pow(diff as u32)) - .expect("Found division by zero") + numeral.checked_div(r * shift) } else { - (numeral * 10i128.pow(diff as u32)) - .checked_div(*r) - .expect("Found division by zero") - }; + (numeral * shift).checked_div(*r) + } + .expect("Found division by zero"); // The precision of the resulting array will change if one of the // multiplications during the iteration produces a value bigger @@ -301,15 +239,16 @@ pub fn adaptive_div( // 00.1000 -> 6, 4 // ----------------- // 100.0000 -> 7, 4 - if res.abs() > max_value(res_p) { + if res.abs() > max { res_p = number_digits(res); + max = max_value(res_p); } - result.push(res); - } + res + }); + let values = Buffer::from_trusted_len_iter(iter); let validity = combine_validities(lhs.validity(), rhs.validity()); - let values = Buffer::from(result); Ok(PrimitiveArray::::from_data( DataType::Decimal(res_p, res_s), diff --git a/src/compute/arithmetics/decimal/mod.rs b/src/compute/arithmetics/decimal/mod.rs index d124fc9a104..c3b69468756 100644 --- a/src/compute/arithmetics/decimal/mod.rs +++ b/src/compute/arithmetics/decimal/mod.rs @@ -8,6 +8,9 @@ pub mod div; pub mod mul; pub mod sub; +use crate::datatypes::DataType; +use crate::error::{ArrowError, Result}; + /// Maximum value that can exist with a selected precision #[inline] fn max_value(precision: usize) -> i128 { @@ -28,6 +31,22 @@ fn number_digits(num: i128) -> usize { digit as usize } +fn get_parameters(lhs: &DataType, rhs: &DataType) -> Result<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + Err(ArrowError::InvalidArgumentError( + "Arrays must have the same precision and scale".to_string(), + )) + } + } else { + unreachable!() + } +} + /// Returns the adjusted precision and scale for the lhs and rhs precision and /// scale fn adjusted_precision_scale( diff --git a/src/compute/arithmetics/decimal/mul.rs b/src/compute/arithmetics/decimal/mul.rs index cf4f6ac0c8f..eab858099d2 100644 --- a/src/compute/arithmetics/decimal/mul.rs +++ b/src/compute/arithmetics/decimal/mul.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Defines the multiplication arithmetic kernels for Decimal //! `PrimitiveArrays`. @@ -31,7 +14,7 @@ use crate::{ error::{ArrowError, Result}, }; -use super::{adjusted_precision_scale, max_value, number_digits}; +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// Multiply two decimal primitive arrays with the same precision and scale. If /// the precision and scale is different, then an InvalidArgumentError is @@ -53,51 +36,37 @@ use super::{adjusted_precision_scale, max_value, number_digits}; /// assert_eq!(result, expected); /// ``` pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result> { - // Matching on both data types from both arrays - // This match will be true only when precision and scale from both - // arrays are the same, otherwise it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. This closure will panic if - // the sum of the values is larger than the max value possible - // for the decimal precision - let op = move |a: i128, b: i128| { - // The multiplication between i128 can overflow if they are - // very large numbers. For that reason a checked - // multiplication is used. - let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); - - // The multiplication is done using the numbers without scale. - // The resulting scale of the value has to be corrected by - // dividing by (10^scale) - - // 111.111 --> 111111 - // 222.222 --> 222222 - // -------- ------- - // 24691.308 <-- 24691308642 - let res = res / 10i128.pow(*lhs_s as u32); - - assert!( - !(res.abs() > max_value(*lhs_p)), - "Overflow in multiplication presented for precision {}", - lhs_p - ); - - res - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } - } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {}", + precision + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Saturated multiplication of two decimal primitive arrays with the same @@ -125,43 +94,30 @@ pub fn saturating_mul( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a: i128, b: i128| match a.checked_mul(b) { - Some(res) => { - let res = res / 10i128.pow(*lhs_s as u32); - let max = max_value(*lhs_p); - - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - } - _ => res, - } - } - None => max_value(*lhs_p), - }; + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + } + _ => res, } } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + None => max, + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Checked multiplication of two decimal primitive arrays with the same @@ -188,36 +144,24 @@ pub fn checked_mul( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a: i128, b: i128| match a.checked_mul(b) { - Some(res) => { - let res = res / 10i128.pow(*lhs_s as u32); - - match res { - res if res.abs() > max_value(*lhs_p) => None, - _ => Some(res), - } - } - None => None, - }; + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; - binary_checked(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => None, + _ => Some(res), } } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + None => None, + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) } // Implementation of ArrayMul trait for PrimitiveArrays @@ -280,20 +224,23 @@ pub fn adaptive_mul( // looping through the iterator let (mut res_p, res_s, diff) = adjusted_precision_scale(*lhs_p, *lhs_s, *rhs_p, *rhs_s); - let mut result = Vec::new(); - for (l, r) in lhs.values().iter().zip(rhs.values().iter()) { + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| { // Based on the array's scales one of the arguments in the sum has to be shifted // to the left to match the final scale let res = if lhs_s > rhs_s { - l.checked_mul(r * 10i128.pow(diff as u32)) + l.checked_mul(r * shift) .expect("Mayor overflow for multiplication") } else { - (l * 10i128.pow(diff as u32)) + (l * shift) .checked_mul(*r) .expect("Mayor overflow for multiplication") }; - let res = res / 10i128.pow(res_s as u32); + let res = res / shift_1; // The precision of the resulting array will change if one of the // multiplications during the iteration produces a value bigger @@ -303,15 +250,16 @@ pub fn adaptive_mul( // 10.0000 -> 6, 4 // ----------------- // 100.0000 -> 7, 4 - if res.abs() > max_value(res_p) { + if res.abs() > max { res_p = number_digits(res); + max = max_value(res_p); } - result.push(res); - } + res + }); + let values = Buffer::from_trusted_len_iter(iter); let validity = combine_validities(lhs.validity(), rhs.validity()); - let values = Buffer::from(result); Ok(PrimitiveArray::::from_data( DataType::Decimal(res_p, res_s), diff --git a/src/compute/arithmetics/decimal/sub.rs b/src/compute/arithmetics/decimal/sub.rs index 6fb40a36ac7..a81add11653 100644 --- a/src/compute/arithmetics/decimal/sub.rs +++ b/src/compute/arithmetics/decimal/sub.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. use crate::compute::arithmetics::basic::check_same_len; @@ -30,7 +13,7 @@ use crate::{ error::{ArrowError, Result}, }; -use super::{adjusted_precision_scale, max_value, number_digits}; +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// Subtract two decimal primitive arrays with the same precision and scale. If /// the precision and scale is different, then an InvalidArgumentError is @@ -52,38 +35,23 @@ use super::{adjusted_precision_scale, max_value, number_digits}; /// assert_eq!(result, expected); /// ``` pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result> { - // Matching on both data types from both arrays This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. This closure will panic if - // the sum of the values is larger than the max value possible - // for the decimal precision - let op = move |a, b| { - let res: i128 = a - b; + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - assert!( - !(res.abs() > max_value(*lhs_p)), - "Overflow in subtract presented for precision {}", - lhs_p - ); + let max = max_value(precision); - res - }; + let op = move |a, b| { + let res: i128 = a - b; - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } - } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + assert!( + res.abs() <= max, + "Overflow in subtract presented for precision {}", + precision + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } /// Saturated subtraction of two decimal primitive arrays with the same @@ -110,40 +78,26 @@ pub fn saturating_sub( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a, b| { - let res: i128 = a - b; - let max: i128 = max_value(*lhs_p); + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - } - _ => res, - } - }; + let max = max_value(precision); - binary(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } } + _ => res, } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) } // Implementation of ArraySub trait for PrimitiveArrays @@ -190,33 +144,20 @@ pub fn checked_sub( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { - // Matching on both data types from both arrays. This match will be true - // only when precision and scale from both arrays are the same, otherwise - // it will return and ArrowError - match (lhs.data_type(), rhs.data_type()) { - (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) => { - if lhs_p == rhs_p && lhs_s == rhs_s { - // Closure for the binary operation. - let op = move |a, b| { - let res: i128 = a - b; + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type())?; - match res { - res if res.abs() > max_value(*lhs_p) => None, - _ => Some(res), - } - }; + let max = max_value(precision); - binary_checked(lhs, rhs, lhs.data_type().clone(), op) - } else { - Err(ArrowError::InvalidArgumentError( - "Arrays must have the same precision and scale".to_string(), - )) - } + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => None, + _ => Some(res), } - _ => Err(ArrowError::InvalidArgumentError( - "Incorrect data type for the array".to_string(), - )), - } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) } /// Adaptive subtract of two decimal primitive arrays with different precision @@ -257,14 +198,16 @@ pub fn adaptive_sub( // looping through the iterator let (mut res_p, res_s, diff) = adjusted_precision_scale(*lhs_p, *lhs_s, *rhs_p, *rhs_s); - let mut result = Vec::new(); - for (l, r) in lhs.values().iter().zip(rhs.values().iter()) { + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| { // Based on the array's scales one of the arguments in the sum has to be shifted // to the left to match the final scale let res: i128 = if lhs_s > rhs_s { - l - r * 10i128.pow(diff as u32) + l - r * shift } else { - l * 10i128.pow(diff as u32) - r + l * shift - r }; // The precision of the resulting array will change if one of the @@ -275,15 +218,16 @@ pub fn adaptive_sub( // 00.0001 -> 6, 4 // ----------------- // -100.0000 -> 7, 4 - if res.abs() > max_value(res_p) { + if res.abs() > max { res_p = number_digits(res); + max = max_value(res_p); } - result.push(res); - } + res + }); + let values = Buffer::from_trusted_len_iter(iter); let validity = combine_validities(lhs.validity(), rhs.validity()); - let values = Buffer::from(result); Ok(PrimitiveArray::::from_data( DataType::Decimal(res_p, res_s),