diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs new file mode 100644 index 00000000000..efd27f06e76 --- /dev/null +++ b/src/compute/sort/common.rs @@ -0,0 +1,76 @@ +/// # Safety +/// `indices[i] < values.len()` for all i +/// `limit < values.len()` +#[inline] +unsafe fn k_element_sort_inner( + indices: &mut [i32], + get: G, + descending: bool, + limit: usize, + mut cmp: F, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if descending { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }; + before.sort_unstable_by(compare); + } else { + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + let (before, _, _) = indices.select_nth_unstable_by(limit, compare); + let compare = |lhs: &i32, rhs: &i32| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }; + before.sort_unstable_by(compare); + } +} + +/// # Safety +/// Safe iff +/// * `indices[i] < values.len()` for all i +/// * `limit < values.len()` +#[inline] +pub(super) unsafe fn sort_unstable_by( + indices: &mut [i32], + get: G, + mut cmp: F, + descending: bool, + limit: usize, +) where + G: Fn(usize) -> T, + F: FnMut(&T, &T) -> std::cmp::Ordering, +{ + if limit != indices.len() { + return k_element_sort_inner(indices, get, descending, limit, cmp); + } + + if descending { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs).reverse() + }) + } else { + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(*lhs as usize); + let rhs = get(*rhs as usize); + cmp(&lhs, &rhs) + }) + } +} diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 6b30d331c70..53a814ab7d0 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -14,6 +14,7 @@ use crate::buffer::MutableBuffer; use num::ToPrimitive; mod boolean; +mod common; mod lex_sort; mod primitive; diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 0660310cf19..474ddd430b8 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -5,84 +5,9 @@ use crate::{ types::NativeType, }; +use super::super::common::sort_unstable_by; use super::super::SortOptions; -/// # Safety -/// `indices[i] < values.len()` for all i -#[inline] -unsafe fn k_element_sort_inner( - indices: &mut [i32], - values: &[T], - descending: bool, - limit: usize, - mut cmp: F, -) where - T: NativeType, - F: FnMut(&T, &T) -> std::cmp::Ordering, -{ - if descending { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }; - before.sort_unstable_by(compare); - } else { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }; - before.sort_unstable_by(compare); - } -} - -/// # Safety -/// Safe iff -/// * `indices[i] < values.len()` for all i -/// * `limit < values.len()` -#[inline] -unsafe fn sort_unstable_by( - indices: &mut [i32], - values: &[T], - mut cmp: F, - descending: bool, - limit: usize, -) where - T: NativeType, - F: FnMut(&T, &T) -> std::cmp::Ordering, -{ - if limit != indices.len() { - return k_element_sort_inner(indices, values, descending, limit, cmp); - } - - if descending { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs).reverse() - }) - } else { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = values.get_unchecked(*lhs as usize); - let rhs = values.get_unchecked(*rhs as usize); - cmp(lhs, rhs) - }) - } -} - /// Unstable sort of indices. pub fn indices_sorted_unstable_by( array: &PrimitiveArray, @@ -129,7 +54,15 @@ where // limit is by construction < indices.len() let limit = limit - validity.null_count(); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) } + unsafe { + sort_unstable_by( + indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + options.descending, + limit, + ) + } } } else { let last_valid_index = array.len() - validity.null_count(); @@ -153,7 +86,15 @@ where // limit is by construction <= values.len() let limit = limit.min(last_valid_index); let indices = &mut indices.as_mut_slice()[..last_valid_index]; - unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) }; + unsafe { + sort_unstable_by( + indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + options.descending, + limit, + ) + }; } indices.truncate(limit); @@ -167,7 +108,15 @@ where // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` - unsafe { sort_unstable_by(&mut indices, values, cmp, descending, limit) }; + unsafe { + sort_unstable_by( + &mut indices, + |x: usize| *values.as_slice().get_unchecked(x), + cmp, + descending, + limit, + ) + }; indices.truncate(limit); indices.shrink_to_fit();