Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improve arity_assign: ~2x improvement if we share data. #1076

Merged
merged 8 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benches/assign_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function(&format!("apply_mul 2^{}", log2_size), |b| {
b.iter(|| {
criterion::black_box(&mut arr_a)
.apply_values(|x| x.iter_mut().for_each(|x| *x *= 1.01));
.apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 1.01));
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
assert!(!arr_a.value(10).is_nan());
})
});
Expand All @@ -30,7 +30,7 @@ fn add_benchmark(c: &mut Criterion) {
let mut arr_a = create_primitive_array::<f32>(size, 0.2);
let mut arr_b = create_primitive_array_with_seed::<f32>(size, 0.2, 10);
// convert to be close to 1.01
arr_b.apply_values(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0));
arr_b.apply_values_mut(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0));

c.bench_function(&format!("apply_mul null 2^{}", log2_size), |b| {
b.iter(|| {
Expand Down
2 changes: 1 addition & 1 deletion examples/cow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
.unwrap();

// 2. call `apply_values` with the function to apply over the values
array.apply_values(|x| x.iter_mut().for_each(|x| *x *= 10));
array.apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 10));

// confirm that it gives the right result :)
assert_eq!(array, &PrimitiveArray::from_vec(vec![10i32, 20]));
Expand Down
4 changes: 2 additions & 2 deletions src/array/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl BooleanArray {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_values<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_values_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
let values = std::mem::take(&mut self.values);
let mut values = values.make_mut();
f(&mut values);
Expand All @@ -121,7 +121,7 @@ impl BooleanArray {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_validity_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let values = std::mem::take(validity);
let mut bitmap = values.make_mut();
Expand Down
75 changes: 65 additions & 10 deletions src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,41 @@ impl<T: NativeType> PrimitiveArray<T> {
/// This function panics iff `validity.len() != self.len()`.
#[must_use]
pub fn with_validity(&self, validity: Option<Bitmap>) -> Self {
let mut out = self.clone();
out.set_validity(validity);
out
}

/// Update the validity buffer of this [`PrimitiveArray`].
/// # Panics
/// This function panics iff `values.len() != self.len()`.
pub fn set_validity(&mut self, validity: Option<Bitmap>) {
if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) {
panic!("validity should be as least as large as the array")
}
let mut arr = self.clone();
arr.validity = validity;
arr
self.validity = validity;
}

/// Returns a clone of this [`PrimitiveArray`] with a new values.
/// # Panics
/// This function panics iff `values.len() != self.len()`.
#[must_use]
pub fn with_values(&self, values: Buffer<T>) -> Self {
let mut out = self.clone();
out.set_values(values);
out
}

/// Update the values buffer of this [`PrimitiveArray`].
/// # Panics
/// This function panics iff `values.len() != self.len()`.
pub fn set_values(&mut self, values: Buffer<T>) {
assert_eq!(
values.len(),
self.len(),
"values length should be equal to this arrays length"
);
self.values = values;
}

/// Applies a function `f` to the values of this array, cloning the values
Expand All @@ -260,10 +289,14 @@ impl<T: NativeType> PrimitiveArray<T> {
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
pub fn apply_values<F: Fn(&mut [T])>(&mut self, f: F) {
/// # Panics
/// This function panics, if `f` modifies the length of `&mut [T]`
pub fn apply_values_mut<F: Fn(&mut [T])>(&mut self, f: F) {
let values = std::mem::take(&mut self.values);
let mut values = values.make_mut();
let len = values.len();
f(&mut values);
assert_eq!(values.len(), len, "values length must remain the same");
self.values = values.into();
}

Expand All @@ -276,13 +309,29 @@ impl<T: NativeType> PrimitiveArray<T> {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_validity_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let owned_validity = std::mem::take(validity);
let mut mut_bitmap = owned_validity.make_mut();
f(&mut mut_bitmap);
assert_eq!(mut_bitmap.len(), self.values.len());
*validity = mut_bitmap.into();
}
}

/// Applies a function `f` to the validity of this array, the caller can decide to make
/// it mutable or not.
///
/// This is an API to leverage clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut Bitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let values = std::mem::take(validity);
let mut bitmap = values.make_mut();
f(&mut bitmap);
assert_eq!(bitmap.len(), self.values.len());
*validity = bitmap.into();
f(validity);
assert_eq!(validity.len(), self.values.len());
}
}

Expand Down Expand Up @@ -511,3 +560,9 @@ pub type UInt16Vec = MutablePrimitiveArray<u16>;
pub type UInt32Vec = MutablePrimitiveArray<u32>;
/// A type definition [`MutablePrimitiveArray`] for `u64`
pub type UInt64Vec = MutablePrimitiveArray<u64>;

impl<T: NativeType> Default for PrimitiveArray<T> {
fn default() -> Self {
PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None)
}
}
31 changes: 31 additions & 0 deletions src/array/primitive/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,37 @@ impl<T: NativeType> MutablePrimitiveArray<T> {
pub fn into_data(self) -> (DataType, Vec<T>, Option<MutableBitmap>) {
(self.data_type, self.values, self.validity)
}

/// Applies a function `f` to the values of this array, cloning the values
/// iff they are being shared with others
///
/// This is an API to use clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics, if `f` modifies the length of `&mut [T]`
pub fn apply_values<F: Fn(&mut [T])>(&mut self, f: F) {
let len = self.values.len();
f(&mut self.values);
assert_eq!(len, self.values.len(), "values length must remain the same")
}

/// Applies a function `f` to the validity of this array, cloning it
/// iff it is being shared.
///
/// This is an API to leverage clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = &mut self.validity {
f(validity);
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(validity.len(), self.values.len());
}
}
}

impl<T: NativeType> Default for MutablePrimitiveArray<T> {
Expand Down
63 changes: 56 additions & 7 deletions src/compute/arity_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use super::utils::check_same_len;
use crate::{array::PrimitiveArray, types::NativeType};
use either::Either;

/// Applies an unary function to a [`PrimitiveArray`] in-place via cow semantics.
///
Expand All @@ -17,7 +18,7 @@ where
I: NativeType,
F: Fn(I) -> I,
{
array.apply_values(|values| values.iter_mut().for_each(|v| *v = op(*v)));
array.apply_values_mut(|values| values.iter_mut().for_each(|v| *v = op(*v)));
}

/// Applies a binary operations to two [`PrimitiveArray`], applying the operation
Expand All @@ -38,20 +39,68 @@ where
{
check_same_len(lhs, rhs).unwrap();

// both for the validity and for the values
// we branch to check if we can mutate in place
// if we can, great that is fastest.
// if we cannot, we allocate a new buffer and assign values to that
// new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place
// for the validity bits it can be much faster as we might need to iterate all bits if the
// bitmap has an offset.
match rhs.validity() {
None => {}
Some(rhs) => {
if lhs.validity().is_none() {
*lhs = lhs.with_validity(Some(rhs.clone()))
} else {
lhs.apply_validity(|mut lhs| lhs &= rhs)
lhs.apply_validity(|bitmap| {
// we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact
// so that we can later assign the result to out `&mut bitmap`
let owned_lhs = std::mem::take(bitmap);

match owned_lhs.into_mut() {
// we take alloc and write to new buffer
Either::Left(immutable) => {
// we allocate a new bitmap because that is a lot faster
// than doing the memcpy or the potential iteration of bits if
// we are dealing with an offset
let new = &immutable & rhs;
*bitmap = new;
}
// we can mutate in place, happy days.
Either::Right(mut mutable) => {
let mut mutable_ref = &mut mutable;
mutable_ref &= rhs;
*bitmap = mutable.into()
}
}
});
}
}
}
// we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact
// so that we can later assign the result to out `&mut lhs`
let owned_lhs = std::mem::take(lhs);

lhs.apply_values(|x| {
x.iter_mut()
.zip(rhs.values().iter())
.for_each(|(l, r)| *l = op(*l, *r))
});
match owned_lhs.into_mut() {
// we take alloc and write to new buffer
Either::Left(mut immutable) => {
let values = immutable
.values()
.iter()
.zip(rhs.values().iter())
.map(|(l, r)| op(*l, *r))
.collect::<Vec<_>>();
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
immutable.set_values(values.into());
*lhs = immutable;
}
// we can mutate in place
Either::Right(mut mutable) => {
mutable.apply_values(|x| {
x.iter_mut()
.zip(rhs.values().iter())
.for_each(|(l, r)| *l = op(*l, *r))
});
*lhs = mutable.into()
}
}
}
4 changes: 2 additions & 2 deletions tests/it/array/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn from_iter() {
#[test]
fn apply_values() {
let mut a = BooleanArray::from([Some(true), Some(false), None]);
a.apply_values(|x| {
a.apply_values_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand All @@ -147,7 +147,7 @@ fn apply_values() {
#[test]
fn apply_validity() {
let mut a = BooleanArray::from([Some(true), Some(false), None]);
a.apply_validity(|x| {
a.apply_validity_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand Down
4 changes: 2 additions & 2 deletions tests/it/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn into_mut_3() {
#[test]
fn apply_values() {
let mut a = PrimitiveArray::from([Some(1), Some(2), None]);
a.apply_values(|x| {
a.apply_values_mut(|x| {
x[0] = 10;
});
let expected = PrimitiveArray::from([Some(10), Some(2), None]);
Expand All @@ -138,7 +138,7 @@ fn apply_values() {
#[test]
fn apply_validity() {
let mut a = PrimitiveArray::from([Some(1), Some(2), None]);
a.apply_validity(|x| {
a.apply_validity_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand Down