From 3d123b39c0a854b9d42ea54dc3e92a716da601fc Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Thu, 5 Aug 2021 00:17:36 +0800 Subject: [PATCH] Sort improve (#246) --- benches/growable.rs | 8 ++-- benches/sort_kernel.rs | 15 +++++- benches/take_kernels.rs | 2 +- src/array/specification.rs | 67 ++++++++++++++++++++++++++- src/buffer/mutable.rs | 9 ++++ src/compute/merge_sort/mod.rs | 39 ++++++++++++++++ src/compute/sort/boolean.rs | 3 +- src/compute/sort/common.rs | 43 ++++++----------- src/compute/sort/lex_sort.rs | 1 + src/compute/sort/mod.rs | 2 +- src/compute/sort/primitive/indices.rs | 3 +- src/compute/sort/primitive/sort.rs | 6 ++- src/compute/sort/utf8.rs | 1 + 13 files changed, 157 insertions(+), 42 deletions(-) diff --git a/benches/growable.rs b/benches/growable.rs index bedb89bd5a4..e0a4d2426ae 100644 --- a/benches/growable.rs +++ b/benches/growable.rs @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) { let i32_array = create_primitive_array::(1026 * 10, DataType::Int32, 0.0); c.bench_function("growable::primitive::non_null::non_null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10); values .clone() .into_iter() @@ -25,7 +25,7 @@ fn add_benchmark(c: &mut Criterion) { let i32_array = create_primitive_array::(1026 * 10, DataType::Int32, 0.0); c.bench_function("growable::primitive::non_null::null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10); values.clone().into_iter().for_each(|start| { if start % 2 == 0 { a.extend_validity(10); @@ -41,7 +41,7 @@ fn add_benchmark(c: &mut Criterion) { let values = values.collect::>(); c.bench_function("growable::primitive::null::non_null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10); values .clone() .into_iter() @@ -50,7 +50,7 @@ fn add_benchmark(c: &mut Criterion) { }); c.bench_function("growable::primitive::null::null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10); values.clone().into_iter().for_each(|start| { if start % 2 == 0 { a.extend_validity(10); diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index 30d89c441cb..e78e4462b04 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -19,7 +19,7 @@ extern crate criterion; use criterion::Criterion; -use arrow2::compute::sort::{lexsort, sort, SortColumn, SortOptions}; +use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOptions}; use arrow2::util::bench_util::*; use arrow2::{array::*, datatypes::*}; @@ -42,6 +42,15 @@ fn bench_sort(arr_a: &dyn Array) { sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap(); } +fn bench_sort_limit(arr_a: &dyn Array) { + let _: PrimitiveArray = sort_to_indices( + criterion::black_box(arr_a), + &SortOptions::default(), + Some(100), + ) + .unwrap(); +} + fn add_benchmark(c: &mut Criterion) { (10..=20).step_by(2).for_each(|log2_size| { let size = 2usize.pow(log2_size); @@ -51,6 +60,10 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a)) }); + c.bench_function(&format!("sort-limit 2^{} f32", log2_size), |b| { + b.iter(|| bench_sort_limit(&arr_a)) + }); + let arr_b = create_primitive_array_with_seed::(size, DataType::Float32, 0.0, 43); c.bench_function(&format!("lexsort 2^{} f32", log2_size), |b| { b.iter(|| bench_lexsort(&arr_a, &arr_b)) diff --git a/benches/take_kernels.rs b/benches/take_kernels.rs index 2ab746b443c..9303a719f34 100644 --- a/benches/take_kernels.rs +++ b/benches/take_kernels.rs @@ -35,7 +35,7 @@ fn create_random_index(size: usize, null_density: f32) -> PrimitiveArray { (0..size) .map(|_| { if rng.gen::() > null_density { - let value = rng.gen_range::(0i32, size as i32); + let value = rng.gen_range::(0i32..size as i32); Some(value) } else { None diff --git a/src/array/specification.rs b/src/array/specification.rs index 70576cb8621..0684d9a6f54 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -3,7 +3,7 @@ use std::convert::TryFrom; use num::Num; use crate::{ - buffer::Buffer, + buffer::{Buffer, MutableBuffer}, types::{NativeType, NaturalDataType}, }; @@ -11,6 +11,11 @@ use crate::{ pub trait Index: NativeType + NaturalDataType { fn to_usize(&self) -> usize; fn from_usize(index: usize) -> Option; + fn is_usize() -> bool { + false + } + + fn buffer_from_range(start: usize, end: usize) -> Option>; } /// Trait describing types that can be used as offsets as per Arrow specification. @@ -71,6 +76,19 @@ impl Index for i32 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for i64 { @@ -83,6 +101,19 @@ impl Index for i64 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for u32 { @@ -95,6 +126,23 @@ impl Index for u32 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn is_usize() -> bool { + std::mem::size_of::() == std::mem::size_of::() + } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for u64 { @@ -107,6 +155,23 @@ impl Index for u64 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn is_usize() -> bool { + std::mem::size_of::() == std::mem::size_of::() + } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } #[inline] diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs index 7cf5cdae394..e811901f21b 100644 --- a/src/buffer/mutable.rs +++ b/src/buffer/mutable.rs @@ -93,6 +93,15 @@ impl MutableBuffer { } } + /// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where + /// all bytes are not initialized + #[inline] + pub unsafe fn from_len(len: usize) -> Self { + let mut buffer = MutableBuffer::with_capacity(len); + buffer.set_len(len); + buffer + } + /// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff /// `self.len + additional > capacity`. /// # Example diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index e3c2d4feddd..3f4736a24aa 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -291,6 +291,29 @@ where None => self.right = None, } } + + /// Collect the MergeSortSlices to be a vec for reusing + #[warn(dead_code)] + pub fn to_vec(self, limit: Option) -> Vec { + match limit { + Some(limit) => { + let mut v = Vec::with_capacity(limit); + let mut current_len = 0; + for (index, start, len) in self { + if len + current_len >= limit { + v.push((index, start, limit - current_len)); + break; + } else { + v.push((index, start, len)); + } + current_len += len; + } + + v + } + None => self.into_iter().collect(), + } + } } impl<'a, L, R> Iterator for MergeSortSlices<'a, L, R> @@ -562,6 +585,22 @@ mod tests { Ok(()) } + #[test] + fn test_merge_slices_to_vec() -> Result<()> { + let a0: &dyn Array = &Int32Array::from_slice(&[0, 2, 4, 6, 8]); + let a1: &dyn Array = &Int32Array::from_slice(&[1, 3, 5, 7, 9]); + + let options = SortOptions::default(); + let arrays = vec![a0, a1]; + let pairs = vec![(arrays.as_ref(), &options)]; + let comparator = build_comparator(&pairs)?; + + let slices = merge_sort_slices(once(&(0, 0, 5)), once(&(1, 0, 5)), &comparator); + let vec = slices.to_vec(Some(5)); + assert_eq!(vec, [(0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1), (0, 2, 1)]); + Ok(()) + } + #[test] fn test_merge_4_i32() -> Result<()> { let a0: &dyn Array = &Int32Array::from_slice(&[0, 1]); diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs index 2e7082a5901..6afdd551ec1 100644 --- a/src/compute/sort/boolean.rs +++ b/src/compute/sort/boolean.rs @@ -26,7 +26,7 @@ pub fn sort_boolean( if !descending { valids.sort_by(|a, b| a.1.cmp(&b.1)); } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); + valids.sort_by(|a, b| b.1.cmp(&a.1)); // reverse to keep a stable ordering nulls.reverse(); } @@ -45,6 +45,7 @@ pub fn sort_boolean( // un-efficient; there are much more performant ways of sorting nulls above, anyways. if let Some(limit) = limit { values.truncate(limit); + values.shrink_to_fit(); } PrimitiveArray::::from_data(I::DATA_TYPE, values.into(), None) diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index c36763831bb..7e909e01cd6 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -22,31 +22,21 @@ fn k_element_sort_inner( F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let compare = |lhs: &I, rhs: &I| { + let mut compare = |lhs: &I, rhs: &I| { let lhs = get(lhs.to_usize()); let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() + cmp(&rhs, &lhs) }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() - }; - before.sort_unstable_by(compare); + let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); + before.sort_unstable_by(&mut compare); } else { - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs) - }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| { + let mut compare = |lhs: &I, rhs: &I| { let lhs = get(lhs.to_usize()); let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs) }; - before.sort_unstable_by(compare); + let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); + before.sort_unstable_by(&mut compare); } } @@ -74,7 +64,7 @@ fn sort_unstable_by( indices.sort_unstable_by(|lhs, rhs| { let lhs = get(lhs.to_usize()); let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() + cmp(&rhs, &lhs) }) } else { indices.sort_unstable_by(|lhs, rhs| { @@ -110,8 +100,7 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(length); - + let mut indices = unsafe { MutableBuffer::::from_len(length) }; if options.nulls_first { let mut nulls = 0; let mut valids = 0; @@ -134,12 +123,12 @@ where // Soundness: // all indices in `indices` are by construction `< array.len() == values.len()` // limit is by construction < indices.len() - let limit = limit - validity.null_count(); + let limit = limit.saturating_sub(validity.null_count()); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; sort_unstable_by(indices, get, cmp, options.descending, limit) } } else { - let last_valid_index = length - validity.null_count(); + let last_valid_index = length.saturating_sub(validity.null_count()); let mut nulls = 0; let mut valids = 0; validity.iter().zip(0..length).for_each(|(x, index)| { @@ -165,21 +154,15 @@ where indices } else { - let mut indices = unsafe { - MutableBuffer::from_trusted_len_iter_unchecked( - (0..length).map(|x| I::from_usize(x).unwrap()), - ) - }; - + let mut indices = Index::buffer_from_range(0, length).unwrap(); // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` sort_unstable_by(&mut indices, get, cmp, descending, limit); - indices.truncate(limit); indices.shrink_to_fit(); - indices }; + PrimitiveArray::::from_data(I::DATA_TYPE, indices.into(), None) } diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 0a596bfd068..3a15a10b675 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -177,6 +177,7 @@ pub fn lexsort_to_indices( let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator); before.sort_unstable_by(lex_comparator); values.truncate(limit); + values.shrink_to_fit(); } else { values.sort_unstable_by(lex_comparator); } diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 32cf63166c6..781263f45dc 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -387,7 +387,7 @@ where if !options.descending { valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())) } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()) + valids.sort_by(|a, b| cmp_array(b.1.as_ref(), a.1.as_ref())) } let values = valids.iter().map(|tuple| tuple.0); diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index e7c79947b4c..d3d7f76bd36 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -18,10 +18,11 @@ where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { + let values = array.values().as_slice(); unsafe { common::indices_sorted_unstable_by( array.validity(), - |x: usize| *array.values().as_slice().get_unchecked(x), + |x: usize| *values.get_unchecked(x), cmp, array.len(), options, diff --git a/src/compute/sort/primitive/sort.rs b/src/compute/sort/primitive/sort.rs index 197ea299416..de6aa58867b 100644 --- a/src/compute/sort/primitive/sort.rs +++ b/src/compute/sort/primitive/sort.rs @@ -34,7 +34,7 @@ where F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y).reverse()); + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(y, x)); before.sort_unstable_by(|x, y| cmp(x, y)); } else { let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y)); @@ -52,7 +52,7 @@ where } if descending { - values.sort_unstable_by(|x, y| cmp(x, y).reverse()); + values.sort_unstable_by(|x, y| cmp(y, x)); } else { values.sort_unstable_by(cmp); }; @@ -125,6 +125,7 @@ where }; // values are sorted, we can now truncate the remaining. buffer.truncate(limit); + buffer.shrink_to_fit(); (buffer.into(), new_validity.into()) } @@ -154,6 +155,7 @@ where sort_values(&mut buffer.as_mut_slice(), cmp, options.descending, limit); buffer.truncate(limit); + buffer.shrink_to_fit(); (buffer.into(), None) }; diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index 20f2d65fea7..17a75aef4e9 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -32,6 +32,7 @@ pub(super) fn indices_sorted_unstable_by_dictionary