Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Add *_scalar for some arithmetics kernels (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Dec 2, 2021
1 parent f086986 commit 3316a4d
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 4 deletions.
46 changes: 45 additions & 1 deletion src/compute/arithmetics/decimal/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::{
buffer::Buffer,
compute::{
arithmetics::{ArrayCheckedDiv, ArrayDiv},
arity::{binary, binary_checked},
arity::{binary, binary_checked, unary},
utils::{check_same_len, combine_validities},
},
datatypes::DataType,
error::{ArrowError, Result},
scalar::{PrimitiveScalar, Scalar},
};

use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
Expand Down Expand Up @@ -67,6 +68,49 @@ pub fn div(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveA
binary(lhs, rhs, lhs.data_type().clone(), op)
}

/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If
/// the precision and scale is different, then an InvalidArgumentError is
/// returned. This function panics if the multiplied numbers result in a number
/// larger than the possible number for the selected precision.
pub fn div_scalar(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveScalar<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return PrimitiveArray::<i128>::new_null(lhs.data_type().clone(), lhs.len());
};

let scale = 10i128.pow(scale as u32);
let max = max_value(precision);

let op = move |a: i128| {
// The division is done using the numbers without scale.
// The dividend is scaled up to maintain precision after the
// division

// 222.222 --> 222222000
// 123.456 --> 123456
// -------- ---------
// 1.800 <-- 1800
let numeral: i128 = a * scale;

// The division can overflow if the dividend is divided
// by zero.
let res: i128 = numeral.checked_div(rhs).expect("Found division by zero");

assert!(
res.abs() <= max,
"Overflow in multiplication presented for precision {}",
precision
);

res
};

unary(lhs, op, lhs.data_type().clone())
}

/// Saturated division of two decimal primitive arrays with the same
/// precision and scale. If the precision and scale is different, then an
/// InvalidArgumentError is returned. If the result from the division is
Expand Down
49 changes: 48 additions & 1 deletion src/compute/arithmetics/decimal/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::{
buffer::Buffer,
compute::{
arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul},
arity::{binary, binary_checked},
arity::{binary, binary_checked, unary},
utils::{check_same_len, combine_validities},
},
datatypes::DataType,
error::{ArrowError, Result},
scalar::{PrimitiveScalar, Scalar},
};

use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
Expand Down Expand Up @@ -68,6 +69,52 @@ pub fn mul(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveA
binary(lhs, rhs, lhs.data_type().clone(), op)
}

/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If
/// the precision and scale is different, then an InvalidArgumentError is
/// returned. This function panics if the multiplied numbers result in a number
/// larger than the possible number for the selected precision.
pub fn mul_scalar(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveScalar<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return PrimitiveArray::<i128>::new_null(lhs.data_type().clone(), lhs.len());
};

let scale = 10i128.pow(scale as u32);
let max = max_value(precision);

let op = move |a: i128| {
// The multiplication between i128 can overflow if they are
// very large numbers. For that reason a checked
// multiplication is used.
let res: i128 = a
.checked_mul(rhs)
.expect("Mayor overflow for multiplication");

// The multiplication is done using the numbers without scale.
// The resulting scale of the value has to be corrected by
// dividing by (10^scale)

// 111.111 --> 111111
// 222.222 --> 222222
// -------- -------
// 24691.308 <-- 24691308642
let res = res / scale;

assert!(
res.abs() <= max,
"Overflow in multiplication presented for precision {}",
precision
);

res
};

unary(lhs, op, lhs.data_type().clone())
}

/// Saturated multiplication of two decimal primitive arrays with the same
/// precision and scale. If the precision and scale is different, then an
/// InvalidArgumentError is returned. If the result from the multiplication is
Expand Down
49 changes: 49 additions & 0 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
array::Array,
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::Scalar,
types::NativeType,
};

Expand Down Expand Up @@ -117,6 +118,22 @@ pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
)
}

/// Adds an [`Array`] and a [`Scalar`].
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_add`] to check)
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
lhs,
rhs,
add_scalar,
duration = add_duration_scalar,
interval = add_interval_scalar
)
}

/// Returns whether two [`DataType`]s can be added by [`add`].
pub fn can_add(lhs: &DataType, rhs: &DataType) -> bool {
use DataType::*;
Expand Down Expand Up @@ -162,6 +179,22 @@ pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
)
}

/// Adds an [`Array`] and a [`Scalar`].
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_sub`] to check)
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
lhs,
rhs,
sub_scalar,
duration = sub_duration_scalar,
timestamp = sub_timestamps_scalar
)
}

/// Returns whether two [`DataType`]s can be subtracted by [`sub`].
pub fn can_sub(lhs: &DataType, rhs: &DataType) -> bool {
use DataType::*;
Expand Down Expand Up @@ -199,6 +232,14 @@ pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
arith!(lhs, rhs, mul, decimal = mul)
}

/// Multiply an [`Array`] with a [`Scalar`].
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_mul`] to check)
pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, mul_scalar, decimal = mul_scalar)
}

/// Returns whether two [`DataType`]s can be multiplied by [`mul`].
pub fn can_mul(lhs: &DataType, rhs: &DataType) -> bool {
use DataType::*;
Expand Down Expand Up @@ -227,6 +268,14 @@ pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
arith!(lhs, rhs, div, decimal = div)
}

/// Divide an [`Array`] with a [`Scalar`].
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_div`] to check)
pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, div_scalar, decimal = div_scalar)
}

/// Returns whether two [`DataType`]s can be divided by [`div`].
pub fn can_div(lhs: &DataType, rhs: &DataType) -> bool {
can_mul(lhs, rhs)
Expand Down
149 changes: 148 additions & 1 deletion src/compute/arithmetics/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ use num_traits::AsPrimitive;

use crate::{
array::{Array, PrimitiveArray},
compute::arity::binary,
compute::arity::{binary, unary},
datatypes::{DataType, TimeUnit},
error::{ArrowError, Result},
scalar::{PrimitiveScalar, Scalar},
temporal_conversions,
types::{months_days_ns, NativeType},
};
Expand Down Expand Up @@ -117,6 +118,31 @@ where
binary(time, duration, time.data_type().clone(), op)
}

/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit
/// enum is used to scale correctly both arrays; adding seconds with seconds,
/// or milliseconds with milliseconds.
pub fn add_duration_scalar<T>(
time: &PrimitiveArray<T>,
duration: &PrimitiveScalar<i64>,
) -> PrimitiveArray<T>
where
f64: AsPrimitive<T>,
T: NativeType + Add<T, Output = T>,
{
let scale = create_scale(time.data_type(), duration.data_type()).unwrap();
let duration = if let Some(duration) = duration.value() {
duration
} else {
return PrimitiveArray::<T>::new_null(time.data_type().clone(), time.len());
};

// Closure for the binary operation. The closure contains the scale
// required to add a duration to the timestamp array.
let op = move |a: T| a + (duration as f64 * scale).as_();

unary(time, op, time.data_type().clone())
}

/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit
/// enum is used to scale correctly both arrays; adding seconds with seconds,
/// or milliseconds with milliseconds.
Expand Down Expand Up @@ -173,6 +199,29 @@ where
binary(time, duration, time.data_type().clone(), op)
}

/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit
/// enum is used to scale correctly both arrays; adding seconds with seconds,
/// or milliseconds with milliseconds.
pub fn sub_duration_scalar<T>(
time: &PrimitiveArray<T>,
duration: &PrimitiveScalar<i64>,
) -> PrimitiveArray<T>
where
f64: AsPrimitive<T>,
T: NativeType + Sub<T, Output = T>,
{
let scale = create_scale(time.data_type(), duration.data_type()).unwrap();
let duration = if let Some(duration) = duration.value() {
duration
} else {
return PrimitiveArray::<T>::new_null(time.data_type().clone(), time.len());
};

let op = move |a: T| a - (duration as f64 * scale).as_();

unary(time, op, time.data_type().clone())
}

/// Calculates the difference between two timestamps returning an array of type
/// Duration. The timeunit enum is used to scale correctly both arrays;
/// subtracting seconds with seconds, or milliseconds with milliseconds.
Expand Down Expand Up @@ -228,6 +277,40 @@ pub fn subtract_timestamps(
}
}

/// Calculates the difference between two timestamps as [`DataType::Duration`] with the same time scale.
pub fn sub_timestamps_scalar(
lhs: &PrimitiveArray<i64>,
rhs: &PrimitiveScalar<i64>,
) -> Result<PrimitiveArray<i64>> {
let (scale, timeunit_a) =
if let (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) =
(lhs.data_type(), rhs.data_type())
{
(
temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b),
timeunit_a,
)
} else {
return Err(ArrowError::InvalidArgumentError(
"sub_timestamps_scalar requires both arguments to be timestamps without timezone"
.to_string(),
));
};

let rhs = if let Some(value) = rhs.value() {
value
} else {
return Ok(PrimitiveArray::<i64>::new_null(
lhs.data_type().clone(),
lhs.len(),
));
};

let op = move |a| a - (rhs as f64 * scale) as i64;

Ok(unary(lhs, op, DataType::Duration(*timeunit_a)))
}

/// Adds an interval to a [`DataType::Timestamp`].
pub fn add_interval(
timestamp: &PrimitiveArray<i64>,
Expand Down Expand Up @@ -285,3 +368,67 @@ pub fn add_interval(
)),
}
}

/// Adds an interval to a [`DataType::Timestamp`].
pub fn add_interval_scalar(
timestamp: &PrimitiveArray<i64>,
interval: &PrimitiveScalar<months_days_ns>,
) -> Result<PrimitiveArray<i64>> {
let interval = if let Some(interval) = interval.value() {
interval
} else {
return Ok(PrimitiveArray::<i64>::new_null(
timestamp.data_type().clone(),
timestamp.len(),
));
};

match timestamp.data_type().to_logical_type() {
DataType::Timestamp(time_unit, Some(timezone_str)) => {
let time_unit = *time_unit;
let timezone = temporal_conversions::parse_offset(timezone_str);
match timezone {
Ok(timezone) => Ok(unary(
timestamp,
|timestamp| {
temporal_conversions::add_interval(
timestamp, time_unit, interval, &timezone,
)
},
timestamp.data_type().clone(),
)),
#[cfg(feature = "chrono-tz")]
Err(_) => {
let timezone = temporal_conversions::parse_offset_tz(timezone_str)?;
Ok(unary(
timestamp,
|timestamp| {
temporal_conversions::add_interval(
timestamp, time_unit, interval, &timezone,
)
},
timestamp.data_type().clone(),
))
}
#[cfg(not(feature = "chrono-tz"))]
_ => Err(ArrowError::InvalidArgumentError(format!(
"timezone \"{}\" cannot be parsed (feature chrono-tz is not active)",
timezone_str
))),
}
}
DataType::Timestamp(time_unit, None) => {
let time_unit = *time_unit;
Ok(unary(
timestamp,
|timestamp| {
temporal_conversions::add_naive_interval(timestamp, time_unit, interval)
},
timestamp.data_type().clone(),
))
}
_ => Err(ArrowError::InvalidArgumentError(
"Adding an interval is only supported for `DataType::Timestamp`".to_string(),
)),
}
}
Loading

0 comments on commit 3316a4d

Please sign in to comment.