From 93bdde80d16e38dcb8707ed5dd80b6c30f783fcc Mon Sep 17 00:00:00 2001 From: baishen Date: Sat, 28 May 2022 11:18:54 +0800 Subject: [PATCH] Added support for custom sort `build_compare_fn` (#1016) --- src/compute/merge_sort/mod.rs | 11 ++++++++++- src/compute/sort/lex_sort.rs | 33 ++++++++++++++++++++++++++++++--- src/compute/sort/mod.rs | 2 +- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index c56a82d3e8d..9c48a29dd00 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -469,6 +469,15 @@ type IsValid<'a> = Box bool + 'a>; /// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`. pub fn build_comparator<'a>( pairs: &'a [(&'a [&'a dyn Array], &SortOptions)], +) -> Result> { + build_comparator_impl(pairs, &build_compare) +} + +/// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`. +/// Implementing custom `build_compare_fn` for unsupportd data types. +pub fn build_comparator_impl<'a>( + pairs: &'a [(&'a [&'a dyn Array], &SortOptions)], + build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result, ) -> Result> { // prepare the comparison function of _values_ between all pairs of arrays let indices_pairs = (0..pairs[0].0.len()) @@ -483,7 +492,7 @@ pub fn build_comparator<'a>( Ok(( Box::new(move |row| arrays[lhs_index].is_valid(row)) as IsValid<'a>, Box::new(move |row| arrays[rhs_index].is_valid(row)) as IsValid<'a>, - build_compare(arrays[lhs_index], arrays[rhs_index])?, + build_compare_fn(arrays[lhs_index], arrays[rhs_index])?, )) }) .collect::>>()?; diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 1ab3949f7ac..114093f49a8 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -79,8 +79,16 @@ fn build_is_valid(array: &dyn Array) -> IsValid { } pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Result { + build_compare_impl(array, sort_option, &ord::build_compare) +} + +pub(crate) fn build_compare_impl( + array: &dyn Array, + sort_option: SortOptions, + build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result, +) -> Result { let is_valid = build_is_valid(array); - let comparator = ord::build_compare(array, array)?; + let comparator = build_compare_fn(array, array)?; Ok(match (sort_option.descending, sort_option.nulls_first) { (true, true) => Box::new(move |i: usize, j: usize| match (is_valid(i), is_valid(j)) { @@ -127,6 +135,17 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu pub fn lexsort_to_indices( columns: &[SortColumn], limit: Option, +) -> Result> { + lexsort_to_indices_impl(columns, limit, &ord::build_compare) +} + +/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`] +/// representing the indices that would sort the columns. +/// Implementing custom `build_compare_fn` for unsupportd data types. +pub fn lexsort_to_indices_impl( + columns: &[SortColumn], + limit: Option, + build_compare_fn: &dyn Fn(&dyn Array, &dyn Array) -> Result, ) -> Result> { if columns.is_empty() { return Err(Error::InvalidArgumentError( @@ -136,7 +155,11 @@ pub fn lexsort_to_indices( if columns.len() == 1 { // fallback to non-lexical sort let column = &columns[0]; - return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit); + if let Ok(indices) = + sort_to_indices(column.values, &column.options.unwrap_or_default(), limit) + { + return Ok(indices); + } } let row_count = columns[0].values.len(); @@ -150,7 +173,11 @@ pub fn lexsort_to_indices( let comparators = columns .iter() .map(|column| -> Result { - build_compare(column.values, column.options.unwrap_or_default()) + build_compare_impl( + column.values, + column.options.unwrap_or_default(), + build_compare_fn, + ) }) .collect::>>()?; diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 548575dec82..125e4da5261 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -18,7 +18,7 @@ mod primitive; mod utf8; pub(crate) use lex_sort::build_compare; -pub use lex_sort::{lexsort, lexsort_to_indices, SortColumn}; +pub use lex_sort::{lexsort, lexsort_to_indices, lexsort_to_indices_impl, SortColumn}; macro_rules! dyn_sort { ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{