From f9e4cf5777e16d9b1329017b1bb1d1e6c44334de Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 Mar 2022 03:34:00 -0800 Subject: [PATCH] Improve performance if dictionary kernels, add benchmark and add `take_iter_unchecked` (#1372) * Add benchmark and take_iter_unchecked. * Add Safety section for clippy * Update arrow/src/compute/kernels/comparison.rs Co-authored-by: Andrew Lamb Co-authored-by: Andrew Lamb --- arrow/benches/comparison_kernels.rs | 21 +++++++++++++++- arrow/src/array/array_binary.rs | 11 +++++++++ arrow/src/array/array_boolean.rs | 11 +++++++++ arrow/src/array/array_primitive.rs | 11 +++++++++ arrow/src/array/array_string.rs | 11 +++++++++ arrow/src/compute/kernels/comparison.rs | 33 +++++++++++++++---------- 6 files changed, 84 insertions(+), 14 deletions(-) diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index cf9ccdd977d9..4dced67ad87f 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -24,7 +24,7 @@ extern crate arrow; use arrow::compute::*; use arrow::datatypes::{ArrowNumericType, IntervalMonthDayNanoType}; use arrow::util::bench_util::*; -use arrow::{array::*, datatypes::Float32Type}; +use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; fn bench_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) where @@ -133,6 +133,18 @@ fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { .unwrap(); } +fn bench_dict_eq(arr_a: &DictionaryArray, arr_b: &DictionaryArray) +where + T: ArrowNumericType, +{ + cmp_dict_utf8::( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a == b, + ) + .unwrap(); +} + fn add_benchmark(c: &mut Criterion) { let size = 65536; let arr_a = create_primitive_array_with_seed::(size, 0.0, 42); @@ -249,6 +261,13 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("egexp_matches_utf8 scalar ends with", |b| { b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$")) }); + + let dict_arr_a = create_string_dict_array::(size, 0.0); + let dict_arr_b = create_string_dict_array::(size, 0.0); + + c.bench_function("dict eq string", |b| { + b.iter(|| bench_dict_eq(&dict_arr_a, &dict_arr_b)) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/src/array/array_binary.rs b/arrow/src/array/array_binary.rs index 40a5ee690cdf..d9118ddf7107 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow/src/array/array_binary.rs @@ -209,6 +209,17 @@ impl GenericBinaryArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value(index))) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl<'a, T: BinaryOffsetSizeTrait> GenericBinaryArray { diff --git a/arrow/src/array/array_boolean.rs b/arrow/src/array/array_boolean.rs index ca3bb2db5f85..12cecd41167e 100644 --- a/arrow/src/array/array_boolean.rs +++ b/arrow/src/array/array_boolean.rs @@ -130,6 +130,17 @@ impl BooleanArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value(index))) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl Array for BooleanArray { diff --git a/arrow/src/array/array_primitive.rs b/arrow/src/array/array_primitive.rs index 0d18032824ba..79aa8e6a0594 100644 --- a/arrow/src/array/array_primitive.rs +++ b/arrow/src/array/array_primitive.rs @@ -154,6 +154,17 @@ impl PrimitiveArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value(index))) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl Array for PrimitiveArray { diff --git a/arrow/src/array/array_string.rs b/arrow/src/array/array_string.rs index 95c7cd68abcf..b17534ac53f7 100644 --- a/arrow/src/array/array_string.rs +++ b/arrow/src/array/array_string.rs @@ -180,6 +180,17 @@ impl GenericStringArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value(index))) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl<'a, Ptr, OffsetSize: StringOffsetSizeTrait> FromIterator<&'a Option> diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index d1c33e33dee6..115407636391 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2214,19 +2214,26 @@ macro_rules! compare_dict_op { )); } - let left_iter = $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter($left.keys_iter()); - - let right_iter = $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter($right.keys_iter()); + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + $left + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($left.keys_iter()) + }; + + let right_iter = unsafe { + $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($right.keys_iter()) + }; let result = left_iter .zip(right_iter)