Skip to content

Commit

Permalink
inplace arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 15, 2022
1 parent 55acdb6 commit 013fdbb
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 21 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ default = ["docs", "temporal", "private"]
lazy = ["sort_multiple"]

# ~40% faster collect, needed until trustedlength iter stabilizes
# more fast paths
# more fast paths, slower compilation
performant = []

# extra utilities for Utf8Chunked
Expand Down
111 changes: 99 additions & 12 deletions polars/polars-core/src/chunked_array/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
//! Implementations of arithmetic operations on ChunkedArray's.
use crate::prelude::*;
use crate::utils::align_chunks_binary;
use crate::utils::{align_chunks_binary, align_chunks_binary_owned};
use arrow::array::PrimitiveArray;
use arrow::{compute, compute::arithmetics::basic};
use arrow::{
compute,
compute::{arithmetics::basic, arity_assign},
};
use num::{Num, NumCast, ToPrimitive};
use std::borrow::Cow;
use std::ops::{Add, Div, Mul, Rem, Sub};
Expand Down Expand Up @@ -93,6 +96,52 @@ where
ca
}

/// This assigns to the owned buffer if the ref count is 1
fn arithmetic_helper_owned<T, Kernel, F>(
mut lhs: ChunkedArray<T>,
mut rhs: ChunkedArray<T>,
kernel: Kernel,
operation: F,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
Kernel: Fn(&mut PrimitiveArray<T::Native>, &mut PrimitiveArray<T::Native>),
F: Fn(T::Native, T::Native) -> T::Native,
{
let ca = match (lhs.len(), rhs.len()) {
(a, b) if a == b => {
let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs);
lhs.downcast_iter_mut()
.zip(rhs.downcast_iter_mut())
.for_each(|(lhs, rhs)| kernel(lhs, rhs));
lhs
}
// broadcast right path
(_, 1) => {
let opt_rhs = rhs.get(0);
match opt_rhs {
None => ChunkedArray::full_null(lhs.name(), lhs.len()),
Some(rhs) => {
lhs.apply_mut(|lhs| operation(lhs, rhs));
lhs
}
}
}
(1, _) => {
let opt_lhs = lhs.get(0);
match opt_lhs {
None => ChunkedArray::full_null(lhs.name(), rhs.len()),
Some(lhs) => {
rhs.apply_mut(|rhs| operation(lhs, rhs));
rhs
}
}
}
_ => panic!("Cannot apply operation on arrays of different lengths"),
};
ca
}

// Operands on ChunkedArray & ChunkedArray

impl<T> Add for &ChunkedArray<T>
Expand Down Expand Up @@ -157,7 +206,12 @@ where
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
(&self).add(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a + b),
|lhs, rhs| lhs + rhs,
)
}
}

Expand All @@ -168,7 +222,12 @@ where
type Output = Self;

fn div(self, rhs: Self) -> Self::Output {
(&self).div(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a / b),
|lhs, rhs| lhs / rhs,
)
}
}

Expand All @@ -179,7 +238,12 @@ where
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
(&self).mul(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a * b),
|lhs, rhs| lhs * rhs,
)
}
}

Expand All @@ -190,7 +254,12 @@ where
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
(&self).sub(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a - b),
|lhs, rhs| lhs - rhs,
)
}
}

Expand Down Expand Up @@ -279,8 +348,14 @@ where
{
type Output = ChunkedArray<T>;

fn add(self, rhs: N) -> Self::Output {
(&self).add(rhs)
fn add(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let adder: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val + adder);
self
} else {
(&self).add(rhs)
}
}
}

Expand All @@ -291,8 +366,14 @@ where
{
type Output = ChunkedArray<T>;

fn sub(self, rhs: N) -> Self::Output {
(&self).sub(rhs)
fn sub(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let subber: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val - subber);
self
} else {
(&self).sub(rhs)
}
}
}

Expand All @@ -315,8 +396,14 @@ where
{
type Output = ChunkedArray<T>;

fn mul(self, rhs: N) -> Self::Output {
(&self).mul(rhs)
fn mul(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let multiplier: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val * multiplier);
self
} else {
(&self).mul(rhs)
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions polars/polars-core/src/chunked_array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl<T: PolarsNumericType> ChunkedArray<T> {
}
}

impl<T: PolarsNumericType> ChunkedArray<T> {
pub(crate) fn apply_mut<F>(&mut self, f: F)
where
F: Fn(T::Native) -> T::Native + Copy,
{
self.downcast_iter_mut()
.for_each(|arr| arrow::compute::arity_assign::unary(arr, f));
}
}

impl<'a, T> ChunkApply<'a, T::Native, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down
12 changes: 12 additions & 0 deletions polars/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ where
unsafe { &*(arr as *const dyn Array as *const PrimitiveArray<T::Native>) }
})
}

pub(crate) fn downcast_iter_mut(
&mut self,
) -> impl Iterator<Item = &mut PrimitiveArray<T::Native>> + DoubleEndedIterator {
self.chunks.iter_mut().map(|arr| {
// Safety:
// This should be the array type in PolarsNumericType
let arr = &mut **arr;
unsafe { &mut *(arr as *mut dyn Array as *mut PrimitiveArray<T::Native>) }
})
}

pub fn downcast_chunks(&self) -> Chunks<'_, PrimitiveArray<T::Native>> {
Chunks::new(&self.chunks)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use crate::prelude::*;
use crate::utils::{get_supertype, get_time_units};
use num::{Num, NumCast};
use std::borrow::Cow;
use std::fmt::Debug;
use std::ops;
use super::*;

pub trait NumOpsDispatch: Debug {
fn subtract(&self, rhs: &Series) -> Result<Series> {
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-core/src/series/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
mod borrowed;
mod owned;

use crate::prelude::*;
use crate::utils::{get_supertype, get_time_units};
use num::{Num, NumCast};
use std::borrow::Cow;
use std::fmt::Debug;
use std::ops::{self, Add, Mul, Sub};

pub use borrowed::*;
82 changes: 82 additions & 0 deletions polars/polars-core/src/series/arithmetic/owned.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use super::*;
#[cfg(feature = "performant")]
use crate::utils::align_chunks_binary_owned_series;

#[cfg(feature = "performant")]
pub fn coerce_lhs_rhs_owned(lhs: Series, rhs: Series) -> Result<(Series, Series)> {
let dtype = get_supertype(lhs.dtype(), rhs.dtype())?;
let left = if lhs.dtype() == &dtype {
lhs
} else {
lhs.cast(&dtype)?
};
let right = if rhs.dtype() == &dtype {
rhs
} else {
rhs.cast(&dtype)?
};
Ok((left, right))
}

#[cfg(feature = "performant")]
fn apply_operation_mut<T, F>(mut lhs: Series, mut rhs: Series, op: F) -> Series
where
T: PolarsNumericType,
F: Fn(ChunkedArray<T>, ChunkedArray<T>) -> ChunkedArray<T> + Copy,
ChunkedArray<T>: IntoSeries,
{
let lhs_ca: &mut ChunkedArray<T> = lhs._get_inner_mut().as_mut();
let rhs_ca: &mut ChunkedArray<T> = rhs._get_inner_mut().as_mut();

let lhs = std::mem::take(lhs_ca);
let rhs = std::mem::take(rhs_ca);

op(lhs, rhs).into_series()
}

macro_rules! impl_operation {
($operation:ident, $method:ident, $function:expr) => {
impl $operation for Series {
type Output = Series;

fn $method(self, rhs: Self) -> Self::Output {
#[cfg(feature = "performant")]
{
// only physical numeric values take the mutable path
if !self.is_logical() && self.is_numeric_physical() {
let (lhs, rhs) = coerce_lhs_rhs_owned(self, rhs).unwrap();
let (lhs, rhs) = align_chunks_binary_owned_series(lhs, rhs);
use DataType::*;
match lhs.dtype() {
#[cfg(feature = "dtype-i8")]
Int8 => apply_operation_mut::<Int8Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-i16")]
Int16 => apply_operation_mut::<Int16Type, _>(lhs, rhs, $function),
Int32 => apply_operation_mut::<Int32Type, _>(lhs, rhs, $function),
Int64 => apply_operation_mut::<Int64Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-u8")]
UInt8 => apply_operation_mut::<UInt8Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-u16")]
UInt16 => apply_operation_mut::<UInt16Type, _>(lhs, rhs, $function),
UInt32 => apply_operation_mut::<UInt32Type, _>(lhs, rhs, $function),
UInt64 => apply_operation_mut::<UInt64Type, _>(lhs, rhs, $function),
Float32 => apply_operation_mut::<Float32Type, _>(lhs, rhs, $function),
Float64 => apply_operation_mut::<Float64Type, _>(lhs, rhs, $function),
_ => unreachable!(),
}
} else {
(&self).$method(&rhs)
}
}
#[cfg(not(feature = "performant"))]
{
(&self).$method(&rhs)
}
}
}
};
}

impl_operation!(Add, add, |a, b| a.add(b));
impl_operation!(Sub, sub, |a, b| a.sub(b));
impl_operation!(Mul, mul, |a, b| a.mul(b));
28 changes: 28 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,34 @@ where
}
}

#[cfg(feature = "performant")]
pub(crate) fn align_chunks_binary_owned_series(left: Series, right: Series) -> (Series, Series) {
match (left.chunks().len(), right.chunks().len()) {
(1, 1) => (left, right),
(_, 1) => (left.rechunk(), right),
(1, _) => (left, right.rechunk()),
(_, _) => (left.rechunk(), right.rechunk()),
}
}

pub(crate) fn align_chunks_binary_owned<T, B>(
left: ChunkedArray<T>,
right: ChunkedArray<B>,
) -> (ChunkedArray<T>, ChunkedArray<B>)
where
ChunkedArray<B>: ChunkOps,
ChunkedArray<T>: ChunkOps,
B: PolarsDataType,
T: PolarsDataType,
{
match (left.chunks.len(), right.chunks.len()) {
(1, 1) => (left, right),
(_, 1) => (left.rechunk(), right),
(1, _) => (left, right.rechunk()),
(_, _) => (left.rechunk(), right.rechunk()),
}
}

#[allow(clippy::type_complexity)]
pub(crate) fn align_chunks_ternary<'a, A, B, C>(
a: &'a ChunkedArray<A>,
Expand Down
1 change: 1 addition & 0 deletions polars/polars-io/src/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl<R: Read + Seek> IpcReader<R> {
}

// todo! hoist to lazy crate
#[cfg(feature = "lazy")]
pub fn finish_with_scan_ops(
mut self,
predicate: Option<Arc<dyn PhysicalIoExpr>>,
Expand Down
Loading

0 comments on commit 013fdbb

Please sign in to comment.