From a6e8b6961301c87e4f7adfeb5fea82e1bab31834 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sat, 24 Jul 2021 14:32:09 +0200 Subject: [PATCH] Added support for limited sort (#218) --- benches/sort_kernel.rs | 9 +- src/buffer/mutable.rs | 9 + src/compute/merge_sort/mod.rs | 4 +- src/compute/sort/boolean.rs | 52 ++++ src/compute/sort/common.rs | 179 +++++++++++ src/compute/sort/lex_sort.rs | 53 ++-- src/compute/sort/mod.rs | 421 +++++++++++--------------- src/compute/sort/primitive/indices.rs | 253 ++++++++-------- src/compute/sort/primitive/mod.rs | 19 +- src/compute/sort/primitive/sort.rs | 153 +++++++--- src/compute/sort/utf8.rs | 37 +++ 11 files changed, 737 insertions(+), 452 deletions(-) create mode 100644 src/compute/sort/boolean.rs create mode 100644 src/compute/sort/common.rs create mode 100644 src/compute/sort/utf8.rs diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index a55a6971fae..ed48c4c8e42 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -35,11 +35,11 @@ fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) { }, ]; - criterion::black_box(lexsort(&columns).unwrap()); + criterion::black_box(lexsort(&columns, None).unwrap()); } fn bench_sort(arr_a: &dyn Array) { - sort(criterion::black_box(arr_a), &SortOptions::default()).unwrap(); + sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap(); } fn add_benchmark(c: &mut Criterion) { @@ -66,6 +66,11 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function(&format!("lexsort null 2^{} f32", log2_size), |b| { b.iter(|| bench_lexsort(&arr_a, &arr_b)) }); + + let arr_a = create_string_array::(size, 0.1); + c.bench_function(&format!("sort utf8 null 2^{}", log2_size), |b| { + b.iter(|| bench_sort(&arr_a)) + }); }); } diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs index 96e19879602..971c276c582 100644 --- a/src/buffer/mutable.rs +++ b/src/buffer/mutable.rs @@ -186,6 +186,15 @@ impl MutableBuffer { self.len = 0 } + /// Shortens the buffer. + /// If `len` is greater or equal to the buffers' current length, this has no effect. + #[inline] + pub fn truncate(&mut self, len: usize) { + if len < self.len { + self.len = len; + } + } + /// Returns the data stored in this buffer as a slice. #[inline] pub fn as_slice(&self) -> &[T] { diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index 339d07aca32..5c4fddfaaca 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -637,8 +637,8 @@ mod tests { let options = SortOptions::default(); // sort individually, potentially in parallel. - let a0 = sort(a0, &options)?; - let a1 = sort(a1, &options)?; + let a0 = sort(a0, &options, None)?; + let a1 = sort(a1, &options, None)?; // merge then. If multiple arrays, this can be applied in parallel. let result = merge_sort(a0.as_ref(), a1.as_ref(), &options)?; diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs new file mode 100644 index 00000000000..385c7c21355 --- /dev/null +++ b/src/compute/sort/boolean.rs @@ -0,0 +1,52 @@ +use crate::{ + array::{Array, BooleanArray, Int32Array}, + buffer::MutableBuffer, + datatypes::DataType, +}; + +use super::SortOptions; + +/// Returns the indices that would sort a [`BooleanArray`]. +pub fn sort_boolean( + values: &BooleanArray, + value_indices: Vec, + null_indices: Vec, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let descending = options.descending; + + // create tuples that are used for sorting + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + + let mut nulls = null_indices; + + if !descending { + valids.sort_by(|a, b| a.1.cmp(&b.1)); + } else { + valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); + // reverse to keep a stable ordering + nulls.reverse(); + } + + let mut values = MutableBuffer::::with_capacity(values.len()); + + if options.nulls_first { + values.extend_from_slice(nulls.as_slice()); + valids.iter().for_each(|x| values.push(x.0)); + } else { + // nulls last + valids.iter().for_each(|x| values.push(x.0)); + values.extend_from_slice(nulls.as_slice()); + } + + // un-efficient; there are much more performant ways of sorting nulls above, anyways. + if let Some(limit) = limit { + values.truncate(limit); + } + + Int32Array::from_data(DataType::Int32, values.into(), None) +} diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs new file mode 100644 index 00000000000..da62d2af387 --- /dev/null +++ b/src/compute/sort/common.rs @@ -0,0 +1,179 @@ +use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType}; + +use super::SortOptions; + +/// # Safety +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. +#[inline] +fn k_element_sort_inner( + indices: &mut [i32], + get: G, + descending: bool, + limit: usize, + mut cmp: F, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if descending { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + before.sort_unstable_by(compare); + } else { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + before.sort_unstable_by(compare); + } +} + +/// # Safety +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. +#[inline] +fn sort_unstable_by( + indices: &mut [i32], + get: G, + mut cmp: F, + descending: bool, + limit: usize, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != indices.len() { + return k_element_sort_inner(indices, get, descending, limit, cmp); + } + + if descending { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }) + } else { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }) + } +} + +/// # Safety +/// This function guarantees that: +/// * `get` is only called for `0 <= i < length` +/// * `cmp` is only called from the co-domain of `get`. +#[inline] +pub(super) fn indices_sorted_unstable_by( + validity: &Option, + get: G, + cmp: F, + length: usize, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray +where + G: Fn(usize) -> T, + F: Fn(&T, &T) -> std::cmp::Ordering, +{ + let descending = options.descending; + + let limit = limit.unwrap_or(length); + // Safety: without this, we go out of bounds when limit >= length. + let limit = limit.min(length); + + let indices = if let Some(validity) = validity { + let mut indices = MutableBuffer::::from_len_zeroed(length); + + if options.nulls_first { + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(is_valid, index)| { + if is_valid { + indices[validity.null_count() + valids] = index; + valids += 1; + } else { + indices[nulls] = index; + nulls += 1; + } + }); + + if limit > validity.null_count() { + // when limit is larger, we must sort values: + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction < indices.len() + let limit = limit - validity.null_count(); + let indices = &mut indices.as_mut_slice()[validity.null_count()..]; + sort_unstable_by(indices, get, cmp, options.descending, limit) + } + } else { + let last_valid_index = length - validity.null_count(); + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(x, index)| { + if x { + indices[valids] = index; + valids += 1; + } else { + indices[last_valid_index + nulls] = index; + nulls += 1; + } + }); + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction <= values.len() + let limit = limit.min(last_valid_index); + let indices = &mut indices.as_mut_slice()[..last_valid_index]; + sort_unstable_by(indices, get, cmp, options.descending, limit); + } + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + } else { + let mut indices = + unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) }; + + // Soundness: + // indices are by construction `< values.len()` + // limit is by construction `< values.len()` + sort_unstable_by(&mut indices, get, cmp, descending, limit); + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + }; + PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) +} diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 44f0616931f..7ca8b4c6c72 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::cmp::Ordering; use crate::compute::take; @@ -66,14 +49,14 @@ pub struct SortColumn<'a> { /// nulls_first: false, /// }), /// }, -/// ]).unwrap(); +/// ], None).unwrap(); /// /// let sorted = sorted_columns[0].as_any().downcast_ref::().unwrap(); /// assert_eq!(sorted.value(1), -64); /// assert!(sorted.is_null(0)); /// ``` -pub fn lexsort(columns: &[SortColumn]) -> Result>> { - let indices = lexsort_to_indices(columns)?; +pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result>> { + let indices = lexsort_to_indices(columns, limit)?; columns .iter() .map(|c| take::take(c.values, &indices)) @@ -135,9 +118,12 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu }) } -/// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer -/// [`Int32Array`] of indices. -pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> { +/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`] +/// representing the indices that would sort the columns. +pub fn lexsort_to_indices( + columns: &[SortColumn], + limit: Option, +) -> Result> { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -146,7 +132,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> if columns.len() == 1 { // fallback to non-lexical sort let column = &columns[0]; - return sort_to_indices(column.values, &column.options.unwrap_or_default()); + return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit); } let row_count = columns[0].values.len(); @@ -180,7 +166,15 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result> // Safety: `0..row_count` is TrustedLen let mut values = unsafe { MutableBuffer::::from_trusted_len_iter_unchecked(0..row_count as i32) }; - values.sort_unstable_by(lex_comparator); + + if let Some(limit) = limit { + let limit = limit.min(row_count); + let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator); + before.sort_unstable_by(lex_comparator); + values.truncate(limit); + } else { + values.sort_unstable_by(lex_comparator); + } Ok(PrimitiveArray::::from_data( DataType::Int32, @@ -196,7 +190,14 @@ mod tests { use super::*; fn test_lex_sort_arrays(input: Vec, expected: Vec>) { - let sorted = lexsort(&input).unwrap(); + let sorted = lexsort(&input, None).unwrap(); + assert_eq!(sorted, expected); + + let sorted = lexsort(&input, Some(2)).unwrap(); + let expected = expected + .into_iter() + .map(|x| x.slice(0, 2)) + .collect::>(); assert_eq!(sorted, expected); } diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 4c28712a327..998ddc27392 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -1,23 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines sort kernel for `ArrayRef` - -use std::cmp::{Ordering, Reverse}; +use std::cmp::Ordering; use crate::array::ord; use crate::compute::take; @@ -25,64 +6,69 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{ array::*, - buffer::Buffer, types::{days_ms, NativeType}, }; use crate::buffer::MutableBuffer; use num::ToPrimitive; +mod boolean; +mod common; mod lex_sort; mod primitive; +mod utf8; pub(crate) use lex_sort::{build_compare, Compare}; pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn}; macro_rules! dyn_sort { - ($ty:ty, $array:expr, $cmp:expr, $options:expr) => {{ + ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ let array = $array .as_any() .downcast_ref::>() .unwrap(); Ok(Box::new(primitive::sort_by::<$ty, _>( - &array, $cmp, $options, + &array, $cmp, $options, $limit, ))) }}; } -/// Sort the `ArrayRef` using `SortOptions`. +/// Sort the [`Array`] using [`SortOptions`]. /// -/// Performs a stable sort on values and indices. Nulls are ordered according to the `nulls_first` flag in `options`. +/// Performs an unstable sort on values and indices. Nulls are ordered according to the `nulls_first` flag in `options`. /// Floats are sorted using IEEE 754 totalOrder -/// -/// Returns an `ArrowError::ComputeError(String)` if the array type is either unsupported by `sort_to_indices` or `take`. -/// -pub fn sort(values: &dyn Array, options: &SortOptions) -> Result> { +/// # Errors +/// Errors if the [`DataType`] is not supported. +pub fn sort( + values: &dyn Array, + options: &SortOptions, + limit: Option, +) -> Result> { match values.data_type() { - DataType::Int8 => dyn_sort!(i8, values, ord::total_cmp, options), - DataType::Int16 => dyn_sort!(i16, values, ord::total_cmp, options), + DataType::Int8 => dyn_sort!(i8, values, ord::total_cmp, options, limit), + DataType::Int16 => dyn_sort!(i16, values, ord::total_cmp, options, limit), DataType::Int32 | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sort!(i32, values, ord::total_cmp, options) + dyn_sort!(i32, values, ord::total_cmp, options, limit) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, None) - | DataType::Duration(_) => dyn_sort!(i64, values, ord::total_cmp, options), - DataType::UInt8 => dyn_sort!(u8, values, ord::total_cmp, options), - DataType::UInt16 => dyn_sort!(u16, values, ord::total_cmp, options), - DataType::UInt32 => dyn_sort!(u32, values, ord::total_cmp, options), - DataType::UInt64 => dyn_sort!(u64, values, ord::total_cmp, options), - DataType::Float32 => dyn_sort!(f32, values, ord::total_cmp_f32, options), - DataType::Float64 => dyn_sort!(f64, values, ord::total_cmp_f64, options), + | DataType::Duration(_) => dyn_sort!(i64, values, ord::total_cmp, options, limit), + DataType::UInt8 => dyn_sort!(u8, values, ord::total_cmp, options, limit), + DataType::UInt16 => dyn_sort!(u16, values, ord::total_cmp, options, limit), + DataType::UInt32 => dyn_sort!(u32, values, ord::total_cmp, options, limit), + DataType::UInt64 => dyn_sort!(u64, values, ord::total_cmp, options, limit), + DataType::Float32 => dyn_sort!(f32, values, ord::total_cmp_f32, options, limit), + DataType::Float64 => dyn_sort!(f64, values, ord::total_cmp_f64, options, limit), DataType::Interval(IntervalUnit::DayTime) => { - dyn_sort!(days_ms, values, ord::total_cmp, options) + dyn_sort!(days_ms, values, ord::total_cmp, options, limit) } _ => { - let indices = sort_to_indices(values, options)?; + let indices = sort_to_indices(values, options, limit)?; take::take(values, &indices) } } @@ -95,66 +81,78 @@ fn partition_validity(array: &dyn Array) -> (Vec, Vec) { } macro_rules! dyn_sort_indices { - ($ty:ty, $array:expr, $cmp:expr, $options:expr) => {{ + ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ let array = $array .as_any() .downcast_ref::>() .unwrap(); - Ok(primitive::indices_sorted_by::<$ty, _>( - &array, $cmp, $options, + Ok(primitive::indices_sorted_unstable_by::<$ty, _>( + &array, $cmp, $options, $limit, )) }}; } -/// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices. +/// Sort elements from `values` into [`Int32Array`] of indices. /// For floating point arrays any NaN values are considered to be greater than any other non-null value -pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { +pub fn sort_to_indices( + values: &dyn Array, + options: &SortOptions, + limit: Option, +) -> Result { match values.data_type() { DataType::Boolean => { let (v, n) = partition_validity(values); - Ok(sort_boolean(values, v, n, options)) + Ok(boolean::sort_boolean( + values.as_any().downcast_ref().unwrap(), + v, + n, + options, + limit, + )) } - DataType::Int8 => dyn_sort_indices!(i8, values, ord::total_cmp, options), - DataType::Int16 => dyn_sort_indices!(i16, values, ord::total_cmp, options), + DataType::Int8 => dyn_sort_indices!(i8, values, ord::total_cmp, options, limit), + DataType::Int16 => dyn_sort_indices!(i16, values, ord::total_cmp, options, limit), DataType::Int32 | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sort_indices!(i32, values, ord::total_cmp, options) + dyn_sort_indices!(i32, values, ord::total_cmp, options, limit) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, None) - | DataType::Duration(_) => dyn_sort_indices!(i64, values, ord::total_cmp, options), - DataType::UInt8 => dyn_sort_indices!(u8, values, ord::total_cmp, options), - DataType::UInt16 => dyn_sort_indices!(u16, values, ord::total_cmp, options), - DataType::UInt32 => dyn_sort_indices!(u32, values, ord::total_cmp, options), - DataType::UInt64 => dyn_sort_indices!(u64, values, ord::total_cmp, options), - DataType::Float32 => dyn_sort_indices!(f32, values, ord::total_cmp_f32, options), - DataType::Float64 => dyn_sort_indices!(f64, values, ord::total_cmp_f64, options), + | DataType::Duration(_) => dyn_sort_indices!(i64, values, ord::total_cmp, options, limit), + DataType::UInt8 => dyn_sort_indices!(u8, values, ord::total_cmp, options, limit), + DataType::UInt16 => dyn_sort_indices!(u16, values, ord::total_cmp, options, limit), + DataType::UInt32 => dyn_sort_indices!(u32, values, ord::total_cmp, options, limit), + DataType::UInt64 => dyn_sort_indices!(u64, values, ord::total_cmp, options, limit), + DataType::Float32 => dyn_sort_indices!(f32, values, ord::total_cmp_f32, options, limit), + DataType::Float64 => dyn_sort_indices!(f64, values, ord::total_cmp_f64, options, limit), DataType::Interval(IntervalUnit::DayTime) => { - dyn_sort_indices!(days_ms, values, ord::total_cmp, options) - } - DataType::Utf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options)) - } - DataType::LargeUtf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options)) + dyn_sort_indices!(days_ms, values, ord::total_cmp, options, limit) } + DataType::Utf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::LargeUtf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -164,14 +162,14 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -181,37 +179,28 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options)), - DataType::Int16 => Ok(sort_list::(values, v, n, options)), - DataType::Int32 => Ok(sort_list::(values, v, n, options)), - DataType::Int64 => Ok(sort_list::(values, v, n, options)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t ))), } } - DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { - let (v, n) = partition_validity(values); - match key_type.as_ref() { - DataType::Int8 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int16 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int32 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::Int64 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt8 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt16 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt32 => Ok(sort_string_dictionary::(values, v, n, options)), - DataType::UInt64 => Ok(sort_string_dictionary::(values, v, n, options)), - t => Err(ArrowError::NotYetImplemented(format!( - "Sort not supported for dictionary key type {:?}", - t - ))), - } - } + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 => sort_dict::(values, key_type.as_ref(), options, limit), + DataType::LargeUtf8 => sort_dict::(values, key_type.as_ref(), options, limit), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary type with keys {:?}", + t + ))), + }, t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for data type {:?}", t @@ -219,6 +208,60 @@ pub fn sort_to_indices(values: &dyn Array, options: &SortOptions) -> Result( + values: &dyn Array, + key_type: &DataType, + options: &SortOptions, + limit: Option, +) -> Result { + match key_type { + DataType::Int8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary key type {:?}", + t + ))), + } +} + /// Checks if an array of type `datatype` can be sorted /// /// # Examples @@ -303,138 +346,12 @@ impl Default for SortOptions { } } -/// Sort primitive values -fn sort_boolean( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, -) -> Int32Array { - let values = values - .as_any() - .downcast_ref::() - .expect("Unable to downcast to boolean array"); - let descending = options.descending; - - // create tuples that are used for sorting - let mut valids = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>(); - - let mut nulls = null_indices; - - if !descending { - valids.sort_by(|a, b| a.1.cmp(&b.1)); - } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } - - let mut values = MutableBuffer::::with_capacity(values.len()); - - if options.nulls_first { - values.extend_from_slice(nulls.as_slice()); - valids.iter().for_each(|x| values.push(x.0)); - } else { - // nulls last - valids.iter().for_each(|x| values.push(x.0)); - values.extend_from_slice(nulls.as_slice()); - } - - Int32Array::from_data(DataType::Int32, values.into(), None) -} - -/// Sort strings -fn sort_utf8( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, -) -> Int32Array { - let values = values.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - values, - value_indices, - null_indices, - options, - |array, idx| array.value(idx as usize), - ) -} - -/// Sort dictionary encoded strings -fn sort_string_dictionary( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, -) -> Int32Array { - let values: &DictionaryArray = values - .as_any() - .downcast_ref::>() - .unwrap(); - - let keys = values.keys(); - - let dict = values.values(); - let dict = dict.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - keys, - value_indices, - null_indices, - options, - |array: &PrimitiveArray, idx| -> &str { - let key: T = array.value(idx as usize); - dict.value(key.to_usize().unwrap()) - }, - ) -} - -/// shared implementation between dictionary encoded and plain string arrays -#[inline] -fn sort_string_helper<'a, A: Array, F>( - values: &'a A, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - value_fn: F, -) -> Int32Array -where - F: Fn(&'a A, i32) -> &str, -{ - let mut valids = value_indices - .into_iter() - .map(|index| (index, value_fn(values, index))) - .collect::>(); - let mut nulls = null_indices; - if !options.descending { - valids.sort_by_key(|a| a.1); - } else { - valids.sort_by_key(|a| Reverse(a.1)); - nulls.reverse(); - } - - let valids = valids.iter().map(|tuple| tuple.0); - - let values = if options.nulls_first { - let values = nulls.into_iter().chain(valids); - Buffer::::from_trusted_len_iter(values) - } else { - let values = valids.chain(nulls.into_iter()); - Buffer::::from_trusted_len_iter(values) - }; - - PrimitiveArray::::from_data(DataType::Int32, values, None) -} - fn sort_list( values: &dyn Array, value_indices: Vec, null_indices: Vec, options: &SortOptions, + limit: Option, ) -> Int32Array where O: Offset, @@ -469,17 +386,19 @@ where let values = valids.iter().map(|tuple| tuple.0); - let values = if options.nulls_first { + let mut values = if options.nulls_first { let mut buffer = MutableBuffer::::from_trusted_len_iter(null_indices.into_iter()); - values.for_each(|x| buffer.push(x)); - buffer.into() + buffer.extend(values); + buffer } else { let mut buffer = MutableBuffer::::from_trusted_len_iter(values); - null_indices.iter().for_each(|x| buffer.push(*x)); - buffer.into() + buffer.extend(null_indices); + buffer }; - PrimitiveArray::::from_data(DataType::Int32, values, None) + values.truncate(limit.unwrap_or_else(|| values.len())); + + PrimitiveArray::::from_data(DataType::Int32, values.into(), None) } /// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord). @@ -507,7 +426,7 @@ mod tests { ) { let output = BooleanArray::from(data); let expected = Int32Array::from_slice(expected_data); - let output = sort_to_indices(&output, &options).unwrap(); + let output = sort_to_indices(&output, &options, None).unwrap(); assert_eq!(output, expected) } @@ -521,7 +440,7 @@ mod tests { { let input = PrimitiveArray::::from(data).to(data_type.clone()); let expected = PrimitiveArray::::from(expected_data).to(data_type); - let output = sort(&input, &options).unwrap(); + let output = sort(&input, &options, None).unwrap(); assert_eq!(expected, output.as_ref()) } @@ -532,7 +451,7 @@ mod tests { ) { let input = Utf8Array::::from(&data.to_vec()); let expected = Int32Array::from_slice(expected_data); - let output = sort_to_indices(&input, &options).unwrap(); + let output = sort_to_indices(&input, &options, None).unwrap(); assert_eq!(output, expected) } @@ -543,7 +462,7 @@ mod tests { ) { let input = Utf8Array::::from(&data.to_vec()); let expected = Utf8Array::::from(&expected_data.to_vec()); - let output = sort(&input, &options).unwrap(); + let output = sort(&input, &options, None).unwrap(); assert_eq!(expected, output.as_ref()) } @@ -560,7 +479,7 @@ mod tests { expected.try_extend(expected_data.iter().copied()).unwrap(); let expected = expected.into_arc(); - let output = sort(input.as_ref(), &options).unwrap(); + let output = sort(input.as_ref(), &options, None).unwrap(); assert_eq!(expected.as_ref(), output.as_ref()) } @@ -704,6 +623,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -720,7 +640,8 @@ mod tests { descending: true, nulls_first: false, }, - &[2, 4, 1, 5, 3, 0], + // &[2, 4, 1, 5, 3, 0] is also valid + &[2, 4, 1, 5, 0, 3], ); test_sort_to_indices_string_arrays( @@ -736,6 +657,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -752,7 +674,8 @@ mod tests { descending: true, nulls_first: true, }, - &[3, 0, 2, 4, 1, 5], + // &[3, 0, 2, 4, 1, 5] is also valid + &[0, 3, 2, 4, 1, 5], ); } @@ -1100,9 +1023,9 @@ mod tests { nulls_first: true, }; if can_sort(&d1) { - assert!(sort(array.as_ref(), &options).is_ok()); + assert!(sort(array.as_ref(), &options, None).is_ok()); } else { - assert!(sort(array.as_ref(), &options).is_err()); + assert!(sort(array.as_ref(), &options, None).is_err()); } }); } diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 41dea70c579..b5673b19dfe 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -1,132 +1,31 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use crate::{ array::{Array, PrimitiveArray}, - buffer::MutableBuffer, - datatypes::DataType, types::NativeType, }; +use super::super::common; use super::super::SortOptions; -/// # Safety -/// `indices[i] < values.len()` for all i -#[inline] -unsafe fn sort_inner(indices: &mut [i32], values: &[T], mut cmp: F, descending: bool) -where - T: NativeType, - F: FnMut(&T, &T) -> std::cmp::Ordering, -{ - if descending { - indices.sort_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }) - } else { - indices.sort_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }) - } -} - -pub fn indices_sorted_by( +/// Unstable sort of indices. +pub fn indices_sorted_unstable_by( array: &PrimitiveArray, cmp: F, options: &SortOptions, + limit: Option, ) -> PrimitiveArray where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { - let descending = options.descending; - let values = array.values(); - let validity = array.validity(); - - if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(array.len()); - - if options.nulls_first { - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(x, index)| { - if x { - indices[validity.null_count() + valids] = index; - valids += 1; - } else { - indices[nulls] = index; - nulls += 1; - } - }); - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - unsafe { - sort_inner( - &mut indices.as_mut_slice()[validity.null_count()..], - values, - cmp, - options.descending, - ) - } - } else { - let last_valid_index = array.len() - validity.null_count(); - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(x, index)| { - if x { - indices[valids] = index; - valids += 1; - } else { - indices[last_valid_index + nulls] = index; - nulls += 1; - } - }); - - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - unsafe { - sort_inner( - &mut indices.as_mut_slice()[..last_valid_index], - values, - cmp, - options.descending, - ) - }; - } - - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) - } else { - let mut indices = - unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..values.len() as i32) }; - - // Soundness: - // indices are by construction `< values.len()` - unsafe { sort_inner(&mut indices, values, cmp, descending) }; - - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) + unsafe { + common::indices_sorted_unstable_by( + array.validity(), + |x: usize| *array.values().as_slice().get_unchecked(x), + cmp, + array.len(), + options, + limit, + ) } } @@ -135,14 +34,20 @@ mod tests { use super::*; use crate::array::ord; use crate::array::*; - - fn test(data: &[Option], data_type: DataType, options: SortOptions, expected_data: &[i32]) - where + use crate::datatypes::DataType; + + fn test( + data: &[Option], + data_type: DataType, + options: SortOptions, + limit: Option, + expected_data: &[i32], + ) where T: NativeType + std::cmp::Ord, { let input = PrimitiveArray::::from(data).to(data_type); let expected = Int32Array::from_slice(&expected_data); - let output = indices_sorted_by(&input, ord::total_cmp, &options); + let output = indices_sorted_unstable_by(&input, ord::total_cmp, &options, limit); assert_eq!(output, expected) } @@ -155,6 +60,7 @@ mod tests { descending: false, nulls_first: true, }, + None, &[0, 5, 3, 1, 4, 2], ); } @@ -168,6 +74,7 @@ mod tests { descending: false, nulls_first: false, }, + None, &[3, 1, 4, 2, 0, 5], ); } @@ -181,6 +88,7 @@ mod tests { descending: true, nulls_first: true, }, + None, &[0, 5, 2, 1, 4, 3], ); } @@ -194,7 +102,116 @@ mod tests { descending: true, nulls_first: false, }, + None, &[2, 1, 4, 3, 0, 5], ); } + + #[test] + fn limit_ascending_nulls_first() { + // nulls sorted + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: true, + }, + Some(2), + &[0, 5], + ); + + // nulls and values sorted + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: true, + }, + Some(4), + &[0, 5, 3, 1], + ); + } + + #[test] + fn limit_ascending_nulls_last() { + // values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: false, + }, + Some(2), + &[3, 1], + ); + + // values and nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: false, + nulls_first: false, + }, + Some(5), + &[3, 1, 4, 2, 0], + ); + } + + #[test] + fn limit_descending_nulls_first() { + // nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: true, + }, + Some(2), + &[0, 5], + ); + + // nulls and values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: true, + }, + Some(4), + &[0, 5, 2, 1], + ); + } + + #[test] + fn limit_descending_nulls_last() { + // values + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: false, + }, + Some(2), + &[2, 1], + ); + + // values and nulls + test::( + &[None, Some(3), Some(5), Some(2), Some(3), None], + DataType::Int8, + SortOptions { + descending: true, + nulls_first: false, + }, + Some(5), + &[2, 1, 4, 3, 0], + ); + } } diff --git a/src/compute/sort/primitive/mod.rs b/src/compute/sort/primitive/mod.rs index 6be3f81c96c..ccfeb15b868 100644 --- a/src/compute/sort/primitive/mod.rs +++ b/src/compute/sort/primitive/mod.rs @@ -1,22 +1,5 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - mod indices; mod sort; -pub use indices::indices_sorted_by; +pub use indices::indices_sorted_unstable_by; pub use sort::sort_by; diff --git a/src/compute/sort/primitive/sort.rs b/src/compute/sort/primitive/sort.rs index 34dbf373136..197ea299416 100644 --- a/src/compute/sort/primitive/sort.rs +++ b/src/compute/sort/primitive/sort.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::MutableBuffer; +use crate::bitmap::Bitmap; +use crate::buffer::{Buffer, MutableBuffer}; use crate::{ array::{Array, PrimitiveArray}, bitmap::{utils::SlicesIterator, MutableBitmap}, @@ -24,11 +25,32 @@ use crate::{ use super::super::SortOptions; -fn sort_inner(values: &mut [T], mut cmp: F, descending: bool) +/// # Safety +/// `indices[i] < values.len()` for all i +#[inline] +fn k_element_sort_inner(values: &mut [T], descending: bool, limit: usize, mut cmp: F) where T: NativeType, F: FnMut(&T, &T) -> std::cmp::Ordering, { + if descending { + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y).reverse()); + before.sort_unstable_by(|x, y| cmp(x, y)); + } else { + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y)); + before.sort_unstable_by(|x, y| cmp(x, y)); + } +} + +fn sort_values(values: &mut [T], mut cmp: F, descending: bool, limit: usize) +where + T: NativeType, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != values.len() { + return k_element_sort_inner(values, descending, limit, cmp); + } + if descending { values.sort_unstable_by(|x, y| cmp(x, y).reverse()); } else { @@ -36,54 +58,106 @@ where }; } +fn sort_nullable( + values: &[T], + validity: &Bitmap, + cmp: F, + options: &SortOptions, + limit: usize, +) -> (Buffer, Option) +where + T: NativeType, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + assert!(limit <= values.len()); + if options.nulls_first && limit < validity.null_count() { + let mut buffer = MutableBuffer::::with_capacity(limit); + buffer.extend_constant(limit, T::default()); + let bitmap = MutableBitmap::from_trusted_len_iter(std::iter::repeat(false).take(limit)); + return (buffer.into(), bitmap.into()); + } + + let nulls = std::iter::repeat(false).take(validity.null_count()); + let valids = std::iter::repeat(true).take(values.len() - validity.null_count()); + + let mut buffer = MutableBuffer::::with_capacity(values.len()); + let mut new_validity = MutableBitmap::with_capacity(values.len()); + let slices = SlicesIterator::new(validity); + + if options.nulls_first { + // validity is [0,0,0,...,1,1,1,1] + new_validity.extend_from_trusted_len_iter(nulls.chain(valids).take(limit)); + + // extend buffer with constants followed by non-null values + buffer.extend_constant(validity.null_count(), T::default()); + for (start, len) in slices { + buffer.extend_from_slice(&values[start..start + len]) + } + + // sort values + sort_values( + &mut buffer.as_mut_slice()[validity.null_count()..], + cmp, + options.descending, + limit - validity.null_count(), + ); + } else { + // validity is [1,1,1,...,0,0,0,0] + new_validity.extend_from_trusted_len_iter(valids.chain(nulls).take(limit)); + + // extend buffer with non-null values + for (start, len) in slices { + buffer.extend_from_slice(&values[start..start + len]) + } + + // sort all non-null values + sort_values( + &mut buffer.as_mut_slice(), + cmp, + options.descending, + limit - validity.null_count(), + ); + + if limit > values.len() - validity.null_count() { + // extend remaining with nulls + buffer.extend_constant(validity.null_count(), T::default()); + } + }; + // values are sorted, we can now truncate the remaining. + buffer.truncate(limit); + + (buffer.into(), new_validity.into()) +} + /// Sorts a [`PrimitiveArray`] according to `cmp` comparator and [`SortOptions`]. -pub fn sort_by(array: &PrimitiveArray, cmp: F, options: &SortOptions) -> PrimitiveArray +pub fn sort_by( + array: &PrimitiveArray, + cmp: F, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray where T: NativeType, F: FnMut(&T, &T) -> std::cmp::Ordering, { + let limit = limit.unwrap_or_else(|| array.len()); + let limit = limit.min(array.len()); + let values = array.values(); let validity = array.validity(); let (buffer, validity) = if let Some(validity) = validity { - let nulls = std::iter::repeat(false).take(validity.null_count()); - let valids = std::iter::repeat(true).take(array.len() - validity.null_count()); - - let mut buffer = MutableBuffer::::with_capacity(array.len()); - let mut new_validity = MutableBitmap::with_capacity(array.len()); - let slices = SlicesIterator::new(validity); - - if options.nulls_first { - new_validity.extend_from_trusted_len_iter(nulls.chain(valids)); - (0..validity.null_count()).for_each(|_| buffer.push(T::default())); - for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) - } - sort_inner( - &mut buffer.as_mut_slice()[validity.null_count()..], - cmp, - options.descending, - ) - } else { - new_validity.extend_from_trusted_len_iter(valids.chain(nulls)); - for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) - } - sort_inner(&mut buffer.as_mut_slice(), cmp, options.descending); - - (0..validity.null_count()).for_each(|_| buffer.push(T::default())); - }; - - (buffer, new_validity.into()) + sort_nullable(values, validity, cmp, options, limit) } else { let mut buffer = MutableBuffer::::new(); buffer.extend_from_slice(values); - sort_inner(&mut buffer.as_mut_slice(), cmp, options.descending); + sort_values(&mut buffer.as_mut_slice(), cmp, options.descending, limit); + buffer.truncate(limit); - (buffer, None) + (buffer.into(), None) }; - PrimitiveArray::::from_data(array.data_type().clone(), buffer.into(), validity) + PrimitiveArray::::from_data(array.data_type().clone(), buffer, validity) } #[cfg(test)] @@ -103,8 +177,13 @@ mod tests { T: NativeType + std::cmp::Ord, { let input = PrimitiveArray::::from(data).to(data_type.clone()); - let expected = PrimitiveArray::::from(expected_data).to(data_type); - let output = sort_by(&input, ord::total_cmp, &options); + let expected = PrimitiveArray::::from(expected_data).to(data_type.clone()); + let output = sort_by(&input, ord::total_cmp, &options, None); + assert_eq!(expected, output); + + // with limit + let expected = PrimitiveArray::::from(&expected_data[..3]).to(data_type); + let output = sort_by(&input, ord::total_cmp, &options, Some(3)); assert_eq!(expected, output) } diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs new file mode 100644 index 00000000000..8e83517ec53 --- /dev/null +++ b/src/compute/sort/utf8.rs @@ -0,0 +1,37 @@ +use crate::array::{Array, Int32Array, Offset, Utf8Array}; +use crate::array::{DictionaryArray, DictionaryKey}; + +use super::common; +use super::SortOptions; + +pub(super) fn indices_sorted_unstable_by( + array: &Utf8Array, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let get = |idx| unsafe { array.value_unchecked(idx as usize) }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +} + +pub(super) fn indices_sorted_unstable_by_dictionary( + array: &DictionaryArray, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let keys = array.keys(); + + let dict = array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + let get = |idx| unsafe { + let index = keys.value_unchecked(idx as usize); + // Note: there is no check that the keys are within bounds of the dictionary. + dict.value(index.to_usize().unwrap()) + }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +}