From 9c150405841f17feeaa6f5dbd91e5e42a9142fe3 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 23 Jul 2021 05:44:11 +0000 Subject: [PATCH 1/5] Added limit to sort. --- benches/sort_kernel.rs | 4 +- src/buffer/mutable.rs | 9 + src/compute/merge_sort/mod.rs | 4 +- src/compute/sort/boolean.rs | 52 +++++ src/compute/sort/lex_sort.rs | 51 ++--- src/compute/sort/mod.rs | 261 +++++++++++-------------- src/compute/sort/primitive/indices.rs | 262 +++++++++++++++++++++----- src/compute/sort/primitive/mod.rs | 19 +- src/compute/sort/primitive/sort.rs | 153 +++++++++++---- 9 files changed, 534 insertions(+), 281 deletions(-) create mode 100644 src/compute/sort/boolean.rs diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index a55a6971fae..048cf40e403 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -35,11 +35,11 @@ fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) { }, ]; - criterion::black_box(lexsort(&columns).unwrap()); + criterion::black_box(lexsort(&columns, None).unwrap()); } fn bench_sort(arr_a: &dyn Array) { - sort(criterion::black_box(arr_a), &SortOptions::default()).unwrap(); + sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap(); } fn add_benchmark(c: &mut Criterion) { diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs index 96e19879602..971c276c582 100644 --- a/src/buffer/mutable.rs +++ b/src/buffer/mutable.rs @@ -186,6 +186,15 @@ impl MutableBuffer { self.len = 0 } + /// Shortens the buffer. + /// If `len` is greater or equal to the buffers' current length, this has no effect. + #[inline] + pub fn truncate(&mut self, len: usize) { + if len < self.len { + self.len = len; + } + } + /// Returns the data stored in this buffer as a slice. #[inline] pub fn as_slice(&self) -> &[T] { diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index 339d07aca32..5c4fddfaaca 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -637,8 +637,8 @@ mod tests { let options = SortOptions::default(); // sort individually, potentially in parallel. - let a0 = sort(a0, &options)?; - let a1 = sort(a1, &options)?; + let a0 = sort(a0, &options, None)?; + let a1 = sort(a1, &options, None)?; // merge then. If multiple arrays, this can be applied in parallel. let result = merge_sort(a0.as_ref(), a1.as_ref(), &options)?; diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs new file mode 100644 index 00000000000..385c7c21355 --- /dev/null +++ b/src/compute/sort/boolean.rs @@ -0,0 +1,52 @@ +use crate::{ + array::{Array, BooleanArray, Int32Array}, + buffer::MutableBuffer, + datatypes::DataType, +}; + +use super::SortOptions; + +/// Returns the indices that would sort a [`BooleanArray`]. +pub fn sort_boolean( + values: &BooleanArray, + value_indices: Vec, + null_indices: Vec, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let descending = options.descending; + + // create tuples that are used for sorting + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + + let mut nulls = null_indices; + + if !descending { + valids.sort_by(|a, b| a.1.cmp(&b.1)); + } else { + valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + + let mut values = MutableBuffer::::with_capacity(values.len()); + + if options.nulls_first { + values.extend_from_slice(nulls.as_slice()); + valids.iter().for_each(|x| values.push(x.0)); + } else { + // nulls last + valids.iter().for_each(|x| values.push(x.0)); + values.extend_from_slice(nulls.as_slice()); + } + + // un-efficient; there are much more performant ways of sorting nulls above, anyways. + if let Some(limit) = limit { + values.truncate(limit); + } + + Int32Array::from_data(DataType::Int32, values.into(), None) +} diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 44f0616931f..0ac05f8df04 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::cmp::Ordering; use crate::compute::take; @@ -72,8 +55,8 @@ pub struct SortColumn<'a> { /// assert_eq!(sorted.value(1), -64); /// assert!(sorted.is_null(0)); /// ``` -pub fn lexsort(columns: &[SortColumn]) -> Result>> { - let indices = lexsort_to_indices(columns)?; +pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result>> { + let indices = lexsort_to_indices(columns, limit)?; columns .iter() .map(|c| take::take(c.values, &indices)) @@ -135,9 +118,12 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu }) } -/// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer -/// [`Int32Array`] of indices. -pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> { +/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`] +/// representing the indices that would sort the columns. +pub fn lexsort_to_indices( + columns: &[SortColumn], + limit: Option, +) -> Result> { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -146,7 +132,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> if columns.len() == 1 { // fallback to non-lexical sort let column = &columns[0]; - return sort_to_indices(column.values, &column.options.unwrap_or_default()); + return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit); } let row_count = columns[0].values.len(); @@ -180,7 +166,15 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> // Safety: `0..row_count` is TrustedLen let mut values = unsafe { MutableBuffer::::from_trusted_len_iter_unchecked(0..row_count as i32) }; - values.sort_unstable_by(lex_comparator); + + if let Some(limit) = limit { + let limit = limit.min(row_count); + let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator); + before.sort_unstable_by(lex_comparator); + values.truncate(limit); + } else { + values.sort_unstable_by(lex_comparator); + } Ok(PrimitiveArray::::from_data( DataType::Int32, @@ -196,7 +190,14 @@ mod tests { use super::*; fn test_lex_sort_arrays(input: Vec, expected: Vec>) { - let sorted = lexsort(&input).unwrap(); + let sorted = lexsort(&input, None).unwrap(); + assert_eq!(sorted, expected); + + let sorted = lexsort(&input, Some(2)).unwrap(); + let expected = expected + .into_iter() + .map(|x| x.slice(0, 2)) + .collect::>(); assert_eq!(sorted, expected); } diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 4c28712a327..6b30d331c70 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -1,22 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines sort kernel for `ArrayRef` - use std::cmp::{Ordering, Reverse}; use crate::array::ord; @@ -32,6 +13,7 @@ use crate::{ use crate::buffer::MutableBuffer; use num::ToPrimitive; +mod boolean; mod lex_sort; mod primitive; @@ -39,50 +21,53 @@ pub(crate) use lex_sort::{build_compare, Compare}; pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn}; macro_rules! dyn_sort { - ($ty:ty, $array:expr, $cmp:expr, $options:expr) => {{ + ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ let array = $array .as_any() .downcast_ref::>() .unwrap(); Ok(Box::new(primitive::sort_by::<$ty, _>( - &array, $cmp, $options, + &array, $cmp, $options, $limit, ))) }}; } -/// Sort the `ArrayRef` using `SortOptions`. +/// Sort the [`Array`] using [`SortOptions`]. /// -/// Performs a stable sort on values and indices. Nulls are ordered according to the `nulls_first` flag in `options`. +/// Performs an unstable sort on values and indices. Nulls are ordered according to the `nulls_first` flag in `options`. /// Floats are sorted using IEEE 754 totalOrder -/// -/// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. -/// -pub fn sort(values: &dyn Array, options: &SortOptions) -> Result> { +/// # Errors +/// Errors if the [`DataType`] is not supported. +pub fn sort( + values: &dyn Array, + options: &SortOptions, + limit: Option, +) -> Result> { match values.data_type() { - DataType::Int8 => dyn_sort!(i8, values, ord::total_cmp, options), - DataType::Int16 => dyn_sort!(i16, values, ord::total_cmp, options), + DataType::Int8 => dyn_sort!(i8, values, ord::total_cmp, options, limit), + DataType::Int16 => dyn_sort!(i16, values, ord::total_cmp, options, limit), DataType::Int32 | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sort!(i32, values, ord::total_cmp, options) + dyn_sort!(i32, values, ord::total_cmp, options, limit) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, None) - | DataType::Duration(_) => dyn_sort!(i64, values, ord::total_cmp, options), - DataType::UInt8 => dyn_sort!(u8, values, ord::total_cmp, options), - DataType::UInt16 => dyn_sort!(u16, values, ord::total_cmp, options), - DataType::UInt32 => dyn_sort!(u32, values, ord::total_cmp, options), - DataType::UInt64 => dyn_sort!(u64, values, ord::total_cmp, options), - DataType::Float32 => dyn_sort!(f32, values, ord::total_cmp_f32, options), - DataType::Float64 => dyn_sort!(f64, values, ord::total_cmp_f64, options), + | DataType::Duration(_) => dyn_sort!(i64, values, ord::total_cmp, options, limit), + DataType::UInt8 => dyn_sort!(u8, values, ord::total_cmp, options, limit), + DataType::UInt16 => dyn_sort!(u16, values, ord::total_cmp, options, limit), + DataType::UInt32 => dyn_sort!(u32, values, ord::total_cmp, options, limit), + DataType::UInt64 => dyn_sort!(u64, values, ord::total_cmp, options, limit), + DataType::Float32 => dyn_sort!(f32, values, ord::total_cmp_f32, options, limit), + DataType::Float64 => dyn_sort!(f64, values, ord::total_cmp_f64, options, limit), DataType::Interval(IntervalUnit::DayTime) => { - dyn_sort!(days_ms, values, ord::total_cmp, options) + dyn_sort!(days_ms, values, ord::total_cmp, options, limit) } _ => { - let indices = sort_to_indices(values, options)?; + let indices = sort_to_indices(values, options, limit)?; take::take(values, &indices) } } @@ -95,66 +80,76 @@ fn partition_validity(array: &dyn Array) -> (Vec, Vec) { } macro_rules! dyn_sort_indices { - ($ty:ty, $array:expr, $cmp:expr, $options:expr) => {{ + ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ let array = $array .as_any() .downcast_ref::>() .unwrap(); - Ok(primitive::indices_sorted_by::<$ty, _>( - &array, $cmp, $options, + Ok(primitive::indices_sorted_unstable_by::<$ty, _>( + &array, $cmp, $options, $limit, )) }}; } -/// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices. +/// Sort elements from `values` into [`Int32Array`] of indices. /// For floating point arrays any NaN values are considered to be greater than any other non-null value -pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { +pub fn sort_to_indices( + values: &dyn Array, + options: &SortOptions, + limit: Option, +) -> Result { match values.data_type() { DataType::Boolean => { let (v, n) = partition_validity(values); - Ok(sort_boolean(values, v, n, options)) + Ok(boolean::sort_boolean( + values.as_any().downcast_ref().unwrap(), + v, + n, + options, + limit, + )) } - DataType::Int8 => dyn_sort_indices!(i8, values, ord::total_cmp, options), - DataType::Int16 => dyn_sort_indices!(i16, values, ord::total_cmp, options), + DataType::Int8 => dyn_sort_indices!(i8, values, ord::total_cmp, options, limit), + DataType::Int16 => dyn_sort_indices!(i16, values, ord::total_cmp, options, limit), DataType::Int32 | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sort_indices!(i32, values, ord::total_cmp, options) + dyn_sort_indices!(i32, values, ord::total_cmp, options, limit) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, None) - | DataType::Duration(_) => dyn_sort_indices!(i64, values, ord::total_cmp, options), - DataType::UInt8 => dyn_sort_indices!(u8, values, ord::total_cmp, options), - DataType::UInt16 => dyn_sort_indices!(u16, values, ord::total_cmp, options), - DataType::UInt32 => dyn_sort_indices!(u32, values, ord::total_cmp, options), - DataType::UInt64 => dyn_sort_indices!(u64, values, ord::total_cmp, options), - DataType::Float32 => dyn_sort_indices!(f32, values, ord::total_cmp_f32, options), - DataType::Float64 => dyn_sort_indices!(f64, values, ord::total_cmp_f64, options), + | DataType::Duration(_) => dyn_sort_indices!(i64, values, ord::total_cmp, options, limit), + DataType::UInt8 => dyn_sort_indices!(u8, values, ord::total_cmp, options, limit), + DataType::UInt16 => dyn_sort_indices!(u16, values, ord::total_cmp, options, limit), + DataType::UInt32 => dyn_sort_indices!(u32, values, ord::total_cmp, options, limit), + DataType::UInt64 => dyn_sort_indices!(u64, values, ord::total_cmp, options, limit), + DataType::Float32 => dyn_sort_indices!(f32, values, ord::total_cmp_f32, options, limit), + DataType::Float64 => dyn_sort_indices!(f64, values, ord::total_cmp_f64, options, limit), DataType::Interval(IntervalUnit::DayTime) => { - dyn_sort_indices!(days_ms, values, ord::total_cmp, options) + dyn_sort_indices!(days_ms, values, ord::total_cmp, options, limit) } DataType::Utf8 => { let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options)) + Ok(sort_utf8::(values, v, n, options, limit)) } DataType::LargeUtf8 => { let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options)) + Ok(sort_utf8::(values, v, n, options, limit)) } DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -164,14 +159,14 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -181,14 +176,14 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -198,14 +193,14 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { let (v, n) = partition_validity(values); match key_type.as_ref() { - DataType::Int8 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int16 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int32 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int64 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt8 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt16 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt32 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt64 => Ok(sort_string_dictionary::(values, v, n, options)), + DataType::Int8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for dictionary key type {:?}", t @@ -303,55 +298,13 @@ impl Default for SortOptions { } } -/// Sort primitive values -fn sort_boolean( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, -) -> Int32Array { - let values = values - .as_any() - .downcast_ref::() - .expect("Unable to downcast to boolean array"); - let descending = options.descending; - - // create tuples that are used for sorting - let mut valids = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>(); - - let mut nulls = null_indices; - - if !descending { - valids.sort_by(|a, b| a.1.cmp(&b.1)); - } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - - let mut values = MutableBuffer::::with_capacity(values.len()); - - if options.nulls_first { - values.extend_from_slice(nulls.as_slice()); - valids.iter().for_each(|x| values.push(x.0)); - } else { - // nulls last - valids.iter().for_each(|x| values.push(x.0)); - values.extend_from_slice(nulls.as_slice()); - } - - Int32Array::from_data(DataType::Int32, values.into(), None) -} - /// Sort strings fn sort_utf8( values: &dyn Array, value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Int32Array { let values = values.as_any().downcast_ref::>().unwrap(); @@ -360,6 +313,7 @@ fn sort_utf8( value_indices, null_indices, options, + limit, |array, idx| array.value(idx as usize), ) } @@ -370,6 +324,7 @@ fn sort_string_dictionary( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Int32Array { let values: &DictionaryArray = values .as_any() @@ -386,6 +341,7 @@ fn sort_string_dictionary( value_indices, null_indices, options, + limit, |array: &PrimitiveArray, idx| -> &str { let key: T = array.value(idx as usize); dict.value(key.to_usize().unwrap()) @@ -400,6 +356,7 @@ fn sort_string_helper<'a, A: Array, F>( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, value_fn: F, ) -> Int32Array where @@ -420,10 +377,15 @@ where let valids = valids.iter().map(|tuple| tuple.0); let values = if options.nulls_first { - let values = nulls.into_iter().chain(valids); + let values = nulls + .into_iter() + .chain(valids) + .take(limit.unwrap_or_else(|| values.len())); Buffer::::from_trusted_len_iter(values) } else { - let values = valids.chain(nulls.into_iter()); + let values = valids + .chain(nulls.into_iter()) + .take(limit.unwrap_or_else(|| values.len())); Buffer::::from_trusted_len_iter(values) }; @@ -435,6 +397,7 @@ fn sort_list( value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Int32Array where O: Offset, @@ -469,17 +432,19 @@ where let values = valids.iter().map(|tuple| tuple.0); - let values = if options.nulls_first { + let mut values = if options.nulls_first { let mut buffer = MutableBuffer::::from_trusted_len_iter(null_indices.into_iter()); - values.for_each(|x| buffer.push(x)); - buffer.into() + buffer.extend(values); + buffer } else { let mut buffer = MutableBuffer::::from_trusted_len_iter(values); - null_indices.iter().for_each(|x| buffer.push(*x)); - buffer.into() + buffer.extend(null_indices); + buffer }; - PrimitiveArray::::from_data(DataType::Int32, values, None) + values.truncate(limit.unwrap_or_else(|| values.len())); + + PrimitiveArray::::from_data(DataType::Int32, values.into(), None) } /// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord). @@ -507,7 +472,7 @@ mod tests { ) { let output = BooleanArray::from(data); let expected = Int32Array::from_slice(expected_data); - let output = sort_to_indices(&output, &options).unwrap(); + let output = sort_to_indices(&output, &options, None).unwrap(); assert_eq!(output, expected) } @@ -521,7 +486,7 @@ mod tests { { let input = PrimitiveArray::::from(data).to(data_type.clone()); let expected = PrimitiveArray::::from(expected_data).to(data_type); - let output = sort(&input, &options).unwrap(); + let output = sort(&input, &options, None).unwrap(); assert_eq!(expected, output.as_ref()) } @@ -532,7 +497,7 @@ mod tests { ) { let input = Utf8Array::::from(&data.to_vec()); let expected = Int32Array::from_slice(expected_data); - let output = sort_to_indices(&input, &options).unwrap(); + let output = sort_to_indices(&input, &options, None).unwrap(); assert_eq!(output, expected) } @@ -543,7 +508,7 @@ mod tests { ) { let input = Utf8Array::::from(&data.to_vec()); let expected = Utf8Array::::from(&expected_data.to_vec()); - let output = sort(&input, &options).unwrap(); + let output = sort(&input, &options, None).unwrap(); assert_eq!(expected, output.as_ref()) } @@ -560,7 +525,7 @@ mod tests { expected.try_extend(expected_data.iter().copied()).unwrap(); let expected = expected.into_arc(); - let output = sort(input.as_ref(), &options).unwrap(); + let output = sort(input.as_ref(), &options, None).unwrap(); assert_eq!(expected.as_ref(), output.as_ref()) } @@ -1100,9 +1065,9 @@ mod tests { nulls_first: true, }; if can_sort(&d1) { - assert!(sort(array.as_ref(), &options).is_ok()); + assert!(sort(array.as_ref(), &options, None).is_ok()); } else { - assert!(sort(array.as_ref(), &options).is_err()); + assert!(sort(array.as_ref(), &options, None).is_err()); } }); } diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 41dea70c579..0660310cf19 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use crate::{ array::{Array, PrimitiveArray}, buffer::MutableBuffer, @@ -27,19 +10,72 @@ use super::super::SortOptions; /// # Safety /// `indices[i] < values.len()` for all i #[inline] -unsafe fn sort_inner(indices: &mut [i32], values: &[T], mut cmp: F, descending: bool) -where +unsafe fn k_element_sort_inner( + indices: &mut [i32], + values: &[T], + descending: bool, + limit: usize, + mut cmp: F, +) where T: NativeType, F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - indices.sort_by(|lhs, rhs| { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = values.get_unchecked(*lhs as usize); + let rhs = values.get_unchecked(*rhs as usize); + cmp(lhs, rhs).reverse() + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = values.get_unchecked(*lhs as usize); + let rhs = values.get_unchecked(*rhs as usize); + cmp(lhs, rhs).reverse() + }; + before.sort_unstable_by(compare); + } else { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = values.get_unchecked(*lhs as usize); + let rhs = values.get_unchecked(*rhs as usize); + cmp(lhs, rhs) + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = values.get_unchecked(*lhs as usize); + let rhs = values.get_unchecked(*rhs as usize); + cmp(lhs, rhs) + }; + before.sort_unstable_by(compare); + } +} + +/// # Safety +/// Safe iff +/// * `indices[i] < values.len()` for all i +/// * `limit < values.len()` +#[inline] +unsafe fn sort_unstable_by( + indices: &mut [i32], + values: &[T], + mut cmp: F, + descending: bool, + limit: usize, +) where + T: NativeType, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != indices.len() { + return k_element_sort_inner(indices, values, descending, limit, cmp); + } + + if descending { + indices.sort_unstable_by(|lhs, rhs| { let lhs = values.get_unchecked(*lhs as usize); let rhs = values.get_unchecked(*rhs as usize); cmp(lhs, rhs).reverse() }) } else { - indices.sort_by(|lhs, rhs| { + indices.sort_unstable_by(|lhs, rhs| { let lhs = values.get_unchecked(*lhs as usize); let rhs = values.get_unchecked(*rhs as usize); cmp(lhs, rhs) @@ -47,10 +83,12 @@ where } } -pub fn indices_sorted_by( +/// Unstable sort of indices. +pub fn indices_sorted_unstable_by( array: &PrimitiveArray, cmp: F, options: &SortOptions, + limit: Option, ) -> PrimitiveArray where T: NativeType, @@ -60,7 +98,11 @@ where let values = array.values(); let validity = array.validity(); - if let Some(validity) = validity { + let limit = limit.unwrap_or_else(|| array.len()); + // Safety: without this, we go out of bounds when limit >= array.len(). + let limit = limit.min(array.len()); + + let indices = if let Some(validity) = validity { let mut indices = MutableBuffer::::from_len_zeroed(array.len()); if options.nulls_first { @@ -69,8 +111,8 @@ where validity .iter() .zip(0..array.len() as i32) - .for_each(|(x, index)| { - if x { + .for_each(|(is_valid, index)| { + if is_valid { indices[validity.null_count() + valids] = index; valids += 1; } else { @@ -78,15 +120,16 @@ where nulls += 1; } }); - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - unsafe { - sort_inner( - &mut indices.as_mut_slice()[validity.null_count()..], - values, - cmp, - options.descending, - ) + + if limit > validity.null_count() { + // when limit is larger, we must sort values: + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction < indices.len() + let limit = limit - validity.null_count(); + let indices = &mut indices.as_mut_slice()[validity.null_count()..]; + unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) } } } else { let last_valid_index = array.len() - validity.null_count(); @@ -107,27 +150,31 @@ where // Soundness: // all indices in `indices` are by construction `< array.len() == values.len()` - unsafe { - sort_inner( - &mut indices.as_mut_slice()[..last_valid_index], - values, - cmp, - options.descending, - ) - }; + // limit is by construction <= values.len() + let limit = limit.min(last_valid_index); + let indices = &mut indices.as_mut_slice()[..last_valid_index]; + unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) }; } - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) + indices.truncate(limit); + indices.shrink_to_fit(); + + indices } else { let mut indices = unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..values.len() as i32) }; // Soundness: // indices are by construction `< values.len()` - unsafe { sort_inner(&mut indices, values, cmp, descending) }; + // limit is by construction `< values.len()` + unsafe { sort_unstable_by(&mut indices, values, cmp, descending, limit) }; - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) - } + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + }; + PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) } #[cfg(test)] @@ -136,13 +183,18 @@ mod tests { use crate::array::ord; use crate::array::*; - fn test(data: &[Option], data_type: DataType, options: SortOptions, expected_data: &[i32]) - where + fn test( + data: &[Option], + data_type: DataType, + options: SortOptions, + limit: Option, + expected_data: &[i32], + ) where T: NativeType + std::cmp::Ord, { let input = PrimitiveArray::::from(data).to(data_type); let expected = Int32Array::from_slice(&expected_data); - let output = indices_sorted_by(&input, ord::total_cmp, &options); + let output = indices_sorted_unstable_by(&input, ord::total_cmp, &options, limit); assert_eq!(output, expected) } @@ -155,6 +207,7 @@ mod tests { descending: false, nulls_first: true, }, + None, &[0, 5, 3, 1, 4, 2], ); } @@ -168,6 +221,7 @@ mod tests { descending: false, nulls_first: false, }, + None, &[3, 1, 4, 2, 0, 5], ); } @@ -181,6 +235,7 @@ mod tests { descending: true, nulls_first: true, }, + None, &[0, 5, 2, 1, 4, 3], ); } @@ -194,7 +249,116 @@ mod tests { descending: true, nulls_first: false, }, + None, &[2, 1, 4, 3, 0, 5], ); } + + #[test] + fn limit_ascending_nulls_first() { + // nulls sorted + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: true, + }, + Some(2), + &[0, 5], + ); + + // nulls and values sorted + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: true, + }, + Some(4), + &[0, 5, 3, 1], + ); + } + + #[test] + fn limit_ascending_nulls_last() { + // values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: false, + }, + Some(2), + &[3, 1], + ); + + // values and nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: false, + }, + Some(5), + &[3, 1, 4, 2, 0], + ); + } + + #[test] + fn limit_descending_nulls_first() { + // nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: true, + }, + Some(2), + &[0, 5], + ); + + // nulls and values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: true, + }, + Some(4), + &[0, 5, 2, 1], + ); + } + + #[test] + fn limit_descending_nulls_last() { + // values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: false, + }, + Some(2), + &[2, 1], + ); + + // values and nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: false, + }, + Some(5), + &[2, 1, 4, 3, 0], + ); + } } diff --git a/src/compute/sort/primitive/mod.rs b/src/compute/sort/primitive/mod.rs index 6be3f81c96c..ccfeb15b868 100644 --- a/src/compute/sort/primitive/mod.rs +++ b/src/compute/sort/primitive/mod.rs @@ -1,22 +1,5 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - mod indices; mod sort; -pub use indices::indices_sorted_by; +pub use indices::indices_sorted_unstable_by; pub use sort::sort_by; diff --git a/src/compute/sort/primitive/sort.rs b/src/compute/sort/primitive/sort.rs index 34dbf373136..197ea299416 100644 --- a/src/compute/sort/primitive/sort.rs +++ b/src/compute/sort/primitive/sort.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::MutableBuffer; +use crate::bitmap::Bitmap; +use crate::buffer::{Buffer, MutableBuffer}; use crate::{ array::{Array, PrimitiveArray}, bitmap::{utils::SlicesIterator, MutableBitmap}, @@ -24,11 +25,32 @@ use crate::{ use super::super::SortOptions; -fn sort_inner(values: &mut [T], mut cmp: F, descending: bool) +/// # Safety +/// `indices[i] < values.len()` for all i +#[inline] +fn k_element_sort_inner(values: &mut [T], descending: bool, limit: usize, mut cmp: F) where T: NativeType, F: FnMut(&T, &T) -> std::cmp::Ordering, { + if descending { + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y).reverse()); + before.sort_unstable_by(|x, y| cmp(x, y)); + } else { + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y)); + before.sort_unstable_by(|x, y| cmp(x, y)); + } +} + +fn sort_values(values: &mut [T], mut cmp: F, descending: bool, limit: usize) +where + T: NativeType, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != values.len() { + return k_element_sort_inner(values, descending, limit, cmp); + } + if descending { values.sort_unstable_by(|x, y| cmp(x, y).reverse()); } else { @@ -36,54 +58,106 @@ where }; } +fn sort_nullable( + values: &[T], + validity: &Bitmap, + cmp: F, + options: &SortOptions, + limit: usize, +) -> (Buffer, Option) +where + T: NativeType, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + assert!(limit <= values.len()); + if options.nulls_first && limit < validity.null_count() { + let mut buffer = MutableBuffer::::with_capacity(limit); + buffer.extend_constant(limit, T::default()); + let bitmap = MutableBitmap::from_trusted_len_iter(std::iter::repeat(false).take(limit)); + return (buffer.into(), bitmap.into()); + } + + let nulls = std::iter::repeat(false).take(validity.null_count()); + let valids = std::iter::repeat(true).take(values.len() - validity.null_count()); + + let mut buffer = MutableBuffer::::with_capacity(values.len()); + let mut new_validity = MutableBitmap::with_capacity(values.len()); + let slices = SlicesIterator::new(validity); + + if options.nulls_first { + // validity is [0,0,0,...,1,1,1,1] + new_validity.extend_from_trusted_len_iter(nulls.chain(valids).take(limit)); + + // extend buffer with constants followed by non-null values + buffer.extend_constant(validity.null_count(), T::default()); + for (start, len) in slices { + buffer.extend_from_slice(&values[start..start + len]) + } + + // sort values + sort_values( + &mut buffer.as_mut_slice()[validity.null_count()..], + cmp, + options.descending, + limit - validity.null_count(), + ); + } else { + // validity is [1,1,1,...,0,0,0,0] + new_validity.extend_from_trusted_len_iter(valids.chain(nulls).take(limit)); + + // extend buffer with non-null values + for (start, len) in slices { + buffer.extend_from_slice(&values[start..start + len]) + } + + // sort all non-null values + sort_values( + &mut buffer.as_mut_slice(), + cmp, + options.descending, + limit - validity.null_count(), + ); + + if limit > values.len() - validity.null_count() { + // extend remaining with nulls + buffer.extend_constant(validity.null_count(), T::default()); + } + }; + // values are sorted, we can now truncate the remaining. + buffer.truncate(limit); + + (buffer.into(), new_validity.into()) +} + /// Sorts a [`PrimitiveArray`] according to `cmp` comparator and [`SortOptions`]. -pub fn sort_by(array: &PrimitiveArray, cmp: F, options: &SortOptions) -> PrimitiveArray +pub fn sort_by( + array: &PrimitiveArray, + cmp: F, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray where T: NativeType, F: FnMut(&T, &T) -> std::cmp::Ordering, { + let limit = limit.unwrap_or_else(|| array.len()); + let limit = limit.min(array.len()); + let values = array.values(); let validity = array.validity(); let (buffer, validity) = if let Some(validity) = validity { - let nulls = std::iter::repeat(false).take(validity.null_count()); - let valids = std::iter::repeat(true).take(array.len() - validity.null_count()); - - let mut buffer = MutableBuffer::::with_capacity(array.len()); - let mut new_validity = MutableBitmap::with_capacity(array.len()); - let slices = SlicesIterator::new(validity); - - if options.nulls_first { - new_validity.extend_from_trusted_len_iter(nulls.chain(valids)); - (0..validity.null_count()).for_each(|_| buffer.push(T::default())); - for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) - } - sort_inner( - &mut buffer.as_mut_slice()[validity.null_count()..], - cmp, - options.descending, - ) - } else { - new_validity.extend_from_trusted_len_iter(valids.chain(nulls)); - for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) - } - sort_inner(&mut buffer.as_mut_slice(), cmp, options.descending); - - (0..validity.null_count()).for_each(|_| buffer.push(T::default())); - }; - - (buffer, new_validity.into()) + sort_nullable(values, validity, cmp, options, limit) } else { let mut buffer = MutableBuffer::::new(); buffer.extend_from_slice(values); - sort_inner(&mut buffer.as_mut_slice(), cmp, options.descending); + sort_values(&mut buffer.as_mut_slice(), cmp, options.descending, limit); + buffer.truncate(limit); - (buffer, None) + (buffer.into(), None) }; - PrimitiveArray::::from_data(array.data_type().clone(), buffer.into(), validity) + PrimitiveArray::::from_data(array.data_type().clone(), buffer, validity) } #[cfg(test)] @@ -103,8 +177,13 @@ mod tests { T: NativeType + std::cmp::Ord, { let input = PrimitiveArray::::from(data).to(data_type.clone()); - let expected = PrimitiveArray::::from(expected_data).to(data_type); - let output = sort_by(&input, ord::total_cmp, &options); + let expected = PrimitiveArray::::from(expected_data).to(data_type.clone()); + let output = sort_by(&input, ord::total_cmp, &options, None); + assert_eq!(expected, output); + + // with limit + let expected = PrimitiveArray::::from(&expected_data[..3]).to(data_type); + let output = sort_by(&input, ord::total_cmp, &options, Some(3)); assert_eq!(expected, output) } From e7b33f1b3baebfdcbee9b94398dcd89278c07a61 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 23 Jul 2021 16:04:47 +0000 Subject: [PATCH 2/5] Added bench for utf8 --- benches/sort_kernel.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index 048cf40e403..ed48c4c8e42 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -66,6 +66,11 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function(&format!("lexsort null 2^{} f32", log2_size), |b| { b.iter(|| bench_lexsort(&arr_a, &arr_b)) }); + + let arr_a = create_string_array::(size, 0.1); + c.bench_function(&format!("sort utf8 null 2^{}", log2_size), |b| { + b.iter(|| bench_sort(&arr_a)) + }); }); } From 128be64aa08badc934e9d3e8678a2ebee2529c3d Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 23 Jul 2021 16:18:54 +0000 Subject: [PATCH 3/5] Generalized function. --- src/compute/sort/common.rs | 76 ++++++++++++++++++ src/compute/sort/mod.rs | 1 + src/compute/sort/primitive/indices.rs | 107 +++++++------------------- 3 files changed, 105 insertions(+), 79 deletions(-) create mode 100644 src/compute/sort/common.rs diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs new file mode 100644 index 00000000000..efd27f06e76 --- /dev/null +++ b/src/compute/sort/common.rs @@ -0,0 +1,76 @@ +/// # Safety +/// `indices[i] < values.len()` for all i +/// `limit < values.len()` +#[inline] +unsafe fn k_element_sort_inner( + indices: &mut [i32], + get: G, + descending: bool, + limit: usize, + mut cmp: F, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if descending { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + before.sort_unstable_by(compare); + } else { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + before.sort_unstable_by(compare); + } +} + +/// # Safety +/// Safe iff +/// * `indices[i] < values.len()` for all i +/// * `limit < values.len()` +#[inline] +pub(super) unsafe fn sort_unstable_by( + indices: &mut [i32], + get: G, + mut cmp: F, + descending: bool, + limit: usize, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != indices.len() { + return k_element_sort_inner(indices, get, descending, limit, cmp); + } + + if descending { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }) + } else { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }) + } +} diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 6b30d331c70..53a814ab7d0 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -14,6 +14,7 @@ use crate::buffer::MutableBuffer; use num::ToPrimitive; mod boolean; +mod common; mod lex_sort; mod primitive; diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 0660310cf19..474ddd430b8 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -5,84 +5,9 @@ use crate::{ types::NativeType, }; +use super::super::common::sort_unstable_by; use super::super::SortOptions; -/// # Safety -/// `indices[i] < values.len()` for all i -#[inline] -unsafe fn k_element_sort_inner( - indices: &mut [i32], - values: &[T], - descending: bool, - limit: usize, - mut cmp: F, -) where - T: NativeType, - F: FnMut(&T, &T) -> std::cmp::Ordering, -{ - if descending { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }; - before.sort_unstable_by(compare); - } else { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }; - before.sort_unstable_by(compare); - } -} - -/// # Safety -/// Safe iff -/// * `indices[i] < values.len()` for all i -/// * `limit < values.len()` -#[inline] -unsafe fn sort_unstable_by( - indices: &mut [i32], - values: &[T], - mut cmp: F, - descending: bool, - limit: usize, -) where - T: NativeType, - F: FnMut(&T, &T) -> std::cmp::Ordering, -{ - if limit != indices.len() { - return k_element_sort_inner(indices, values, descending, limit, cmp); - } - - if descending { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }) - } else { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }) - } -} - /// Unstable sort of indices. pub fn indices_sorted_unstable_by( array: &PrimitiveArray, @@ -129,7 +54,15 @@ where // limit is by construction < indices.len() let limit = limit - validity.null_count(); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) } + unsafe { + sort_unstable_by( + indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + options.descending, + limit, + ) + } } } else { let last_valid_index = array.len() - validity.null_count(); @@ -153,7 +86,15 @@ where // limit is by construction <= values.len() let limit = limit.min(last_valid_index); let indices = &mut indices.as_mut_slice()[..last_valid_index]; - unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) }; + unsafe { + sort_unstable_by( + indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + options.descending, + limit, + ) + }; } indices.truncate(limit); @@ -167,7 +108,15 @@ where // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` - unsafe { sort_unstable_by(&mut indices, values, cmp, descending, limit) }; + unsafe { + sort_unstable_by( + &mut indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + descending, + limit, + ) + }; indices.truncate(limit); indices.shrink_to_fit(); From ae41e1f7abe96464b324532b7d18ca25a3bf13a4 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 23 Jul 2021 17:08:56 +0000 Subject: [PATCH 4/5] Optimized sort of utf8 --- src/compute/sort/common.rs | 117 ++++++++++++++- src/compute/sort/mod.rs | 203 ++++++++++---------------- src/compute/sort/primitive/indices.rs | 120 ++------------- src/compute/sort/utf8.rs | 37 +++++ 4 files changed, 239 insertions(+), 238 deletions(-) create mode 100644 src/compute/sort/utf8.rs diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index efd27f06e76..da62d2af387 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -1,8 +1,13 @@ +use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType}; + +use super::SortOptions; + /// # Safety -/// `indices[i] < values.len()` for all i -/// `limit < values.len()` +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. #[inline] -unsafe fn k_element_sort_inner( +fn k_element_sort_inner( indices: &mut [i32], get: G, descending: bool, @@ -42,11 +47,11 @@ unsafe fn k_element_sort_inner( } /// # Safety -/// Safe iff -/// * `indices[i] < values.len()` for all i -/// * `limit < values.len()` +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. #[inline] -pub(super) unsafe fn sort_unstable_by( +fn sort_unstable_by( indices: &mut [i32], get: G, mut cmp: F, @@ -74,3 +79,101 @@ pub(super) unsafe fn sort_unstable_by( }) } } + +/// # Safety +/// This function guarantees that: +/// * `get` is only called for `0 <= i < length` +/// * `cmp` is only called from the co-domain of `get`. +#[inline] +pub(super) fn indices_sorted_unstable_by( + validity: &Option, + get: G, + cmp: F, + length: usize, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray +where + G: Fn(usize) -> T, + F: Fn(&T, &T) -> std::cmp::Ordering, +{ + let descending = options.descending; + + let limit = limit.unwrap_or(length); + // Safety: without this, we go out of bounds when limit >= length. + let limit = limit.min(length); + + let indices = if let Some(validity) = validity { + let mut indices = MutableBuffer::::from_len_zeroed(length); + + if options.nulls_first { + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(is_valid, index)| { + if is_valid { + indices[validity.null_count() + valids] = index; + valids += 1; + } else { + indices[nulls] = index; + nulls += 1; + } + }); + + if limit > validity.null_count() { + // when limit is larger, we must sort values: + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction < indices.len() + let limit = limit - validity.null_count(); + let indices = &mut indices.as_mut_slice()[validity.null_count()..]; + sort_unstable_by(indices, get, cmp, options.descending, limit) + } + } else { + let last_valid_index = length - validity.null_count(); + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(x, index)| { + if x { + indices[valids] = index; + valids += 1; + } else { + indices[last_valid_index + nulls] = index; + nulls += 1; + } + }); + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction <= values.len() + let limit = limit.min(last_valid_index); + let indices = &mut indices.as_mut_slice()[..last_valid_index]; + sort_unstable_by(indices, get, cmp, options.descending, limit); + } + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + } else { + let mut indices = + unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) }; + + // Soundness: + // indices are by construction `< values.len()` + // limit is by construction `< values.len()` + sort_unstable_by(&mut indices, get, cmp, descending, limit); + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + }; + PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) +} diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 53a814ab7d0..998ddc27392 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -1,4 +1,4 @@ -use std::cmp::{Ordering, Reverse}; +use std::cmp::Ordering; use crate::array::ord; use crate::compute::take; @@ -6,7 +6,6 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{ array::*, - buffer::Buffer, types::{days_ms, NativeType}, }; @@ -17,6 +16,7 @@ mod boolean; mod common; mod lex_sort; mod primitive; +mod utf8; pub(crate) use lex_sort::{build_compare, Compare}; pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn}; @@ -132,14 +132,16 @@ pub fn sort_to_indices( DataType::Interval(IntervalUnit::DayTime) => { dyn_sort_indices!(days_ms, values, ord::total_cmp, options, limit) } - DataType::Utf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options, limit)) - } - DataType::LargeUtf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options, limit)) - } + DataType::Utf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::LargeUtf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { @@ -191,23 +193,14 @@ pub fn sort_to_indices( ))), } } - DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { - let (v, n) = partition_validity(values); - match key_type.as_ref() { - DataType::Int8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - t => Err(ArrowError::NotYetImplemented(format!( - "Sort not supported for dictionary key type {:?}", - t - ))), - } - } + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 => sort_dict::(values, key_type.as_ref(), options, limit), + DataType::LargeUtf8 => sort_dict::(values, key_type.as_ref(), options, limit), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary type with keys {:?}", + t + ))), + }, t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for data type {:?}", t @@ -215,6 +208,60 @@ pub fn sort_to_indices( } } +fn sort_dict( + values: &dyn Array, + key_type: &DataType, + options: &SortOptions, + limit: Option, +) -> Result { + match key_type { + DataType::Int8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary key type {:?}", + t + ))), + } +} + /// Checks if an array of type `datatype` can be sorted /// /// # Examples @@ -299,100 +346,6 @@ impl Default for SortOptions { } } -/// Sort strings -fn sort_utf8( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> Int32Array { - let values = values.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - values, - value_indices, - null_indices, - options, - limit, - |array, idx| array.value(idx as usize), - ) -} - -/// Sort dictionary encoded strings -fn sort_string_dictionary( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> Int32Array { - let values: &DictionaryArray = values - .as_any() - .downcast_ref::>() - .unwrap(); - - let keys = values.keys(); - - let dict = values.values(); - let dict = dict.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - keys, - value_indices, - null_indices, - options, - limit, - |array: &PrimitiveArray, idx| -> &str { - let key: T = array.value(idx as usize); - dict.value(key.to_usize().unwrap()) - }, - ) -} - -/// shared implementation between dictionary encoded and plain string arrays -#[inline] -fn sort_string_helper<'a, A: Array, F>( - values: &'a A, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, - value_fn: F, -) -> Int32Array -where - F: Fn(&'a A, i32) -> &str, -{ - let mut valids = value_indices - .into_iter() - .map(|index| (index, value_fn(values, index))) - .collect::>(); - let mut nulls = null_indices; - if !options.descending { - valids.sort_by_key(|a| a.1); - } else { - valids.sort_by_key(|a| Reverse(a.1)); - nulls.reverse(); - } - - let valids = valids.iter().map(|tuple| tuple.0); - - let values = if options.nulls_first { - let values = nulls - .into_iter() - .chain(valids) - .take(limit.unwrap_or_else(|| values.len())); - Buffer::::from_trusted_len_iter(values) - } else { - let values = valids - .chain(nulls.into_iter()) - .take(limit.unwrap_or_else(|| values.len())); - Buffer::::from_trusted_len_iter(values) - }; - - PrimitiveArray::::from_data(DataType::Int32, values, None) -} - fn sort_list( values: &dyn Array, value_indices: Vec, @@ -670,6 +623,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -686,7 +640,8 @@ mod tests { descending: true, nulls_first: false, }, - &[2, 4, 1, 5, 3, 0], + // &[2, 4, 1, 5, 3, 0] is also valid + &[2, 4, 1, 5, 0, 3], ); test_sort_to_indices_string_arrays( @@ -702,6 +657,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -718,7 +674,8 @@ mod tests { descending: true, nulls_first: true, }, - &[3, 0, 2, 4, 1, 5], + // &[3, 0, 2, 4, 1, 5] is also valid + &[0, 3, 2, 4, 1, 5], ); } diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 474ddd430b8..b5673b19dfe 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -1,11 +1,9 @@ use crate::{ array::{Array, PrimitiveArray}, - buffer::MutableBuffer, - datatypes::DataType, types::NativeType, }; -use super::super::common::sort_unstable_by; +use super::super::common; use super::super::SortOptions; /// Unstable sort of indices. @@ -19,111 +17,16 @@ where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { - let descending = options.descending; - let values = array.values(); - let validity = array.validity(); - - let limit = limit.unwrap_or_else(|| array.len()); - // Safety: without this, we go out of bounds when limit >= array.len(). - let limit = limit.min(array.len()); - - let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(array.len()); - - if options.nulls_first { - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(is_valid, index)| { - if is_valid { - indices[validity.null_count() + valids] = index; - valids += 1; - } else { - indices[nulls] = index; - nulls += 1; - } - }); - - if limit > validity.null_count() { - // when limit is larger, we must sort values: - - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - // limit is by construction < indices.len() - let limit = limit - validity.null_count(); - let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - unsafe { - sort_unstable_by( - indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - options.descending, - limit, - ) - } - } - } else { - let last_valid_index = array.len() - validity.null_count(); - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(x, index)| { - if x { - indices[valids] = index; - valids += 1; - } else { - indices[last_valid_index + nulls] = index; - nulls += 1; - } - }); - - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - // limit is by construction <= values.len() - let limit = limit.min(last_valid_index); - let indices = &mut indices.as_mut_slice()[..last_valid_index]; - unsafe { - sort_unstable_by( - indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - options.descending, - limit, - ) - }; - } - - indices.truncate(limit); - indices.shrink_to_fit(); - - indices - } else { - let mut indices = - unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..values.len() as i32) }; - - // Soundness: - // indices are by construction `< values.len()` - // limit is by construction `< values.len()` - unsafe { - sort_unstable_by( - &mut indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - descending, - limit, - ) - }; - - indices.truncate(limit); - indices.shrink_to_fit(); - - indices - }; - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) + unsafe { + common::indices_sorted_unstable_by( + array.validity(), + |x: usize| *array.values().as_slice().get_unchecked(x), + cmp, + array.len(), + options, + limit, + ) + } } #[cfg(test)] @@ -131,6 +34,7 @@ mod tests { use super::*; use crate::array::ord; use crate::array::*; + use crate::datatypes::DataType; fn test( data: &[Option], diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs new file mode 100644 index 00000000000..8e83517ec53 --- /dev/null +++ b/src/compute/sort/utf8.rs @@ -0,0 +1,37 @@ +use crate::array::{Array, Int32Array, Offset, Utf8Array}; +use crate::array::{DictionaryArray, DictionaryKey}; + +use super::common; +use super::SortOptions; + +pub(super) fn indices_sorted_unstable_by( + array: &Utf8Array, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let get = |idx| unsafe { array.value_unchecked(idx as usize) }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +} + +pub(super) fn indices_sorted_unstable_by_dictionary( + array: &DictionaryArray, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let keys = array.keys(); + + let dict = array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + let get = |idx| unsafe { + let index = keys.value_unchecked(idx as usize); + // Note: there is no check that the keys are within bounds of the dictionary. + dict.value(index.to_usize().unwrap()) + }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +} From 25e5ce09e61caebd6f0c507da2ef1a68d378c3ac Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 23 Jul 2021 17:38:56 +0000 Subject: [PATCH 5/5] Fixed example. --- src/compute/sort/lex_sort.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 0ac05f8df04..7ca8b4c6c72 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -49,7 +49,7 @@ pub struct SortColumn<'a> { /// nulls_first: false, /// }), /// }, -/// ]).unwrap(); +/// ], None).unwrap(); /// /// let sorted = sorted_columns[0].as_any().downcast_ref::().unwrap(); /// assert_eq!(sorted.value(1), -64);