Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Generalized sort to accept indices other than i32.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jul 24, 2021
1 parent a6e8b69 commit 03664af
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 161 deletions.
37 changes: 23 additions & 14 deletions src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@ use crate::{
};

/// Trait describing any type that can be used to index a slot of an array.
pub trait Index: NativeType {
pub trait Index: NativeType + NaturalDataType {
fn to_usize(&self) -> usize;
fn from_usize(index: usize) -> Option<Self>;
}

/// Trait describing types that can be used as offsets as per Arrow specification.
/// This trait is only implemented for `i32` and `i64`, the two sizes part of the specification.
/// # Safety
/// Do not implement.
pub unsafe trait Offset:
Index + NaturalDataType + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd
Index + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd
{
fn is_large() -> bool;

fn to_isize(&self) -> isize;

fn from_usize(value: usize) -> Option<Self>;

fn from_isize(value: isize) -> Option<Self>;
}

Expand All @@ -34,11 +33,6 @@ unsafe impl Offset for i32 {
false
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

#[inline]
fn from_isize(value: isize) -> Option<Self> {
Self::try_from(value).ok()
Expand All @@ -56,11 +50,6 @@ unsafe impl Offset for i64 {
true
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Some(value as i64)
}

#[inline]
fn from_isize(value: isize) -> Option<Self> {
Self::try_from(value).ok()
Expand All @@ -77,27 +66,47 @@ impl Index for i32 {
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for i64 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for u32 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for u64 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

#[inline]
Expand Down
19 changes: 9 additions & 10 deletions src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
use crate::{
array::{Array, BooleanArray, Int32Array},
array::{Array, BooleanArray, Index, PrimitiveArray},
buffer::MutableBuffer,
datatypes::DataType,
};

use super::SortOptions;

/// Returns the indices that would sort a [`BooleanArray`].
pub fn sort_boolean(
pub fn sort_boolean<I: Index>(
values: &BooleanArray,
value_indices: Vec<i32>,
null_indices: Vec<i32>,
value_indices: Vec<I>,
null_indices: Vec<I>,
options: &SortOptions,
limit: Option<usize>,
) -> Int32Array {
) -> PrimitiveArray<I> {
let descending = options.descending;

// create tuples that are used for sorting
let mut valids = value_indices
.into_iter()
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(i32, bool)>>();
.map(|index| (index, values.value(index.to_usize())))
.collect::<Vec<(I, bool)>>();

let mut nulls = null_indices;

Expand All @@ -32,7 +31,7 @@ pub fn sort_boolean(
nulls.reverse();
}

let mut values = MutableBuffer::<i32>::with_capacity(values.len());
let mut values = MutableBuffer::<I>::with_capacity(values.len());

if options.nulls_first {
values.extend_from_slice(nulls.as_slice());
Expand All @@ -48,5 +47,5 @@ pub fn sort_boolean(
values.truncate(limit);
}

Int32Array::from_data(DataType::Int32, values.into(), None)
PrimitiveArray::<I>::from_data(I::DATA_TYPE, values.into(), None)
}
90 changes: 48 additions & 42 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType};
use crate::{
array::{Index, PrimitiveArray},
bitmap::Bitmap,
buffer::MutableBuffer,
};

use super::SortOptions;

Expand All @@ -7,8 +11,8 @@ use super::SortOptions;
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
fn k_element_sort_inner<I: Index, T, G, F>(
indices: &mut [I],
get: G,
descending: bool,
limit: usize,
Expand All @@ -18,28 +22,28 @@ fn k_element_sort_inner<T, G, F>(
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
Expand All @@ -51,13 +55,14 @@ fn k_element_sort_inner<T, G, F>(
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
fn sort_unstable_by<I, T, G, F>(
indices: &mut [I],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
I: Index,
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
Expand All @@ -67,14 +72,14 @@ fn sort_unstable_by<T, G, F>(

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
})
}
Expand All @@ -85,15 +90,16 @@ fn sort_unstable_by<T, G, F>(
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<T, G, F>(
pub(super) fn indices_sorted_unstable_by<I, T, G, F>(
validity: &Option<Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<i32>
) -> PrimitiveArray<I>
where
I: Index,
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
Expand All @@ -104,20 +110,20 @@ where
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length);
let mut indices = MutableBuffer::<I>::from_len_zeroed(length);

if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.zip(0..length)
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
indices[validity.null_count() + valids] = I::from_usize(index).unwrap();
valids += 1;
} else {
indices[nulls] = index;
indices[nulls] = I::from_usize(index).unwrap();
nulls += 1;
}
});
Expand All @@ -136,18 +142,15 @@ where
let last_valid_index = length - validity.null_count();
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});
validity.iter().zip(0..length).for_each(|(x, index)| {
if x {
indices[valids] = I::from_usize(index).unwrap();
valids += 1;
} else {
indices[last_valid_index + nulls] = I::from_usize(index).unwrap();
nulls += 1;
}
});

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
Expand All @@ -162,8 +165,11 @@ where

indices
} else {
let mut indices =
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) };
let mut indices = unsafe {
MutableBuffer::from_trusted_len_iter_unchecked(
(0..length).map(|x| I::from_usize(x).unwrap()),
)
};

// Soundness:
// indices are by construction `< values.len()`
Expand All @@ -175,5 +181,5 @@ where

indices
};
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None)
PrimitiveArray::<I>::from_data(I::DATA_TYPE, indices.into(), None)
}
Loading

0 comments on commit 03664af

Please sign in to comment.