diff --git a/benches/assign_ops.rs b/benches/assign_ops.rs index 35f29b643c2..dc4c9b182f3 100644 --- a/benches/assign_ops.rs +++ b/benches/assign_ops.rs @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function(&format!("apply_mul 2^{}", log2_size), |b| { b.iter(|| { criterion::black_box(&mut arr_a) - .apply_values(|x| x.iter_mut().for_each(|x| *x *= 1.01)); + .apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 1.01)); assert!(!arr_a.value(10).is_nan()); }) }); @@ -30,7 +30,7 @@ fn add_benchmark(c: &mut Criterion) { let mut arr_a = create_primitive_array::(size, 0.2); let mut arr_b = create_primitive_array_with_seed::(size, 0.2, 10); // convert to be close to 1.01 - arr_b.apply_values(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0)); + arr_b.apply_values_mut(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0)); c.bench_function(&format!("apply_mul null 2^{}", log2_size), |b| { b.iter(|| { diff --git a/examples/cow.rs b/examples/cow.rs index 4b76972a795..3255d365f4d 100644 --- a/examples/cow.rs +++ b/examples/cow.rs @@ -13,7 +13,7 @@ fn main() { .unwrap(); // 2. call `apply_values` with the function to apply over the values - array.apply_values(|x| x.iter_mut().for_each(|x| *x *= 10)); + array.apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 10)); // confirm that it gives the right result :) assert_eq!(array, &PrimitiveArray::from_vec(vec![10i32, 20])); diff --git a/src/array/boolean/mod.rs b/src/array/boolean/mod.rs index 6faef21186f..815ec6b1348 100644 --- a/src/array/boolean/mod.rs +++ b/src/array/boolean/mod.rs @@ -102,7 +102,7 @@ impl BooleanArray { /// if it is being shared (since it results in a `O(N)` memcopy). /// # Panics /// This function panics if the function modifies the length of the [`MutableBitmap`]. - pub fn apply_values(&mut self, f: F) { + pub fn apply_values_mut(&mut self, f: F) { let values = std::mem::take(&mut self.values); let mut values = values.make_mut(); f(&mut values); @@ -121,7 +121,7 @@ impl BooleanArray { /// if it is being shared (since it results in a `O(N)` memcopy). /// # Panics /// This function panics if the function modifies the length of the [`MutableBitmap`]. - pub fn apply_validity(&mut self, f: F) { + pub fn apply_validity_mut(&mut self, f: F) { if let Some(validity) = self.validity.as_mut() { let values = std::mem::take(validity); let mut bitmap = values.make_mut(); diff --git a/src/array/primitive/mod.rs b/src/array/primitive/mod.rs index 521dc971dc7..8443fca7d56 100644 --- a/src/array/primitive/mod.rs +++ b/src/array/primitive/mod.rs @@ -245,12 +245,41 @@ impl PrimitiveArray { /// This function panics iff `validity.len() != self.len()`. #[must_use] pub fn with_validity(&self, validity: Option) -> Self { + let mut out = self.clone(); + out.set_validity(validity); + out + } + + /// Update the validity buffer of this [`PrimitiveArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_validity(&mut self, validity: Option) { if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { panic!("validity should be as least as large as the array") } - let mut arr = self.clone(); - arr.validity = validity; - arr + self.validity = validity; + } + + /// Returns a clone of this [`PrimitiveArray`] with a new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(&self, values: Buffer) -> Self { + let mut out = self.clone(); + out.set_values(values); + out + } + + /// Update the values buffer of this [`PrimitiveArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Buffer) { + assert_eq!( + values.len(), + self.len(), + "values length should be equal to this arrays length" + ); + self.values = values; } /// Applies a function `f` to the values of this array, cloning the values @@ -260,10 +289,14 @@ impl PrimitiveArray { /// # Implementation /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` /// if it is being shared (since it results in a `O(N)` memcopy). - pub fn apply_values(&mut self, f: F) { + /// # Panics + /// This function panics, if `f` modifies the length of `&mut [T]` + pub fn apply_values_mut(&mut self, f: F) { let values = std::mem::take(&mut self.values); let mut values = values.make_mut(); + let len = values.len(); f(&mut values); + assert_eq!(values.len(), len, "values length must remain the same"); self.values = values.into(); } @@ -276,13 +309,29 @@ impl PrimitiveArray { /// if it is being shared (since it results in a `O(N)` memcopy). /// # Panics /// This function panics if the function modifies the length of the [`MutableBitmap`]. - pub fn apply_validity(&mut self, f: F) { + pub fn apply_validity_mut(&mut self, f: F) { + if let Some(validity) = self.validity.as_mut() { + let owned_validity = std::mem::take(validity); + let mut mut_bitmap = owned_validity.make_mut(); + f(&mut mut_bitmap); + assert_eq!(mut_bitmap.len(), self.values.len()); + *validity = mut_bitmap.into(); + } + } + + /// Applies a function `f` to the validity of this array, the caller can decide to make + /// it mutable or not. + /// + /// This is an API to leverage clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics if the function modifies the length of the [`MutableBitmap`]. + pub fn apply_validity(&mut self, f: F) { if let Some(validity) = self.validity.as_mut() { - let values = std::mem::take(validity); - let mut bitmap = values.make_mut(); - f(&mut bitmap); - assert_eq!(bitmap.len(), self.values.len()); - *validity = bitmap.into(); + f(validity); + assert_eq!(validity.len(), self.values.len()); } } diff --git a/src/array/primitive/mutable.rs b/src/array/primitive/mutable.rs index d59df1b845a..ae983c10af0 100644 --- a/src/array/primitive/mutable.rs +++ b/src/array/primitive/mutable.rs @@ -81,6 +81,37 @@ impl MutablePrimitiveArray { pub fn into_data(self) -> (DataType, Vec, Option) { (self.data_type, self.values, self.validity) } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics, if `f` modifies the length of `&mut [T]` + pub fn apply_values(&mut self, f: F) { + let len = self.values.len(); + f(&mut self.values); + assert_eq!(len, self.values.len(), "values length must remain the same") + } + + /// Applies a function `f` to the validity of this array, cloning it + /// iff it is being shared. + /// + /// This is an API to leverage clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics if the function modifies the length of the [`MutableBitmap`]. + pub fn apply_validity(&mut self, f: F) { + if let Some(validity) = &mut self.validity { + f(validity); + assert_eq!(validity.len(), self.values.len()); + } + } } impl Default for MutablePrimitiveArray { diff --git a/src/compute/arity_assign.rs b/src/compute/arity_assign.rs index d10b100117d..b8746cdd81c 100644 --- a/src/compute/arity_assign.rs +++ b/src/compute/arity_assign.rs @@ -2,6 +2,7 @@ use super::utils::check_same_len; use crate::{array::PrimitiveArray, types::NativeType}; +use either::Either; /// Applies an unary function to a [`PrimitiveArray`] in-place via cow semantics. /// @@ -17,7 +18,7 @@ where I: NativeType, F: Fn(I) -> I, { - array.apply_values(|values| values.iter_mut().for_each(|v| *v = op(*v))); + array.apply_values_mut(|values| values.iter_mut().for_each(|v| *v = op(*v))); } /// Applies a binary operations to two [`PrimitiveArray`], applying the operation @@ -38,20 +39,68 @@ where { check_same_len(lhs, rhs).unwrap(); + // both for the validity and for the values + // we branch to check if we can mutate in place + // if we can, great that is fastest. + // if we cannot, we allocate a new buffer and assign values to that + // new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place + // for the validity bits it can be much faster as we might need to iterate all bits if the + // bitmap has an offset. match rhs.validity() { None => {} Some(rhs) => { if lhs.validity().is_none() { *lhs = lhs.with_validity(Some(rhs.clone())) } else { - lhs.apply_validity(|mut lhs| lhs &= rhs) + lhs.apply_validity(|bitmap| { + // we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact + // so that we can later assign the result to out `&mut bitmap` + let owned_lhs = std::mem::take(bitmap); + + match owned_lhs.into_mut() { + // we take alloc and write to new buffer + Either::Left(immutable) => { + // we allocate a new bitmap because that is a lot faster + // than doing the memcpy or the potential iteration of bits if + // we are dealing with an offset + let new = &immutable & rhs; + *bitmap = new; + } + // we can mutate in place, happy days. + Either::Right(mut mutable) => { + let mut mutable_ref = &mut mutable; + mutable_ref &= rhs; + *bitmap = mutable.into() + } + } + }); } } } + // we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact + // so that we can later assign the result to out `&mut lhs` + let owned_lhs = std::mem::take(lhs); - lhs.apply_values(|x| { - x.iter_mut() - .zip(rhs.values().iter()) - .for_each(|(l, r)| *l = op(*l, *r)) - }); + match owned_lhs.into_mut() { + // we take alloc and write to new buffer + Either::Left(mut immutable) => { + let values = immutable + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>(); + immutable.set_values(values.into()); + *lhs = immutable; + } + // we can mutate in place + Either::Right(mut mutable) => { + mutable.apply_values(|x| { + x.iter_mut() + .zip(rhs.values().iter()) + .for_each(|(l, r)| *l = op(*l, *r)) + }); + *lhs = mutable.into() + } + } } diff --git a/tests/it/array/boolean/mod.rs b/tests/it/array/boolean/mod.rs index 786fdb1dbc8..ffd5867bbb1 100644 --- a/tests/it/array/boolean/mod.rs +++ b/tests/it/array/boolean/mod.rs @@ -135,7 +135,7 @@ fn from_iter() { #[test] fn apply_values() { let mut a = BooleanArray::from([Some(true), Some(false), None]); - a.apply_values(|x| { + a.apply_values_mut(|x| { let mut a = std::mem::take(x); a = !a; *x = a; @@ -147,7 +147,7 @@ fn apply_values() { #[test] fn apply_validity() { let mut a = BooleanArray::from([Some(true), Some(false), None]); - a.apply_validity(|x| { + a.apply_validity_mut(|x| { let mut a = std::mem::take(x); a = !a; *x = a; diff --git a/tests/it/array/primitive/mod.rs b/tests/it/array/primitive/mod.rs index 2af71d504a5..bb0402f4006 100644 --- a/tests/it/array/primitive/mod.rs +++ b/tests/it/array/primitive/mod.rs @@ -128,7 +128,7 @@ fn into_mut_3() { #[test] fn apply_values() { let mut a = PrimitiveArray::from([Some(1), Some(2), None]); - a.apply_values(|x| { + a.apply_values_mut(|x| { x[0] = 10; }); let expected = PrimitiveArray::from([Some(10), Some(2), None]); @@ -138,7 +138,7 @@ fn apply_values() { #[test] fn apply_validity() { let mut a = PrimitiveArray::from([Some(1), Some(2), None]); - a.apply_validity(|x| { + a.apply_validity_mut(|x| { let mut a = std::mem::take(x); a = !a; *x = a;