From 03664af141734701c8afe8f24b44b770697345ba Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 24 Jul 2021 17:42:59 +0000 Subject: [PATCH] Generalized sort to accept indices other than i32. --- src/array/specification.rs | 37 ++++--- src/compute/sort/boolean.rs | 19 ++-- src/compute/sort/common.rs | 90 ++++++++-------- src/compute/sort/lex_sort.rs | 37 ++++--- src/compute/sort/mod.rs | 146 ++++++++++++++------------ src/compute/sort/primitive/indices.rs | 10 +- src/compute/sort/utf8.rs | 10 +- 7 files changed, 188 insertions(+), 161 deletions(-) diff --git a/src/array/specification.rs b/src/array/specification.rs index 5acfc9e1b5a..70576cb8621 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -8,8 +8,9 @@ use crate::{ }; /// Trait describing any type that can be used to index a slot of an array. -pub trait Index: NativeType { +pub trait Index: NativeType + NaturalDataType { fn to_usize(&self) -> usize; + fn from_usize(index: usize) -> Option; } /// Trait describing types that can be used as offsets as per Arrow specification. @@ -17,14 +18,12 @@ pub trait Index: NativeType { /// # Safety /// Do not implement. pub unsafe trait Offset: - Index + NaturalDataType + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd + Index + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd { fn is_large() -> bool; fn to_isize(&self) -> isize; - fn from_usize(value: usize) -> Option; - fn from_isize(value: isize) -> Option; } @@ -34,11 +33,6 @@ unsafe impl Offset for i32 { false } - #[inline] - fn from_usize(value: usize) -> Option { - Self::try_from(value).ok() - } - #[inline] fn from_isize(value: isize) -> Option { Self::try_from(value).ok() @@ -56,11 +50,6 @@ unsafe impl Offset for i64 { true } - #[inline] - fn from_usize(value: usize) -> Option { - Some(value as i64) - } - #[inline] fn from_isize(value: isize) -> Option { Self::try_from(value).ok() @@ -77,6 +66,11 @@ impl Index for i32 { fn to_usize(&self) -> usize { *self as usize } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } } impl Index for i64 { @@ -84,6 +78,11 @@ impl Index for i64 { fn to_usize(&self) -> usize { *self as usize } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } } impl Index for u32 { @@ -91,6 +90,11 @@ impl Index for u32 { fn to_usize(&self) -> usize { *self as usize } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } } impl Index for u64 { @@ -98,6 +102,11 @@ impl Index for u64 { fn to_usize(&self) -> usize { *self as usize } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } } #[inline] diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs index 385c7c21355..2e7082a5901 100644 --- a/src/compute/sort/boolean.rs +++ b/src/compute/sort/boolean.rs @@ -1,26 +1,25 @@ use crate::{ - array::{Array, BooleanArray, Int32Array}, + array::{Array, BooleanArray, Index, PrimitiveArray}, buffer::MutableBuffer, - datatypes::DataType, }; use super::SortOptions; /// Returns the indices that would sort a [`BooleanArray`]. -pub fn sort_boolean( +pub fn sort_boolean( values: &BooleanArray, - value_indices: Vec, - null_indices: Vec, + value_indices: Vec, + null_indices: Vec, options: &SortOptions, limit: Option, -) -> Int32Array { +) -> PrimitiveArray { let descending = options.descending; // create tuples that are used for sorting let mut valids = value_indices .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>(); + .map(|index| (index, values.value(index.to_usize()))) + .collect::>(); let mut nulls = null_indices; @@ -32,7 +31,7 @@ pub fn sort_boolean( nulls.reverse(); } - let mut values = MutableBuffer::::with_capacity(values.len()); + let mut values = MutableBuffer::::with_capacity(values.len()); if options.nulls_first { values.extend_from_slice(nulls.as_slice()); @@ -48,5 +47,5 @@ pub fn sort_boolean( values.truncate(limit); } - Int32Array::from_data(DataType::Int32, values.into(), None) + PrimitiveArray::::from_data(I::DATA_TYPE, values.into(), None) } diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index da62d2af387..c36763831bb 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -1,4 +1,8 @@ -use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType}; +use crate::{ + array::{Index, PrimitiveArray}, + bitmap::Bitmap, + buffer::MutableBuffer, +}; use super::SortOptions; @@ -7,8 +11,8 @@ use super::SortOptions; /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn k_element_sort_inner( - indices: &mut [i32], +fn k_element_sort_inner( + indices: &mut [I], get: G, descending: bool, limit: usize, @@ -18,28 +22,28 @@ fn k_element_sort_inner( F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs).reverse() }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs).reverse() }; before.sort_unstable_by(compare); } else { - let compare = |lhs: &i32, rhs: &i32| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &i32, rhs: &i32| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs) }; before.sort_unstable_by(compare); @@ -51,13 +55,14 @@ fn k_element_sort_inner( /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn sort_unstable_by( - indices: &mut [i32], +fn sort_unstable_by( + indices: &mut [I], get: G, mut cmp: F, descending: bool, limit: usize, ) where + I: Index, G: Fn(usize) -> T, F: FnMut(&T, &T) -> std::cmp::Ordering, { @@ -67,14 +72,14 @@ fn sort_unstable_by( if descending { indices.sort_unstable_by(|lhs, rhs| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs).reverse() }) } else { indices.sort_unstable_by(|lhs, rhs| { - let lhs = get(*lhs as usize); - let rhs = get(*rhs as usize); + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); cmp(&lhs, &rhs) }) } @@ -85,15 +90,16 @@ fn sort_unstable_by( /// * `get` is only called for `0 <= i < length` /// * `cmp` is only called from the co-domain of `get`. #[inline] -pub(super) fn indices_sorted_unstable_by( +pub(super) fn indices_sorted_unstable_by( validity: &Option, get: G, cmp: F, length: usize, options: &SortOptions, limit: Option, -) -> PrimitiveArray +) -> PrimitiveArray where + I: Index, G: Fn(usize) -> T, F: Fn(&T, &T) -> std::cmp::Ordering, { @@ -104,20 +110,20 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(length); + let mut indices = MutableBuffer::::from_len_zeroed(length); if options.nulls_first { let mut nulls = 0; let mut valids = 0; validity .iter() - .zip(0..length as i32) + .zip(0..length) .for_each(|(is_valid, index)| { if is_valid { - indices[validity.null_count() + valids] = index; + indices[validity.null_count() + valids] = I::from_usize(index).unwrap(); valids += 1; } else { - indices[nulls] = index; + indices[nulls] = I::from_usize(index).unwrap(); nulls += 1; } }); @@ -136,18 +142,15 @@ where let last_valid_index = length - validity.null_count(); let mut nulls = 0; let mut valids = 0; - validity - .iter() - .zip(0..length as i32) - .for_each(|(x, index)| { - if x { - indices[valids] = index; - valids += 1; - } else { - indices[last_valid_index + nulls] = index; - nulls += 1; - } - }); + validity.iter().zip(0..length).for_each(|(x, index)| { + if x { + indices[valids] = I::from_usize(index).unwrap(); + valids += 1; + } else { + indices[last_valid_index + nulls] = I::from_usize(index).unwrap(); + nulls += 1; + } + }); // Soundness: // all indices in `indices` are by construction `< array.len() == values.len()` @@ -162,8 +165,11 @@ where indices } else { - let mut indices = - unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) }; + let mut indices = unsafe { + MutableBuffer::from_trusted_len_iter_unchecked( + (0..length).map(|x| I::from_usize(x).unwrap()), + ) + }; // Soundness: // indices are by construction `< values.len()` @@ -175,5 +181,5 @@ where indices }; - PrimitiveArray::::from_data(DataType::Int32, indices.into(), None) + PrimitiveArray::::from_data(I::DATA_TYPE, indices.into(), None) } diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 7ca8b4c6c72..6cde2cc4dd6 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -3,9 +3,8 @@ use std::cmp::Ordering; use crate::compute::take; use crate::error::{ArrowError, Result}; use crate::{ - array::{ord, Array, PrimitiveArray}, + array::{ord, Array, Index, PrimitiveArray}, buffer::MutableBuffer, - datatypes::DataType, }; use super::{sort_to_indices, SortOptions}; @@ -55,8 +54,11 @@ pub struct SortColumn<'a> { /// assert_eq!(sorted.value(1), -64); /// assert!(sorted.is_null(0)); /// ``` -pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result>> { - let indices = lexsort_to_indices(columns, limit)?; +pub fn lexsort( + columns: &[SortColumn], + limit: Option, +) -> Result>> { + let indices = lexsort_to_indices::(columns, limit)?; columns .iter() .map(|c| take::take(c.values, &indices)) @@ -118,12 +120,12 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu }) } -/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`] +/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray`] /// representing the indices that would sort the columns. -pub fn lexsort_to_indices( +pub fn lexsort_to_indices( columns: &[SortColumn], limit: Option, -) -> Result> { +) -> Result> { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -150,9 +152,9 @@ pub fn lexsort_to_indices( }) .collect::>>()?; - let lex_comparator = |a_idx: &i32, b_idx: &i32| -> Ordering { - let a_idx = *a_idx as usize; - let b_idx = *b_idx as usize; + let lex_comparator = |a_idx: &I, b_idx: &I| -> Ordering { + let a_idx = a_idx.to_usize(); + let b_idx = b_idx.to_usize(); for comparator in comparators.iter() { match comparator(a_idx, b_idx) { Ordering::Equal => continue, @@ -164,8 +166,11 @@ pub fn lexsort_to_indices( }; // Safety: `0..row_count` is TrustedLen - let mut values = - unsafe { MutableBuffer::::from_trusted_len_iter_unchecked(0..row_count as i32) }; + let mut values = unsafe { + MutableBuffer::from_trusted_len_iter_unchecked( + (0..row_count).map(|x| I::from_usize(x).unwrap()), + ) + }; if let Some(limit) = limit { let limit = limit.min(row_count); @@ -176,8 +181,8 @@ pub fn lexsort_to_indices( values.sort_unstable_by(lex_comparator); } - Ok(PrimitiveArray::::from_data( - DataType::Int32, + Ok(PrimitiveArray::::from_data( + I::DATA_TYPE, values.into(), None, )) @@ -190,10 +195,10 @@ mod tests { use super::*; fn test_lex_sort_arrays(input: Vec, expected: Vec>) { - let sorted = lexsort(&input, None).unwrap(); + let sorted = lexsort::(&input, None).unwrap(); assert_eq!(sorted, expected); - let sorted = lexsort(&input, Some(2)).unwrap(); + let sorted = lexsort::(&input, Some(2)).unwrap(); let expected = expected .into_iter() .map(|x| x.slice(0, 2)) diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 998ddc27392..32cf63166c6 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -10,7 +10,6 @@ use crate::{ }; use crate::buffer::MutableBuffer; -use num::ToPrimitive; mod boolean; mod common; @@ -68,37 +67,41 @@ pub fn sort( dyn_sort!(days_ms, values, ord::total_cmp, options, limit) } _ => { - let indices = sort_to_indices(values, options, limit)?; + let indices = sort_to_indices::(values, options, limit)?; take::take(values, &indices) } } } // partition indices into valid and null indices -fn partition_validity(array: &dyn Array) -> (Vec, Vec) { - let indices = 0..(array.len().to_i32().unwrap()); - indices.partition(|index| array.is_valid(*index as usize)) +fn partition_validity(array: &dyn Array) -> (Vec, Vec) { + let length = array.len(); + let indices = (0..length).map(|x| I::from_usize(x).unwrap()); + if let Some(validity) = array.validity() { + indices.partition(|index| validity.get_bit(index.to_usize())) + } else { + (indices.collect(), vec![]) + } } macro_rules! dyn_sort_indices { - ($ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ + ($index:ty, $ty:ty, $array:expr, $cmp:expr, $options:expr, $limit:expr) => {{ let array = $array .as_any() .downcast_ref::>() .unwrap(); - Ok(primitive::indices_sorted_unstable_by::<$ty, _>( + Ok(primitive::indices_sorted_unstable_by::<$index, $ty, _>( &array, $cmp, $options, $limit, )) }}; } -/// Sort elements from `values` into [`Int32Array`] of indices. -/// For floating point arrays any NaN values are considered to be greater than any other non-null value -pub fn sort_to_indices( +/// Sort elements from `values` into a non-nullable [`PrimitiveArray`] of indices that sort `values`. +pub fn sort_to_indices( values: &dyn Array, options: &SortOptions, limit: Option, -) -> Result { +) -> Result> { match values.data_type() { DataType::Boolean => { let (v, n) = partition_validity(values); @@ -110,34 +113,36 @@ pub fn sort_to_indices( limit, )) } - DataType::Int8 => dyn_sort_indices!(i8, values, ord::total_cmp, options, limit), - DataType::Int16 => dyn_sort_indices!(i16, values, ord::total_cmp, options, limit), + DataType::Int8 => dyn_sort_indices!(I, i8, values, ord::total_cmp, options, limit), + DataType::Int16 => dyn_sort_indices!(I, i16, values, ord::total_cmp, options, limit), DataType::Int32 | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_sort_indices!(i32, values, ord::total_cmp, options, limit) + dyn_sort_indices!(I, i32, values, ord::total_cmp, options, limit) } DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, None) - | DataType::Duration(_) => dyn_sort_indices!(i64, values, ord::total_cmp, options, limit), - DataType::UInt8 => dyn_sort_indices!(u8, values, ord::total_cmp, options, limit), - DataType::UInt16 => dyn_sort_indices!(u16, values, ord::total_cmp, options, limit), - DataType::UInt32 => dyn_sort_indices!(u32, values, ord::total_cmp, options, limit), - DataType::UInt64 => dyn_sort_indices!(u64, values, ord::total_cmp, options, limit), - DataType::Float32 => dyn_sort_indices!(f32, values, ord::total_cmp_f32, options, limit), - DataType::Float64 => dyn_sort_indices!(f64, values, ord::total_cmp_f64, options, limit), + | DataType::Duration(_) => { + dyn_sort_indices!(I, i64, values, ord::total_cmp, options, limit) + } + DataType::UInt8 => dyn_sort_indices!(I, u8, values, ord::total_cmp, options, limit), + DataType::UInt16 => dyn_sort_indices!(I, u16, values, ord::total_cmp, options, limit), + DataType::UInt32 => dyn_sort_indices!(I, u32, values, ord::total_cmp, options, limit), + DataType::UInt64 => dyn_sort_indices!(I, u64, values, ord::total_cmp, options, limit), + DataType::Float32 => dyn_sort_indices!(I, f32, values, ord::total_cmp_f32, options, limit), + DataType::Float64 => dyn_sort_indices!(I, f64, values, ord::total_cmp_f64, options, limit), DataType::Interval(IntervalUnit::DayTime) => { - dyn_sort_indices!(days_ms, values, ord::total_cmp, options, limit) + dyn_sort_indices!(I, days_ms, values, ord::total_cmp, options, limit) } - DataType::Utf8 => Ok(utf8::indices_sorted_unstable_by::( + DataType::Utf8 => Ok(utf8::indices_sorted_unstable_by::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::LargeUtf8 => Ok(utf8::indices_sorted_unstable_by::( + DataType::LargeUtf8 => Ok(utf8::indices_sorted_unstable_by::( values.as_any().downcast_ref().unwrap(), options, limit, @@ -145,14 +150,14 @@ pub fn sort_to_indices( DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -162,14 +167,14 @@ pub fn sort_to_indices( DataType::LargeList(field) => { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -179,14 +184,14 @@ pub fn sort_to_indices( DataType::FixedSizeList(field, _) => { let (v, n) = partition_validity(values); match field.data_type() { - DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), - DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::Int64 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt8 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt16 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt32 => Ok(sort_list::(values, v, n, options, limit)), + DataType::UInt64 => Ok(sort_list::(values, v, n, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for list type {:?}", t @@ -194,8 +199,8 @@ pub fn sort_to_indices( } } DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 => sort_dict::(values, key_type.as_ref(), options, limit), - DataType::LargeUtf8 => sort_dict::(values, key_type.as_ref(), options, limit), + DataType::Utf8 => sort_dict::(values, key_type.as_ref(), options, limit), + DataType::LargeUtf8 => sort_dict::(values, key_type.as_ref(), options, limit), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for dictionary type with keys {:?}", t @@ -208,49 +213,49 @@ pub fn sort_to_indices( } } -fn sort_dict( +fn sort_dict( values: &dyn Array, key_type: &DataType, options: &SortOptions, limit: Option, -) -> Result { +) -> Result> { match key_type { - DataType::Int8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::Int8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::Int16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::Int16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::Int32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::Int32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::Int64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::Int64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::UInt8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::UInt8 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::UInt16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::UInt16 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::UInt32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::UInt32 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, )), - DataType::UInt64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( + DataType::UInt64 => Ok(utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, limit, @@ -346,18 +351,19 @@ impl Default for SortOptions { } } -fn sort_list( +fn sort_list( values: &dyn Array, - value_indices: Vec, - null_indices: Vec, + value_indices: Vec, + null_indices: Vec, options: &SortOptions, limit: Option, -) -> Int32Array +) -> PrimitiveArray where + I: Index, O: Offset, T: NativeType + std::cmp::PartialOrd, { - let mut valids: Vec<(i32, Box)> = values + let mut valids: Vec<(I, Box)> = values .as_any() .downcast_ref::() .map_or_else( @@ -366,14 +372,14 @@ where value_indices .iter() .copied() - .map(|index| (index, values.value(index as usize))) + .map(|index| (index, values.value(index.to_usize()))) .collect() }, |values| { value_indices .iter() .copied() - .map(|index| (index, values.value(index as usize))) + .map(|index| (index, values.value(index.to_usize()))) .collect() }, ); @@ -387,18 +393,18 @@ where let values = valids.iter().map(|tuple| tuple.0); let mut values = if options.nulls_first { - let mut buffer = MutableBuffer::::from_trusted_len_iter(null_indices.into_iter()); + let mut buffer = MutableBuffer::::from_trusted_len_iter(null_indices.into_iter()); buffer.extend(values); buffer } else { - let mut buffer = MutableBuffer::::from_trusted_len_iter(values); + let mut buffer = MutableBuffer::::from_trusted_len_iter(values); buffer.extend(null_indices); buffer }; values.truncate(limit.unwrap_or_else(|| values.len())); - PrimitiveArray::::from_data(DataType::Int32, values.into(), None) + PrimitiveArray::::from_data(I::DATA_TYPE, values.into(), None) } /// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord). diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index b5673b19dfe..e7c79947b4c 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -1,5 +1,5 @@ use crate::{ - array::{Array, PrimitiveArray}, + array::{Array, Index, PrimitiveArray}, types::NativeType, }; @@ -7,13 +7,14 @@ use super::super::common; use super::super::SortOptions; /// Unstable sort of indices. -pub fn indices_sorted_unstable_by( +pub fn indices_sorted_unstable_by( array: &PrimitiveArray, cmp: F, options: &SortOptions, limit: Option, -) -> PrimitiveArray +) -> PrimitiveArray where + I: Index, T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { @@ -47,7 +48,8 @@ mod tests { { let input = PrimitiveArray::::from(data).to(data_type); let expected = Int32Array::from_slice(&expected_data); - let output = indices_sorted_unstable_by(&input, ord::total_cmp, &options, limit); + let output = + indices_sorted_unstable_by::(&input, ord::total_cmp, &options, limit); assert_eq!(output, expected) } diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index 8e83517ec53..20f2d65fea7 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -1,24 +1,24 @@ -use crate::array::{Array, Int32Array, Offset, Utf8Array}; +use crate::array::{Array, Index, Offset, PrimitiveArray, Utf8Array}; use crate::array::{DictionaryArray, DictionaryKey}; use super::common; use super::SortOptions; -pub(super) fn indices_sorted_unstable_by( +pub(super) fn indices_sorted_unstable_by( array: &Utf8Array, options: &SortOptions, limit: Option, -) -> Int32Array { +) -> PrimitiveArray { let get = |idx| unsafe { array.value_unchecked(idx as usize) }; let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) } -pub(super) fn indices_sorted_unstable_by_dictionary( +pub(super) fn indices_sorted_unstable_by_dictionary( array: &DictionaryArray, options: &SortOptions, limit: Option, -) -> Int32Array { +) -> PrimitiveArray { let keys = array.keys(); let dict = array