Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Simplified arithmetics compute #607

Merged
merged 3 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/arithmetic_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn bench_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>)
where
T: NativeType + Add<Output = T> + NumCast,
{
criterion::black_box(add(lhs, rhs)).unwrap();
criterion::black_box(add(lhs, rhs));
}

fn add_benchmark(c: &mut Criterion) {
Expand Down
2 changes: 1 addition & 1 deletion benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion};

use arrow2::scalar::*;
use arrow2::util::bench_util::*;
use arrow2::{compute::comparison::*, datatypes::DataType};
use arrow2::{compute::comparison::eq, datatypes::DataType};

fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
Expand Down
2 changes: 1 addition & 1 deletion benches/write_ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn write(array: &dyn Array) -> Result<()> {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![clone(array).into()])?;

let writer = Cursor::new(vec![]);
let mut writer = FileWriter::try_new(writer, &schema)?;
let mut writer = FileWriter::try_new(writer, &schema, Default::default())?;

writer.write(&batch)
}
Expand Down
28 changes: 11 additions & 17 deletions examples/arithmetics.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
use arrow2::array::{Array, PrimitiveArray};
use arrow2::compute::arithmetics::*;
use arrow2::compute::arithmetics::basic::*;
use arrow2::compute::arithmetics::{add as dyn_add, can_add};
use arrow2::compute::arity::{binary, unary};
use arrow2::datatypes::DataType;
use arrow2::error::Result;

fn main() -> Result<()> {
fn main() {
// say we have two arrays
let array0 = PrimitiveArray::<i64>::from(&[Some(1), Some(2), Some(3)]);
let array1 = PrimitiveArray::<i64>::from(&[Some(4), None, Some(6)]);

// we can add them as follows:
let added = arithmetic_primitive(&array0, Operator::Add, &array1)?;
let added = add(&array0, &array1);
assert_eq!(
added,
PrimitiveArray::<i64>::from(&[Some(5), None, Some(9)])
);

// subtract:
let subtracted = arithmetic_primitive(&array0, Operator::Subtract, &array1)?;
let subtracted = sub(&array0, &array1);
assert_eq!(
subtracted,
PrimitiveArray::<i64>::from(&[Some(-3), None, Some(-3)])
);

// add a scalar:
let plus10 = arithmetic_primitive_scalar(&array0, Operator::Add, &10)?;
let plus10 = add_scalar(&array0, &10);
assert_eq!(
plus10,
PrimitiveArray::<i64>::from(&[Some(11), Some(12), Some(13)])
);

// when the array is a trait object, there is a similar API
// a similar API for trait objects:
let array0 = &array0 as &dyn Array;
let array1 = &array1 as &dyn Array;

// check whether the logical types support addition (they could be any `Array`).
assert!(can_arithmetic(
array0.data_type(),
Operator::Add,
array1.data_type()
));
// check whether the logical types support addition.
assert!(can_add(array0.data_type(), array1.data_type()));

// add them
let added = arithmetic(array0, Operator::Add, array1).unwrap();
let added = dyn_add(array0, array1);
assert_eq!(
PrimitiveArray::<i64>::from(&[Some(5), None, Some(9)]),
added.as_ref(),
Expand All @@ -54,7 +50,7 @@ fn main() -> Result<()> {
let array1 = PrimitiveArray::<i64>::from(&[Some(4), None, Some(6)]);

let op = |x: i64, y: i64| x.pow(2) + y.pow(2);
let r = binary(&array0, &array1, DataType::Int64, op)?;
let r = binary(&array0, &array1, DataType::Int64, op);
assert_eq!(
r,
PrimitiveArray::<i64>::from(&[Some(1 + 16), None, Some(9 + 36)])
Expand All @@ -79,6 +75,4 @@ fn main() -> Result<()> {
rounded,
PrimitiveArray::<i64>::from(&[Some(4), None, Some(5)])
);

Ok(())
}
2 changes: 0 additions & 2 deletions examples/csv_read_async.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::sync::Arc;

use futures::io::Cursor;
use tokio::fs::File;
use tokio_util::compat::*;

use arrow2::array::*;
use arrow2::error::Result;
use arrow2::io::csv::read_async::*;

Expand Down
2 changes: 1 addition & 1 deletion examples/growable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow2::array::growable::{Growable, GrowablePrimitive};
use arrow2::array::{Array, PrimitiveArray};
use arrow2::array::PrimitiveArray;

fn main() {
// say we have two sorted arrays
Expand Down
68 changes: 25 additions & 43 deletions src/compute/arithmetics/basic/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@ use std::ops::Add;

use num_traits::{ops::overflowing::OverflowingAdd, CheckedAdd, SaturatingAdd, WrappingAdd, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::compute::arithmetics::ArrayWrappingAdd;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
compute::{
arithmetics::{
ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, NativeArithmetics,
ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd,
NativeArithmetics,
},
arity::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::Result,
types::NativeType,
};

Expand All @@ -30,16 +28,14 @@ use crate::{
///
/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]);
/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]);
/// let result = add(&a, &b).unwrap();
/// let result = add(&a, &b);
/// let expected = PrimitiveArray::from([None, None, None, Some(12)]);
/// assert_eq!(result, expected)
/// ```
pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + Add<Output = T>,
{
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b)
}

Expand All @@ -53,19 +49,14 @@ where
///
/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]);
/// let result = wrapping_add(&a, &b).unwrap();
/// let result = wrapping_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn wrapping_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
pub fn wrapping_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + WrappingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.wrapping_add(&b);

binary(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -81,16 +72,14 @@ where
///
/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]);
/// let result = checked_add(&a, &b).unwrap();
/// let result = checked_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + CheckedAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_add(&b);

binary_checked(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -107,19 +96,14 @@ where
///
/// let a = PrimitiveArray::from([Some(100i8)]);
/// let b = PrimitiveArray::from([Some(100i8)]);
/// let result = saturating_add(&a, &b).unwrap();
/// let result = saturating_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(127)]);
/// assert_eq!(result, expected);
/// ```
pub fn saturating_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
pub fn saturating_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + SaturatingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_add(&b);

binary(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -137,19 +121,17 @@ where
///
/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]);
/// let (result, overflow) = overflowing_add(&a, &b).unwrap();
/// let (result, overflow) = overflowing_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn overflowing_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<(PrimitiveArray<T>, Bitmap)>
) -> (PrimitiveArray<T>, Bitmap)
where
T: NativeType + OverflowingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_add(&b);

binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -160,7 +142,7 @@ impl<T> ArrayAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + Add<Output = T>,
{
fn add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn add(&self, rhs: &PrimitiveArray<T>) -> Self {
add(self, rhs)
}
}
Expand All @@ -169,7 +151,7 @@ impl<T> ArrayWrappingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + WrappingAdd<Output = T>,
{
fn wrapping_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn wrapping_add(&self, rhs: &PrimitiveArray<T>) -> Self {
wrapping_add(self, rhs)
}
}
Expand All @@ -179,7 +161,7 @@ impl<T> ArrayCheckedAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + CheckedAdd<Output = T>,
{
fn checked_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn checked_add(&self, rhs: &PrimitiveArray<T>) -> Self {
checked_add(self, rhs)
}
}
Expand All @@ -189,7 +171,7 @@ impl<T> ArraySaturatingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + SaturatingAdd<Output = T>,
{
fn saturating_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn saturating_add(&self, rhs: &PrimitiveArray<T>) -> Self {
saturating_add(self, rhs)
}
}
Expand All @@ -199,7 +181,7 @@ impl<T> ArrayOverflowingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + OverflowingAdd<Output = T>,
{
fn overflowing_add(&self, rhs: &PrimitiveArray<T>) -> Result<(Self, Bitmap)> {
fn overflowing_add(&self, rhs: &PrimitiveArray<T>) -> (Self, Bitmap) {
overflowing_add(self, rhs)
}
}
Expand Down Expand Up @@ -323,8 +305,8 @@ impl<T> ArrayAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + Add<Output = T>,
{
fn add(&self, rhs: &T) -> Result<Self> {
Ok(add_scalar(self, rhs))
fn add(&self, rhs: &T) -> Self {
add_scalar(self, rhs)
}
}

Expand All @@ -333,8 +315,8 @@ impl<T> ArrayCheckedAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + CheckedAdd<Output = T> + Zero,
{
fn checked_add(&self, rhs: &T) -> Result<Self> {
Ok(checked_add_scalar(self, rhs))
fn checked_add(&self, rhs: &T) -> Self {
checked_add_scalar(self, rhs)
}
}

Expand All @@ -343,8 +325,8 @@ impl<T> ArraySaturatingAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + SaturatingAdd<Output = T>,
{
fn saturating_add(&self, rhs: &T) -> Result<Self> {
Ok(saturating_add_scalar(self, rhs))
fn saturating_add(&self, rhs: &T) -> Self {
saturating_add_scalar(self, rhs)
}
}

Expand All @@ -353,7 +335,7 @@ impl<T> ArrayOverflowingAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + OverflowingAdd<Output = T>,
{
fn overflowing_add(&self, rhs: &T) -> Result<(Self, Bitmap)> {
Ok(overflowing_add_scalar(self, rhs))
fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) {
overflowing_add_scalar(self, rhs)
}
}
Loading