diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index efd27f06e76..da62d2af387 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -1,8 +1,13 @@ +use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType}; + +use super::SortOptions; + /// # Safety -/// `indices[i] < values.len()` for all i -/// `limit < values.len()` +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. #[inline] -unsafe fn k_element_sort_inner( +fn k_element_sort_inner( indices: &mut [i32], get: G, descending: bool, @@ -42,11 +47,11 @@ unsafe fn k_element_sort_inner( } /// # Safety -/// Safe iff -/// * `indices[i] < values.len()` for all i -/// * `limit < values.len()` +/// This function guarantees that: +/// * `get` is only called for `0 <= i < limit` +/// * `cmp` is only called from the co-domain of `get`. #[inline] -pub(super) unsafe fn sort_unstable_by( +fn sort_unstable_by( indices: &mut [i32], get: G, mut cmp: F, @@ -74,3 +79,101 @@ pub(super) unsafe fn sort_unstable_by( }) } } + +/// # Safety +/// This function guarantees that: +/// * `get` is only called for `0 <= i < length` +/// * `cmp` is only called from the co-domain of `get`. +#[inline] +pub(super) fn indices_sorted_unstable_by( + validity: &Option, + get: G, + cmp: F, + length: usize, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray +where + G: Fn(usize) -> T, + F: Fn(&T, &T) -> std::cmp::Ordering, +{ + let descending = options.descending; + + let limit = limit.unwrap_or(length); + // Safety: without this, we go out of bounds when limit >= length. + let limit = limit.min(length); + + let indices = if let Some(validity) = validity { + let mut indices = MutableBuffer::::from_len_zeroed(length); + + if options.nulls_first { + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(is_valid, index)| { + if is_valid { + indices[validity.null_count() + valids] = index; + valids += 1; + } else { + indices[nulls] = index; + nulls += 1; + } + }); + + if limit > validity.null_count() { + // when limit is larger, we must sort values: + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction < indices.len() + let limit = limit - validity.null_count(); + let indices = &mut indices.as_mut_slice()[validity.null_count()..]; + sort_unstable_by(indices, get, cmp, options.descending, limit) + } + } else { + let last_valid_index = length - validity.null_count(); + let mut nulls = 0; + let mut valids = 0; + validity + .iter() + .zip(0..length as i32) + .for_each(|(x, index)| { + if x { + indices[valids] = index; + valids += 1; + } else { + indices[last_valid_index + nulls] = index; + nulls += 1; + } + }); + + // Soundness: + // all indices in `indices` are by construction `< array.len() == values.len()` + // limit is by construction <= values.len() + let limit = limit.min(last_valid_index); + let indices = &mut indices.as_mut_slice()[..last_valid_index]; + sort_unstable_by(indices, get, cmp, options.descending, limit); + } + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + } else { + let mut indices = + unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) }; + + // Soundness: + // indices are by construction `< values.len()` + // limit is by construction `< values.len()` + sort_unstable_by(&mut indices, get, cmp, descending, limit); + + indices.truncate(limit); + indices.shrink_to_fit(); + + indices + }; + PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) +} diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 53a814ab7d0..998ddc27392 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -1,4 +1,4 @@ -use std::cmp::{Ordering, Reverse}; +use std::cmp::Ordering; use crate::array::ord; use crate::compute::take; @@ -6,7 +6,6 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{ array::*, - buffer::Buffer, types::{days_ms, NativeType}, }; @@ -17,6 +16,7 @@ mod boolean; mod common; mod lex_sort; mod primitive; +mod utf8; pub(crate) use lex_sort::{build_compare, Compare}; pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn}; @@ -132,14 +132,16 @@ pub fn sort_to_indices( DataType::Interval(IntervalUnit::DayTime) => { dyn_sort_indices!(days_ms, values, ord::total_cmp, options, limit) } - DataType::Utf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options, limit)) - } - DataType::LargeUtf8 => { - let (v, n) = partition_validity(values); - Ok(sort_utf8::(values, v, n, options, limit)) - } + DataType::Utf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::LargeUtf8 => Ok(utf8::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { @@ -191,23 +193,14 @@ pub fn sort_to_indices( ))), } } - DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { - let (v, n) = partition_validity(values); - match key_type.as_ref() { - DataType::Int8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::Int64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt8 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt16 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt32 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - DataType::UInt64 => Ok(sort_string_dictionary::(values, v, n, options, limit)), - t => Err(ArrowError::NotYetImplemented(format!( - "Sort not supported for dictionary key type {:?}", - t - ))), - } - } + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 => sort_dict::(values, key_type.as_ref(), options, limit), + DataType::LargeUtf8 => sort_dict::(values, key_type.as_ref(), options, limit), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary type with keys {:?}", + t + ))), + }, t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for data type {:?}", t @@ -215,6 +208,60 @@ pub fn sort_to_indices( } } +fn sort_dict( + values: &dyn Array, + key_type: &DataType, + options: &SortOptions, + limit: Option, +) -> Result { + match key_type { + DataType::Int8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::Int64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::UInt64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + t => Err(ArrowError::NotYetImplemented(format!( + "Sort not supported for dictionary key type {:?}", + t + ))), + } +} + /// Checks if an array of type `datatype` can be sorted /// /// # Examples @@ -299,100 +346,6 @@ impl Default for SortOptions { } } -/// Sort strings -fn sort_utf8( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> Int32Array { - let values = values.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - values, - value_indices, - null_indices, - options, - limit, - |array, idx| array.value(idx as usize), - ) -} - -/// Sort dictionary encoded strings -fn sort_string_dictionary( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> Int32Array { - let values: &DictionaryArray = values - .as_any() - .downcast_ref::>() - .unwrap(); - - let keys = values.keys(); - - let dict = values.values(); - let dict = dict.as_any().downcast_ref::>().unwrap(); - - sort_string_helper( - keys, - value_indices, - null_indices, - options, - limit, - |array: &PrimitiveArray, idx| -> &str { - let key: T = array.value(idx as usize); - dict.value(key.to_usize().unwrap()) - }, - ) -} - -/// shared implementation between dictionary encoded and plain string arrays -#[inline] -fn sort_string_helper<'a, A: Array, F>( - values: &'a A, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, - value_fn: F, -) -> Int32Array -where - F: Fn(&'a A, i32) -> &str, -{ - let mut valids = value_indices - .into_iter() - .map(|index| (index, value_fn(values, index))) - .collect::>(); - let mut nulls = null_indices; - if !options.descending { - valids.sort_by_key(|a| a.1); - } else { - valids.sort_by_key(|a| Reverse(a.1)); - nulls.reverse(); - } - - let valids = valids.iter().map(|tuple| tuple.0); - - let values = if options.nulls_first { - let values = nulls - .into_iter() - .chain(valids) - .take(limit.unwrap_or_else(|| values.len())); - Buffer::::from_trusted_len_iter(values) - } else { - let values = valids - .chain(nulls.into_iter()) - .take(limit.unwrap_or_else(|| values.len())); - Buffer::::from_trusted_len_iter(values) - }; - - PrimitiveArray::::from_data(DataType::Int32, values, None) -} - fn sort_list( values: &dyn Array, value_indices: Vec, @@ -670,6 +623,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -686,7 +640,8 @@ mod tests { descending: true, nulls_first: false, }, - &[2, 4, 1, 5, 3, 0], + // &[2, 4, 1, 5, 3, 0] is also valid + &[2, 4, 1, 5, 0, 3], ); test_sort_to_indices_string_arrays( @@ -702,6 +657,7 @@ mod tests { descending: false, nulls_first: true, }, + // &[3, 0, 5, 1, 4, 2] is also valid &[0, 3, 5, 1, 4, 2], ); @@ -718,7 +674,8 @@ mod tests { descending: true, nulls_first: true, }, - &[3, 0, 2, 4, 1, 5], + // &[3, 0, 2, 4, 1, 5] is also valid + &[0, 3, 2, 4, 1, 5], ); } diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 474ddd430b8..b5673b19dfe 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -1,11 +1,9 @@ use crate::{ array::{Array, PrimitiveArray}, - buffer::MutableBuffer, - datatypes::DataType, types::NativeType, }; -use super::super::common::sort_unstable_by; +use super::super::common; use super::super::SortOptions; /// Unstable sort of indices. @@ -19,111 +17,16 @@ where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { - let descending = options.descending; - let values = array.values(); - let validity = array.validity(); - - let limit = limit.unwrap_or_else(|| array.len()); - // Safety: without this, we go out of bounds when limit >= array.len(). - let limit = limit.min(array.len()); - - let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(array.len()); - - if options.nulls_first { - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(is_valid, index)| { - if is_valid { - indices[validity.null_count() + valids] = index; - valids += 1; - } else { - indices[nulls] = index; - nulls += 1; - } - }); - - if limit > validity.null_count() { - // when limit is larger, we must sort values: - - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - // limit is by construction < indices.len() - let limit = limit - validity.null_count(); - let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - unsafe { - sort_unstable_by( - indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - options.descending, - limit, - ) - } - } - } else { - let last_valid_index = array.len() - validity.null_count(); - let mut nulls = 0; - let mut valids = 0; - validity - .iter() - .zip(0..array.len() as i32) - .for_each(|(x, index)| { - if x { - indices[valids] = index; - valids += 1; - } else { - indices[last_valid_index + nulls] = index; - nulls += 1; - } - }); - - // Soundness: - // all indices in `indices` are by construction `< array.len() == values.len()` - // limit is by construction <= values.len() - let limit = limit.min(last_valid_index); - let indices = &mut indices.as_mut_slice()[..last_valid_index]; - unsafe { - sort_unstable_by( - indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - options.descending, - limit, - ) - }; - } - - indices.truncate(limit); - indices.shrink_to_fit(); - - indices - } else { - let mut indices = - unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..values.len() as i32) }; - - // Soundness: - // indices are by construction `< values.len()` - // limit is by construction `< values.len()` - unsafe { - sort_unstable_by( - &mut indices, - |x: usize| *values.as_slice().get_unchecked(x), - cmp, - descending, - limit, - ) - }; - - indices.truncate(limit); - indices.shrink_to_fit(); - - indices - }; - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) + unsafe { + common::indices_sorted_unstable_by( + array.validity(), + |x: usize| *array.values().as_slice().get_unchecked(x), + cmp, + array.len(), + options, + limit, + ) + } } #[cfg(test)] @@ -131,6 +34,7 @@ mod tests { use super::*; use crate::array::ord; use crate::array::*; + use crate::datatypes::DataType; fn test( data: &[Option], diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs new file mode 100644 index 00000000000..8e83517ec53 --- /dev/null +++ b/src/compute/sort/utf8.rs @@ -0,0 +1,37 @@ +use crate::array::{Array, Int32Array, Offset, Utf8Array}; +use crate::array::{DictionaryArray, DictionaryKey}; + +use super::common; +use super::SortOptions; + +pub(super) fn indices_sorted_unstable_by( + array: &Utf8Array, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let get = |idx| unsafe { array.value_unchecked(idx as usize) }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +} + +pub(super) fn indices_sorted_unstable_by_dictionary( + array: &DictionaryArray, + options: &SortOptions, + limit: Option, +) -> Int32Array { + let keys = array.keys(); + + let dict = array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + let get = |idx| unsafe { + let index = keys.value_unchecked(idx as usize); + // Note: there is no check that the keys are within bounds of the dictionary. + dict.value(index.to_usize().unwrap()) + }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +}