diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs new file mode 100644 index 000000000000..b9c2c6dc7813 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs @@ -0,0 +1,16 @@ +use super::*; + +pub fn add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + commutative(lhs, rhs, |a, b| a + b) +} + +pub fn add_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a + b) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs new file mode 100644 index 000000000000..304f9770a0f6 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs @@ -0,0 +1,89 @@ +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; +use polars_error::*; + +use super::{get_parameters, max_value}; +use crate::compute::{binary_mut, unary_mut}; + +pub fn commutative( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + op: F, +) -> PolarsResult> +where + F: Fn(i128, i128) -> i128, +{ + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = |a, b| { + let res = op(a, b); + overflow |= res.abs() > max; + res + }; + let out = binary_mut(lhs, rhs, lhs.data_type().clone(), op); + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + Ok(out) +} + +pub fn commutative_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, + op: F, +) -> PolarsResult> +where + F: Fn(i128, i128) -> i128, +{ + let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = |a| { + let res = op(a, rhs); + overflow |= res.abs() > max; + res + }; + let out = unary_mut(lhs, op, lhs.data_type().clone()); + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + + Ok(out) +} + +pub fn non_commutative( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + op: F, +) -> PolarsResult> +where + F: Fn(i128, i128) -> i128, +{ + Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) +} + +pub fn non_commutative_scalar( + lhs: &PrimitiveArray, + rhs: i128, + op: F, +) -> PolarsResult> +where + F: Fn(i128, i128) -> i128, +{ + let op = move |a| op(a, rhs); + + Ok(unary_mut(lhs, op, lhs.data_type().clone())) +} + +pub fn non_commutative_scalar_swapped( + lhs: i128, + rhs: &PrimitiveArray, + op: F, +) -> PolarsResult> +where + F: Fn(i128, i128) -> i128, +{ + let op = move |a| op(lhs, a); + + Ok(unary_mut(rhs, op, rhs.data_type().clone())) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs new file mode 100644 index 000000000000..94c817eea735 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs @@ -0,0 +1,43 @@ +use super::*; + +#[inline] +fn decimal_div(a: i128, b: i128, scale: i128) -> i128 { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + a * scale / b +} + +pub fn div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + non_commutative(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} + +pub fn div_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; + let scale = 10i128.pow(scale as u32); + non_commutative_scalar(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} + +pub fn div_scalar_swapped( + lhs: i128, + lhs_dtype: &DataType, + rhs: &PrimitiveArray, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs_dtype, rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + non_commutative_scalar_swapped(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs new file mode 100644 index 000000000000..d74f4ddb8e78 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs @@ -0,0 +1,40 @@ +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; +use commutative::{ + commutative, commutative_scalar, non_commutative, non_commutative_scalar, + non_commutative_scalar_swapped, +}; +use polars_error::{PolarsError, PolarsResult}; + +mod add; +mod commutative; +mod div; +mod mul; +mod sub; + +pub use add::*; +pub use div::*; +pub use mul::*; +pub use sub::*; + +/// Maximum value that can exist with a selected precision +#[inline] +fn max_value(precision: usize) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + Err(PolarsError::InvalidOperation( + "Arrays must have the same precision and scale".into(), + )) + } + } else { + unreachable!() + } +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs new file mode 100644 index 000000000000..e8e22e73e2ac --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs @@ -0,0 +1,33 @@ +use super::*; + +#[inline] +fn decimal_mul(a: i128, b: i128, scale: i128) -> i128 { + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + a * b / scale +} + +pub fn mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + commutative(lhs, rhs, |a, b| decimal_mul(a, b, scale)) +} + +pub fn mul_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; + let scale = 10i128.pow(scale as u32); + commutative_scalar(lhs, rhs, rhs_dtype, |a, b| decimal_mul(a, b, scale)) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs new file mode 100644 index 000000000000..da67a8593bde --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs @@ -0,0 +1,19 @@ +use super::*; + +pub fn sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + non_commutative(lhs, rhs, |a, b| a - b) +} + +pub fn sub_scalar(lhs: &PrimitiveArray, rhs: i128) -> PolarsResult> { + non_commutative_scalar(lhs, rhs, |a, b| a - b) +} + +pub fn sub_scalar_swapped( + lhs: i128, + rhs: &PrimitiveArray, +) -> PolarsResult> { + non_commutative_scalar_swapped(lhs, rhs, |a, b| a - b) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/mod.rs b/polars/polars-arrow/src/compute/arithmetics/mod.rs new file mode 100644 index 000000000000..0abcbaba757a --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "dtype-decimal")] +pub mod decimal; diff --git a/polars/polars-arrow/src/compute/arity.rs b/polars/polars-arrow/src/compute/arity.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/polars/polars-arrow/src/compute/arity.rs @@ -0,0 +1 @@ + diff --git a/polars/polars-arrow/src/compute/mod.rs b/polars/polars-arrow/src/compute/mod.rs index ab57198868ec..ead240d11755 100644 --- a/polars/polars-arrow/src/compute/mod.rs +++ b/polars/polars-arrow/src/compute/mod.rs @@ -1,3 +1,11 @@ +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; +use arrow::types::NativeType; + +use crate::utils::combine_validities_and; + +pub mod arithmetics; +pub mod arity; pub mod bitwise; #[cfg(feature = "compute")] pub mod cast; @@ -5,3 +13,45 @@ pub mod cast; pub mod decimal; pub mod take; pub mod tile; + +#[inline] +pub fn binary_mut( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + mut op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: FnMut(T, D) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>() + .into(); + + PrimitiveArray::::new(data_type, values, validity) +} + +#[inline] +pub fn unary_mut( + array: &PrimitiveArray, + mut op: F, + data_type: DataType, +) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: FnMut(I) -> O, +{ + let values = array.values().iter().map(|v| op(*v)).collect::>(); + + PrimitiveArray::::new(data_type, values.into(), array.validity().cloned()) +} diff --git a/polars/polars-core/src/chunked_array/arithmetic.rs b/polars/polars-core/src/chunked_array/arithmetic.rs deleted file mode 100644 index cddf5882c3d2..000000000000 --- a/polars/polars-core/src/chunked_array/arithmetic.rs +++ /dev/null @@ -1,678 +0,0 @@ -//! Implementations of arithmetic operations on ChunkedArray's. -use std::ops::{Add, Div, Mul, Rem, Sub}; - -use arrow::array::PrimitiveArray; -use arrow::compute::arithmetics::basic; -#[cfg(feature = "dtype-decimal")] -use arrow::compute::arithmetics::decimal; -use arrow::compute::arity_assign; -use arrow::types::NativeType; -use num_traits::{Num, NumCast, ToPrimitive, Zero}; -use polars_arrow::utils::combine_validities_and; - -use crate::prelude::*; -use crate::series::IsSorted; -use crate::utils::{align_chunks_binary, align_chunks_binary_owned}; - -pub trait ArrayArithmetics -where - Self: NativeType, -{ - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; -} - -macro_rules! native_array_arithmetics { - ($ty: ty) => { - impl ArrayArithmetics for $ty - { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::add(lhs, rhs) - } - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::sub(lhs, rhs) - } - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::mul(lhs, rhs) - } - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::div(lhs, rhs) - } - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::div_scalar(lhs, rhs) - } - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::rem(lhs, rhs) - } - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::rem_scalar(lhs, rhs) - } - } - }; - ($($ty:ty),*) => { - $(native_array_arithmetics!($ty);)* - } -} - -native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); - -#[cfg(feature = "dtype-decimal")] -impl ArrayArithmetics for i128 { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::add(lhs, rhs) - } - - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::sub(lhs, rhs) - } - - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::mul(lhs, rhs) - } - - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::div(lhs, rhs) - } - - fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - // decimal::div_scalar(lhs, rhs) - todo!("decimal::div_scalar exists, but takes &PrimitiveScalar, not &i128"); - } - - fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } - - fn rem_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } -} - -pub(super) fn arithmetic_helper( - lhs: &ChunkedArray, - rhs: &ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, - F: Fn(T::Native, T::Native) -> T::Native, -{ - let mut ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (lhs, rhs) = align_chunks_binary(lhs, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(lhs, rhs)| Box::new(kernel(lhs, rhs)) as ArrayRef) - .collect(); - lhs.copy_with_chunks(chunks, false, false) - } - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => lhs.apply(|lhs| operation(lhs, rhs)), - } - } - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs) => rhs.apply(|rhs| operation(lhs, rhs)), - } - } - _ => panic!("Cannot apply operation on arrays of different lengths"), - }; - ca.rename(lhs.name()); - ca -} - -/// This assigns to the owned buffer if the ref count is 1 -fn arithmetic_helper_owned( - mut lhs: ChunkedArray, - mut rhs: ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), - F: Fn(T::Native, T::Native) -> T::Native, -{ - let ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs); - // safety, we do no t change the lengths - unsafe { - lhs.downcast_iter_mut() - .zip(rhs.downcast_iter_mut()) - .for_each(|(lhs, rhs)| kernel(lhs, rhs)); - } - lhs.set_sorted_flag(IsSorted::Not); - lhs - } - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => { - lhs.apply_mut(|lhs| operation(lhs, rhs)); - lhs - } - } - } - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs_val) => { - rhs.apply_mut(|rhs| operation(lhs_val, rhs)); - rhs.rename(lhs.name()); - rhs - } - } - } - _ => panic!("Cannot apply operation on arrays of different lengths"), - }; - ca -} - -// Operands on ChunkedArray & ChunkedArray - -impl Add for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::add, - |lhs, rhs| lhs + rhs, - ) - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::div, - |lhs, rhs| lhs / rhs, - ) - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::mul, - |lhs, rhs| lhs * rhs, - ) - } -} - -impl Rem for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::rem, - |lhs, rhs| lhs % rhs, - ) - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::sub, - |lhs, rhs| lhs - rhs, - ) - } -} - -impl Add for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a + b), - |lhs, rhs| lhs + rhs, - ) - } -} - -impl Div for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a / b), - |lhs, rhs| lhs / rhs, - ) - } -} - -impl Mul for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a * b), - |lhs, rhs| lhs * rhs, - ) - } -} - -impl Sub for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a - b), - |lhs, rhs| lhs - rhs, - ) - } -} - -impl Rem for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: Self) -> Self::Output { - (&self).rem(&rhs) - } -} - -// Operands on ChunkedArray & Num - -impl Add for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn add(self, rhs: N) -> Self::Output { - let adder: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply(|val| val + adder); - out.set_sorted_flag(self.is_sorted_flag()); - out - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: N) -> Self::Output { - let subber: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply(|val| val - subber); - out.set_sorted_flag(self.is_sorted_flag()); - out - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let mut out = self - .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); - - if rhs < T::Native::zero() { - out.set_sorted_flag(self.is_sorted_flag().reverse()); - } else { - out.set_sorted_flag(self.is_sorted_flag()); - } - out - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: N) -> Self::Output { - // don't set sorted flag as probability of overflow is higher - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - let rhs = ChunkedArray::from_vec("", vec![multiplier]); - self.mul(&rhs) - } -} - -impl Rem for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let rhs = ChunkedArray::from_vec("", vec![rhs]); - self.rem(&rhs) - } -} - -impl Add for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn add(self, rhs: N) -> Self::Output { - (&self).add(rhs) - } -} - -impl Sub for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: N) -> Self::Output { - (&self).sub(rhs) - } -} - -impl Div for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - (&self).div(rhs) - } -} - -impl Mul for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn mul(mut self, rhs: N) -> Self::Output { - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - self.apply_mut(|val| val * multiplier); - self - } -} - -impl Rem for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: N) -> Self::Output { - (&self).rem(rhs) - } -} - -fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { - buf.clear(); - - buf.extend_from_slice(l); - buf.extend_from_slice(r); -} - -impl Add for &Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: Self) -> Self::Output { - unsafe { (self.as_binary() + rhs.as_binary()).to_utf8() } - } -} - -impl Add for Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -impl Add<&str> for &Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: &str) -> Self::Output { - unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_utf8() } - } -} - -fn concat_binary(a: &BinaryArray, b: &BinaryArray) -> BinaryArray { - let validity = combine_validities_and(a.validity(), b.validity()); - let mut values = Vec::with_capacity(a.get_values_size() + b.get_values_size()); - let mut offsets = Vec::with_capacity(a.len() + 1); - let mut offset_so_far = 0i64; - offsets.push(offset_so_far); - - for (a, b) in a.values_iter().zip(b.values_iter()) { - values.extend_from_slice(a); - values.extend_from_slice(b); - offset_so_far = values.len() as i64; - offsets.push(offset_so_far) - } - unsafe { BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), validity) } -} - -impl Add for &BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: Self) -> Self::Output { - // broadcasting path rhs - if rhs.len() == 1 { - let rhs = rhs.get(0); - let mut buf = vec![]; - return match rhs { - Some(rhs) => { - self.apply_mut(|s| { - concat_binary_arrs(s, rhs, &mut buf); - let out = buf.as_slice(); - // safety: lifetime is bound to the outer scope and the - // ref is valid for the lifetime of this closure - unsafe { std::mem::transmute::<_, &'static [u8]>(out) } - }) - } - None => BinaryChunked::full_null(self.name(), self.len()), - }; - } - // broadcasting path lhs - if self.len() == 1 { - let lhs = self.get(0); - let mut buf = vec![]; - return match lhs { - Some(lhs) => rhs.apply_mut(|s| { - concat_binary_arrs(lhs, s, &mut buf); - - let out = buf.as_slice(); - // safety: lifetime is bound to the outer scope and the - // ref is valid for the lifetime of this closure - unsafe { std::mem::transmute::<_, &'static [u8]>(out) } - }), - None => BinaryChunked::full_null(self.name(), rhs.len()), - }; - } - - let (lhs, rhs) = align_chunks_binary(self, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(a, b)| Box::new(concat_binary(a, b)) as ArrayRef) - .collect(); - - unsafe { BinaryChunked::from_chunks(self.name(), chunks) } - } -} - -impl Add for BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -impl Add<&[u8]> for &BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: &[u8]) -> Self::Output { - let arr = BinaryArray::::from_slice([rhs]); - let rhs = unsafe { BinaryChunked::from_chunks("", vec![Box::new(arr) as ArrayRef]) }; - self.add(&rhs) - } -} - -fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray { - let validity = combine_validities_and(a.validity(), b.validity()); - - let values = a - .values_iter() - .zip(b.values_iter()) - .map(|(a, b)| a as IdxSize + b as IdxSize) - .collect::>(); - PrimitiveArray::from_data_default(values.into(), validity) -} - -impl Add for &BooleanChunked { - type Output = IdxCa; - - fn add(self, rhs: Self) -> Self::Output { - // broadcasting path rhs - if rhs.len() == 1 { - let rhs = rhs.get(0); - return match rhs { - Some(rhs) => self.apply_cast_numeric(|v| v as IdxSize + rhs as IdxSize), - None => IdxCa::full_null(self.name(), self.len()), - }; - } - // broadcasting path lhs - if self.len() == 1 { - return rhs.add(self); - } - let (lhs, rhs) = align_chunks_binary(self, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(a, b)| Box::new(add_boolean(a, b)) as ArrayRef) - .collect::>(); - - unsafe { IdxCa::from_chunks(self.name(), chunks) } - } -} - -impl Add for BooleanChunked { - type Output = IdxCa; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -#[cfg(test)] -pub(crate) mod test { - use crate::prelude::*; - - pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { - let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); - let a2 = Int32Chunked::new("a", &[4, 5, 6]); - let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); - a1.append(&a2); - (a1, a3) - } - - #[test] - #[allow(clippy::eq_op)] - fn test_chunk_mismatch() { - let (a1, a2) = create_two_chunked(); - // with different chunks - let _ = &a1 + &a2; - let _ = &a1 - &a2; - let _ = &a1 / &a2; - let _ = &a1 * &a2; - - // with same chunks - let _ = &a1 + &a1; - let _ = &a1 - &a1; - let _ = &a1 / &a1; - let _ = &a1 * &a1; - } -} diff --git a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs new file mode 100644 index 000000000000..f9e2206ecb81 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -0,0 +1,149 @@ +use polars_arrow::compute::arithmetics::decimal; + +use super::*; +use crate::prelude::DecimalChunked; + +// TODO: remove +impl ArrayArithmetics for i128 { + fn add(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() + } + + fn sub(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() + } + + fn mul(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() + } + + fn div(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() + } + + fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { + unimplemented!() + } + + fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!("requires support in arrow2 crate") + } + + fn rem_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { + unimplemented!("requires support in arrow2 crate") + } +} + +impl DecimalChunked { + fn arithmetic_helper( + &self, + rhs: &DecimalChunked, + kernel: Kernel, + operation_lhs: ScalarKernelLhs, + operation_rhs: ScalarKernelRhs, + ) -> PolarsResult + where + Kernel: + Fn(&PrimitiveArray, &PrimitiveArray) -> PolarsResult>, + ScalarKernelLhs: Fn(&PrimitiveArray, i128) -> PolarsResult>, + ScalarKernelRhs: Fn(i128, &PrimitiveArray) -> PolarsResult>, + { + let lhs = self; + + let mut ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs, rhs)| kernel(lhs, rhs).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs_val) => { + let chunks = lhs + .downcast_iter() + .map(|lhs| operation_lhs(lhs, rhs_val).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + } + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs_val) => { + let chunks = rhs + .downcast_iter() + .map(|rhs| operation_rhs(lhs_val, rhs).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + } + } + } + _ => { + polars_bail!(ComputeError: "Cannot apply operation on arrays of different lengths") + } + }; + ca.rename(lhs.name()); + Ok(ca.into_decimal_unchecked(self.precision(), self.scale())) + } +} + +impl Add for &DecimalChunked { + type Output = PolarsResult; + + fn add(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::add, + |lhs, rhs_val| decimal::add_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::add_scalar(rhs, lhs_val, &self.dtype().to_arrow()), + ) + } +} + +impl Sub for &DecimalChunked { + type Output = PolarsResult; + + fn sub(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::sub, + decimal::sub_scalar, + decimal::sub_scalar_swapped, + ) + } +} + +impl Mul for &DecimalChunked { + type Output = PolarsResult; + + fn mul(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::mul, + |lhs, rhs_val| decimal::mul_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::mul_scalar(rhs, lhs_val, &self.dtype().to_arrow()), + ) + } +} + +impl Div for &DecimalChunked { + type Output = PolarsResult; + + fn div(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::div, + |lhs, rhs_val| decimal::div_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::div_scalar_swapped(lhs_val, &self.dtype().to_arrow(), rhs), + ) + } +} diff --git a/polars/polars-core/src/chunked_array/arithmetic/mod.rs b/polars/polars-core/src/chunked_array/arithmetic/mod.rs new file mode 100644 index 000000000000..6aa0a37cb310 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/mod.rs @@ -0,0 +1,255 @@ +//! Implementations of arithmetic operations on ChunkedArray's. +#[cfg(feature = "dtype-decimal")] +mod decimal; +mod numeric; + +use std::ops::{Add, Div, Mul, Rem, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::compute::arithmetics::basic; +use arrow::compute::arity_assign; +use arrow::types::NativeType; +use num_traits::{Num, NumCast, ToPrimitive, Zero}; +pub(super) use numeric::arithmetic_helper; +use polars_arrow::utils::combine_validities_and; + +use crate::prelude::*; +use crate::series::IsSorted; +use crate::utils::{align_chunks_binary, align_chunks_binary_owned}; + +pub trait ArrayArithmetics +where + Self: NativeType, +{ + fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; + fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; +} + +macro_rules! native_array_arithmetics { + ($ty: ty) => { + impl ArrayArithmetics for $ty + { + fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::add(lhs, rhs) + } + fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::sub(lhs, rhs) + } + fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::mul(lhs, rhs) + } + fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::div(lhs, rhs) + } + fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { + basic::div_scalar(lhs, rhs) + } + fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::rem(lhs, rhs) + } + fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { + basic::rem_scalar(lhs, rhs) + } + } + }; + ($($ty:ty),*) => { + $(native_array_arithmetics!($ty);)* + } +} + +native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); + +fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { + buf.clear(); + + buf.extend_from_slice(l); + buf.extend_from_slice(r); +} + +impl Add for &Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: Self) -> Self::Output { + unsafe { (self.as_binary() + rhs.as_binary()).to_utf8() } + } +} + +impl Add for Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&str> for &Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: &str) -> Self::Output { + unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_utf8() } + } +} + +fn concat_binary(a: &BinaryArray, b: &BinaryArray) -> BinaryArray { + let validity = combine_validities_and(a.validity(), b.validity()); + let mut values = Vec::with_capacity(a.get_values_size() + b.get_values_size()); + let mut offsets = Vec::with_capacity(a.len() + 1); + let mut offset_so_far = 0i64; + offsets.push(offset_so_far); + + for (a, b) in a.values_iter().zip(b.values_iter()) { + values.extend_from_slice(a); + values.extend_from_slice(b); + offset_so_far = values.len() as i64; + offsets.push(offset_so_far) + } + unsafe { BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), validity) } +} + +impl Add for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + // broadcasting path rhs + if rhs.len() == 1 { + let rhs = rhs.get(0); + let mut buf = vec![]; + return match rhs { + Some(rhs) => { + self.apply_mut(|s| { + concat_binary_arrs(s, rhs, &mut buf); + let out = buf.as_slice(); + // safety: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }) + } + None => BinaryChunked::full_null(self.name(), self.len()), + }; + } + // broadcasting path lhs + if self.len() == 1 { + let lhs = self.get(0); + let mut buf = vec![]; + return match lhs { + Some(lhs) => rhs.apply_mut(|s| { + concat_binary_arrs(lhs, s, &mut buf); + + let out = buf.as_slice(); + // safety: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }), + None => BinaryChunked::full_null(self.name(), rhs.len()), + }; + } + + let (lhs, rhs) = align_chunks_binary(self, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(a, b)| Box::new(concat_binary(a, b)) as ArrayRef) + .collect(); + + unsafe { BinaryChunked::from_chunks(self.name(), chunks) } + } +} + +impl Add for BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&[u8]> for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: &[u8]) -> Self::Output { + let arr = BinaryArray::::from_slice([rhs]); + let rhs = unsafe { BinaryChunked::from_chunks("", vec![Box::new(arr) as ArrayRef]) }; + self.add(&rhs) + } +} + +fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray { + let validity = combine_validities_and(a.validity(), b.validity()); + + let values = a + .values_iter() + .zip(b.values_iter()) + .map(|(a, b)| a as IdxSize + b as IdxSize) + .collect::>(); + PrimitiveArray::from_data_default(values.into(), validity) +} + +impl Add for &BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + // broadcasting path rhs + if rhs.len() == 1 { + let rhs = rhs.get(0); + return match rhs { + Some(rhs) => self.apply_cast_numeric(|v| v as IdxSize + rhs as IdxSize), + None => IdxCa::full_null(self.name(), self.len()), + }; + } + // broadcasting path lhs + if self.len() == 1 { + return rhs.add(self); + } + let (lhs, rhs) = align_chunks_binary(self, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(a, b)| Box::new(add_boolean(a, b)) as ArrayRef) + .collect::>(); + + unsafe { IdxCa::from_chunks(self.name(), chunks) } + } +} + +impl Add for BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::prelude::*; + + pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { + let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); + let a2 = Int32Chunked::new("a", &[4, 5, 6]); + let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); + a1.append(&a2); + (a1, a3) + } + + #[test] + #[allow(clippy::eq_op)] + fn test_chunk_mismatch() { + let (a1, a2) = create_two_chunked(); + // with different chunks + let _ = &a1 + &a2; + let _ = &a1 - &a2; + let _ = &a1 / &a2; + let _ = &a1 * &a2; + + // with same chunks + let _ = &a1 + &a1; + let _ = &a1 - &a1; + let _ = &a1 / &a1; + let _ = &a1 * &a1; + } +} diff --git a/polars/polars-core/src/chunked_array/arithmetic/numeric.rs b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs new file mode 100644 index 000000000000..03cb849c2e59 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -0,0 +1,395 @@ +use super::*; + +pub(crate) fn arithmetic_helper( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + kernel: Kernel, + operation: F, +) -> ChunkedArray +where + T: PolarsNumericType, + Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, + F: Fn(T::Native, T::Native) -> T::Native, +{ + let mut ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs, rhs)| Box::new(kernel(lhs, rhs)) as ArrayRef) + .collect(); + lhs.copy_with_chunks(chunks, false, false) + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs) => lhs.apply(|lhs| operation(lhs, rhs)), + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs) => rhs.apply(|rhs| operation(lhs, rhs)), + } + } + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + ca.rename(lhs.name()); + ca +} + +/// This assigns to the owned buffer if the ref count is 1 +fn arithmetic_helper_owned( + mut lhs: ChunkedArray, + mut rhs: ChunkedArray, + kernel: Kernel, + operation: F, +) -> ChunkedArray +where + T: PolarsNumericType, + Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), + F: Fn(T::Native, T::Native) -> T::Native, +{ + let ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs); + // safety, we do no t change the lengths + unsafe { + lhs.downcast_iter_mut() + .zip(rhs.downcast_iter_mut()) + .for_each(|(lhs, rhs)| kernel(lhs, rhs)); + } + lhs.set_sorted_flag(IsSorted::Not); + lhs + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs) => { + lhs.apply_mut(|lhs| operation(lhs, rhs)); + lhs + } + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs_val) => { + rhs.apply_mut(|rhs| operation(lhs_val, rhs)); + rhs.rename(lhs.name()); + rhs + } + } + } + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + ca +} + +// Operands on ChunkedArray & ChunkedArray + +impl Add for &ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn add(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::add, + |lhs, rhs| lhs + rhs, + ) + } +} + +impl Div for &ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn div(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::div, + |lhs, rhs| lhs / rhs, + ) + } +} + +impl Mul for &ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn mul(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::mul, + |lhs, rhs| lhs * rhs, + ) + } +} + +impl Rem for &ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::rem, + |lhs, rhs| lhs % rhs, + ) + } +} + +impl Sub for &ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::sub, + |lhs, rhs| lhs - rhs, + ) + } +} + +impl Add for ChunkedArray +where + T: PolarsNumericType, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a + b), + |lhs, rhs| lhs + rhs, + ) + } +} + +impl Div for ChunkedArray +where + T: PolarsNumericType, +{ + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a / b), + |lhs, rhs| lhs / rhs, + ) + } +} + +impl Mul for ChunkedArray +where + T: PolarsNumericType, +{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a * b), + |lhs, rhs| lhs * rhs, + ) + } +} + +impl Sub for ChunkedArray +where + T: PolarsNumericType, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a - b), + |lhs, rhs| lhs - rhs, + ) + } +} + +impl Rem for ChunkedArray +where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: Self) -> Self::Output { + (&self).rem(&rhs) + } +} + +// Operands on ChunkedArray & Num + +impl Add for &ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn add(self, rhs: N) -> Self::Output { + let adder: T::Native = NumCast::from(rhs).unwrap(); + let mut out = self.apply(|val| val + adder); + out.set_sorted_flag(self.is_sorted_flag()); + out + } +} + +impl Sub for &ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: N) -> Self::Output { + let subber: T::Native = NumCast::from(rhs).unwrap(); + let mut out = self.apply(|val| val - subber); + out.set_sorted_flag(self.is_sorted_flag()); + out + } +} + +impl Div for &ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn div(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); + let mut out = self + .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); + + if rhs < T::Native::zero() { + out.set_sorted_flag(self.is_sorted_flag().reverse()); + } else { + out.set_sorted_flag(self.is_sorted_flag()); + } + out + } +} + +impl Mul for &ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn mul(self, rhs: N) -> Self::Output { + // don't set sorted flag as probability of overflow is higher + let multiplier: T::Native = NumCast::from(rhs).unwrap(); + let rhs = ChunkedArray::from_vec("", vec![multiplier]); + self.mul(&rhs) + } +} + +impl Rem for &ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); + let rhs = ChunkedArray::from_vec("", vec![rhs]); + self.rem(&rhs) + } +} + +impl Add for ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn add(self, rhs: N) -> Self::Output { + (&self).add(rhs) + } +} + +impl Sub for ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: N) -> Self::Output { + (&self).sub(rhs) + } +} + +impl Div for ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn div(self, rhs: N) -> Self::Output { + (&self).div(rhs) + } +} + +impl Mul for ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn mul(mut self, rhs: N) -> Self::Output { + let multiplier: T::Native = NumCast::from(rhs).unwrap(); + self.apply_mut(|val| val * multiplier); + self + } +} + +impl Rem for ChunkedArray +where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: N) -> Self::Output { + (&self).rem(rhs) + } +} diff --git a/polars/polars-core/src/datatypes/dtype.rs b/polars/polars-core/src/datatypes/dtype.rs index 76b749c85b54..2e9538ddf797 100644 --- a/polars/polars-core/src/datatypes/dtype.rs +++ b/polars/polars-core/src/datatypes/dtype.rs @@ -160,7 +160,7 @@ impl DataType { self.is_numeric() | matches!(self, DataType::Boolean | DataType::Utf8 | DataType::Binary) } - /// Check if this [`DataType`] is a numeric type + /// Check if this [`DataType`] is a numeric type. pub fn is_numeric(&self) -> bool { // allow because it cannot be replaced when object feature is activated #[allow(clippy::match_like_matches_macro)] @@ -181,6 +181,8 @@ impl DataType { DataType::Categorical(_) => false, #[cfg(feature = "dtype-struct")] DataType::Struct(_) => false, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => false, _ => true, } } diff --git a/polars/polars-core/src/series/implementations/decimal.rs b/polars/polars-core/src/series/implementations/decimal.rs index 125bfee5a285..8f5f9bdde833 100644 --- a/polars/polars-core/src/series/implementations/decimal.rs +++ b/polars/polars-core/src/series/implementations/decimal.rs @@ -37,6 +37,22 @@ impl private::PrivateSeries for SeriesWrap { .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } + fn subtract(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) - rhs).map(|ca| ca.into_series()) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) + rhs).map(|ca| ca.into_series()) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) * rhs).map(|ca| ca.into_series()) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) / rhs).map(|ca| ca.into_series()) + } } impl SeriesTrait for SeriesWrap { diff --git a/polars/polars-lazy/src/physical_plan/expressions/binary.rs b/polars/polars-lazy/src/physical_plan/expressions/binary.rs index e04140296c2a..06c39265cb30 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/binary.rs @@ -56,6 +56,8 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu Operator::Multiply => Ok(left * right), Operator::Divide => Ok(left / right), Operator::TrueDivide => match left.dtype() { + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Ok(left / right), Date | Datetime(_, _) | Float32 | Float64 => Ok(left / right), _ => Ok(&left.cast(&Float64)? / &right.cast(&Float64)?), }, diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 48bea4f6c436..0de98fe4be26 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -134,3 +134,32 @@ def test_read_csv_decimal(monkeypatch: Any) -> None: D("1.10000000000000000000"), D("0.01000000000000000000"), ] + + +def test_decimal_arithmetic() -> None: + df = pl.DataFrame( + { + "a": [D("0.1"), D("10.1"), D("100.01")], + "b": [D("20.1"), D("10.19"), D("39.21")], + } + ) + + out = df.select( + out1=pl.col("a") * pl.col("b"), + out2=pl.col("a") + pl.col("b"), + out3=pl.col("a") / pl.col("b"), + out4=pl.col("a") - pl.col("b"), + ) + assert out.dtypes == [ + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + ] + + assert out.to_dict(False) == { + "out1": [D("2.01"), D("102.91"), D("3921.39")], + "out2": [D("20.20"), D("20.29"), D("139.22")], + "out3": [D("0.00"), D("0.99"), D("2.55")], + "out4": [D("-20.00"), D("-0.09"), D("60.80")], + }