Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Fixed error in dispatching scalar arithmetics #682

Merged
merged 1 commit into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 99 additions & 6 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ pub mod decimal;
pub mod time;

use crate::{
array::Array,
array::{Array, PrimitiveArray},
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::Scalar,
scalar::{PrimitiveScalar, Scalar},
};

// Macro to evaluate match branch in arithmetic function.
Expand Down Expand Up @@ -101,6 +101,99 @@ macro_rules! arith {
}};
}

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive_scalar {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveArray<$type>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$type>>()
.unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return Box::new(PrimitiveArray::<$type>::new_null(
lhs.data_type().clone(),
lhs.len(),
)) as Box<dyn Array>;
};

let result = basic::$op::<$type>(lhs, &rhs);
Box::new(result) as Box<dyn Array>
}};
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
// the array's logical types
macro_rules! arith_scalar {
($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{
let lhs = $lhs;
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive_scalar!(lhs, rhs, $op, i64)
}
(UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(decimal::$op_decimal(lhs, rhs)) as Box<dyn Array>
}
)?
$ (
(Time32(TimeUnit::Second), Duration(_))
| (Time32(TimeUnit::Millisecond), Duration(_))
| (Date32, Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i32>(lhs, rhs)) as Box<dyn Array>
}
(Time64(TimeUnit::Microsecond), Duration(_))
| (Time64(TimeUnit::Nanosecond), Duration(_))
| (Date64, Duration(_))
| (Timestamp(_, _), Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i64>(lhs, rhs)) as Box<dyn Array>
}
)?
$ (
(Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
}
)?
$ (
(Timestamp(_, None), Timestamp(_, None)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
}
)?
_ => todo!(
"Addition of {:?} with {:?} is not supported",
lhs.data_type(),
rhs.data_type()
),
}
}};
}

/// Adds two [`Array`]s.
/// # Panic
/// This function panics iff
Expand All @@ -124,7 +217,7 @@ pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
arith_scalar!(
lhs,
rhs,
add_scalar,
Expand Down Expand Up @@ -185,7 +278,7 @@ pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
arith_scalar!(
lhs,
rhs,
sub_scalar,
Expand Down Expand Up @@ -236,7 +329,7 @@ pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_mul`] to check)
pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, mul_scalar, decimal = mul_scalar)
arith_scalar!(lhs, rhs, mul_scalar, decimal = mul_scalar)
}

/// Returns whether two [`DataType`]s can be multiplied by [`mul`].
Expand Down Expand Up @@ -272,7 +365,7 @@ pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_div`] to check)
pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, div_scalar, decimal = div_scalar)
arith_scalar!(lhs, rhs, div_scalar, decimal = div_scalar)
}

/// Returns whether two [`DataType`]s can be divided by [`div`].
Expand Down
21 changes: 20 additions & 1 deletion tests/it/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@ mod basic;
mod decimal;
mod time;

use arrow2::array::new_empty_array;
use arrow2::array::{new_empty_array, Int32Array};
use arrow2::compute::arithmetics::*;
use arrow2::datatypes::DataType::*;
use arrow2::datatypes::{IntervalUnit, TimeUnit};
use arrow2::scalar::PrimitiveScalar;

#[test]
fn test_add() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let b = Int32Array::from(&[Some(5), None, None, Some(6)]);
let result = add(&a, &b);
let expected = Int32Array::from(&[None, None, None, Some(12)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn test_add_scalar() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let b: PrimitiveScalar<i32> = Some(1i32).into();
let result = add_scalar(&a, &b);
let expected = Int32Array::from(&[None, Some(7), None, Some(7)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn consistency() {
Expand Down