From bace8d4bf8079940c7c5b03f474c734afb0bb1f9 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 24 May 2021 14:57:58 +0200 Subject: [PATCH 1/5] fix ub in filter record_batch --- arrow/src/compute/kernels/filter.rs | 41 +++++++++++++++++++---------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 4da07b89edde..fc541ec6a00d 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -205,6 +205,21 @@ pub fn build_filter(filter: &BooleanArray) -> Result { })) } +fn prepare_filter(filter: &BooleanArray) -> BooleanArray { + let array_data = filter.data_ref(); + let null_bitmap = array_data.null_buffer().unwrap(); + let mask = filter.values(); + let offset = filter.offset(); + + let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len()); + + let array_data = ArrayData::builder(DataType::Boolean) + .len(filter.len()) + .add_buffer(new_mask) + .build(); + BooleanArray::from(array_data) +} + /// Filters an [Array], returning elements matching the filter (i.e. where the values are true). /// /// # Example @@ -225,18 +240,7 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { if filter.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let array_data = filter.data_ref(); - let null_bitmap = array_data.null_buffer().unwrap(); - let mask = filter.values(); - let offset = filter.offset(); - - let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len()); - - let array_data = ArrayData::builder(DataType::Boolean) - .len(filter.len()) - .add_buffer(new_mask) - .build(); - let filter = BooleanArray::from(array_data); + let filter = prepare_filter(filter); // fully qualified syntax, because we have an argument with the same name return crate::compute::kernels::filter::filter(array, &filter); } @@ -251,12 +255,21 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { } /// Returns a new [RecordBatch] with arrays containing only values matching the filter. -/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. -/// Therefore, it is considered undefined behavior to pass `filter` with null values. pub fn filter_record_batch( record_batch: &RecordBatch, filter: &BooleanArray, ) -> Result { + if filter.null_count() > 0 { + // this greatly simplifies subsequent filtering code + // now we only have a boolean mask to deal with + let filter = prepare_filter(filter); + // fully qualified syntax, because we have an argument with the same name + return crate::compute::kernels::filter::filter_record_batch( + record_batch, + &filter, + ); + } + let filter = build_filter(filter)?; let filtered_arrays = record_batch .columns() From a18608bd2a50b4f00eaca801d6c1374d2b02828f Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 24 May 2021 15:10:57 +0200 Subject: [PATCH 2/5] filter fast path --- arrow/src/compute/kernels/filter.rs | 41 ++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index fc541ec6a00d..f8e68b52db0e 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -197,6 +197,10 @@ pub fn build_filter(filter: &BooleanArray) -> Result { let chunks = iter.collect::>(); Ok(Box::new(move |array: &ArrayData| { + if filter_count == array.len() { + return array.clone(); + } + let mut mutable = MutableArrayData::new(vec![array], false, filter_count); chunks .iter() @@ -205,7 +209,8 @@ pub fn build_filter(filter: &BooleanArray) -> Result { })) } -fn prepare_filter(filter: &BooleanArray) -> BooleanArray { +/// Remove null values by do a bitmask AND operation with null bits and the boolean bits. +fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { let array_data = filter.data_ref(); let null_bitmap = array_data.null_buffer().unwrap(); let mask = filter.values(); @@ -240,18 +245,22 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { if filter.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let filter = prepare_filter(filter); + let filter = prep_null_mask_filter(filter); // fully qualified syntax, because we have an argument with the same name return crate::compute::kernels::filter::filter(array, &filter); } let iter = SlicesIterator::new(filter); - - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); - let data = mutable.freeze(); - Ok(make_array(data)) + if iter.filter_count == array.len() { + let data = array.data().clone(); + Ok(make_array(data)) + } else { + let mut mutable = + MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + let data = mutable.freeze(); + Ok(make_array(data)) + } } /// Returns a new [RecordBatch] with arrays containing only values matching the filter. @@ -262,7 +271,7 @@ pub fn filter_record_batch( if filter.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let filter = prepare_filter(filter); + let filter = prep_null_mask_filter(filter); // fully qualified syntax, because we have an argument with the same name return crate::compute::kernels::filter::filter_record_batch( record_batch, @@ -638,4 +647,18 @@ mod tests { assert_eq!(out_arr0, out_arr1); Ok(()) } + + #[test] + fn test_fast_path() -> Result<()> { + let a: PrimitiveArray = + PrimitiveArray::from(vec![Some(1), Some(2), None]); + let mask = BooleanArray::from(vec![true, true, true]); + let out = filter(&a, &mask)?; + let b = out + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(&a, b); + Ok(()) + } } From c011cff3cb277a736c7ac2c145afacc7d8be9923 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 25 May 2021 09:28:27 +0200 Subject: [PATCH 3/5] add all false fast path --- arrow/src/array/data.rs | 2 +- arrow/src/compute/kernels/filter.rs | 55 ++++++++++++++++++++--------- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 9d5b0ee023db..172bdaac9eb6 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -412,7 +412,7 @@ impl ArrayData { } /// Returns a new empty [ArrayData] valid for `data_type`. - pub(super) fn new_empty(data_type: &DataType) -> Self { + pub fn new_empty(data_type: &DataType) -> Self { let buffers = new_buffers(data_type, 0); let [buffer1, buffer2] = buffers; let buffers = into_buffers(data_type, buffer1, buffer2); diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index f8e68b52db0e..6701c4644367 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -197,15 +197,18 @@ pub fn build_filter(filter: &BooleanArray) -> Result { let chunks = iter.collect::>(); Ok(Box::new(move |array: &ArrayData| { - if filter_count == array.len() { - return array.clone(); + match filter_count { + // return all + len if len == array.len() => array.clone(), + 0 => ArrayData::new_empty(array.data_type()), + _ => { + let mut mutable = MutableArrayData::new(vec![array], false, filter_count); + chunks + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + mutable.freeze() + } } - - let mut mutable = MutableArrayData::new(vec![array], false, filter_count); - chunks - .iter() - .for_each(|(start, end)| mutable.extend(0, *start, *end)); - mutable.freeze() })) } @@ -251,15 +254,25 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { } let iter = SlicesIterator::new(filter); - if iter.filter_count == array.len() { - let data = array.data().clone(); - Ok(make_array(data)) - } else { - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); - let data = mutable.freeze(); - Ok(make_array(data)) + match iter.filter_count { + 0 => { + // return empty + let data = ArrayData::new_empty(array.data_type()); + Ok(make_array(data)) + } + len if len == array.len() => { + // return all + let data = array.data().clone(); + Ok(make_array(data)) + } + _ => { + // actually filter + let mut mutable = + MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + let data = mutable.freeze(); + Ok(make_array(data)) + } } } @@ -652,6 +665,8 @@ mod tests { fn test_fast_path() -> Result<()> { let a: PrimitiveArray = PrimitiveArray::from(vec![Some(1), Some(2), None]); + + // all true let mask = BooleanArray::from(vec![true, true, true]); let out = filter(&a, &mask)?; let b = out @@ -659,6 +674,12 @@ mod tests { .downcast_ref::>() .unwrap(); assert_eq!(&a, b); + + // all false + let mask = BooleanArray::from(vec![false, false, false]); + let out = filter(&a, &mask)?; + assert_eq!(out.len(), 0); + assert_eq!(out.data_type(), &DataType::Int64); Ok(()) } } From 78cb9668d43792ddd1cd2dcef25121024def7e29 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 26 May 2021 08:00:04 +0200 Subject: [PATCH 4/5] use new_empty_array --- arrow/src/compute/kernels/filter.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 6701c4644367..46eb05b68590 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -257,8 +257,7 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { match iter.filter_count { 0 => { // return empty - let data = ArrayData::new_empty(array.data_type()); - Ok(make_array(data)) + Ok(new_empty_array(array.data_type())) } len if len == array.len() => { // return all From e88a3b17f897e0aa883a25499f952f7346fd34ab Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 26 May 2021 08:34:30 +0200 Subject: [PATCH 5/5] rename filter kernel argument rename argument: 'filter' to 'predicate' to reduce name collissions. --- arrow/src/compute/kernels/filter.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 46eb05b68590..b15692e90f2f 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -244,16 +244,15 @@ fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// # Ok(()) /// # } /// ``` -pub fn filter(array: &Array, filter: &BooleanArray) -> Result { - if filter.null_count() > 0 { +pub fn filter(array: &Array, predicate: &BooleanArray) -> Result { + if predicate.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let filter = prep_null_mask_filter(filter); - // fully qualified syntax, because we have an argument with the same name - return crate::compute::kernels::filter::filter(array, &filter); + let predicate = prep_null_mask_filter(predicate); + return filter(array, &predicate); } - let iter = SlicesIterator::new(filter); + let iter = SlicesIterator::new(predicate); match iter.filter_count { 0 => { // return empty @@ -278,20 +277,16 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { /// Returns a new [RecordBatch] with arrays containing only values matching the filter. pub fn filter_record_batch( record_batch: &RecordBatch, - filter: &BooleanArray, + predicate: &BooleanArray, ) -> Result { - if filter.null_count() > 0 { + if predicate.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let filter = prep_null_mask_filter(filter); - // fully qualified syntax, because we have an argument with the same name - return crate::compute::kernels::filter::filter_record_batch( - record_batch, - &filter, - ); + let predicate = prep_null_mask_filter(predicate); + return filter_record_batch(record_batch, &predicate); } - let filter = build_filter(filter)?; + let filter = build_filter(predicate)?; let filtered_arrays = record_batch .columns() .iter()