Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Generalized sort to accept indices other than i32. #220

Merged
merged 1 commit into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@ use crate::{
};

/// Trait describing any type that can be used to index a slot of an array.
pub trait Index: NativeType {
pub trait Index: NativeType + NaturalDataType {
fn to_usize(&self) -> usize;
fn from_usize(index: usize) -> Option<Self>;
}

/// Trait describing types that can be used as offsets as per Arrow specification.
/// This trait is only implemented for `i32` and `i64`, the two sizes part of the specification.
/// # Safety
/// Do not implement.
pub unsafe trait Offset:
Index + NaturalDataType + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd
Index + Num + Ord + std::ops::AddAssign + std::ops::Sub + num::CheckedAdd
{
fn is_large() -> bool;

fn to_isize(&self) -> isize;

fn from_usize(value: usize) -> Option<Self>;

fn from_isize(value: isize) -> Option<Self>;
}

Expand All @@ -34,11 +33,6 @@ unsafe impl Offset for i32 {
false
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

#[inline]
fn from_isize(value: isize) -> Option<Self> {
Self::try_from(value).ok()
Expand All @@ -56,11 +50,6 @@ unsafe impl Offset for i64 {
true
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Some(value as i64)
}

#[inline]
fn from_isize(value: isize) -> Option<Self> {
Self::try_from(value).ok()
Expand All @@ -77,27 +66,47 @@ impl Index for i32 {
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for i64 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for u32 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

impl Index for u64 {
#[inline]
fn to_usize(&self) -> usize {
*self as usize
}

#[inline]
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}
}

#[inline]
Expand Down
19 changes: 9 additions & 10 deletions src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
use crate::{
array::{Array, BooleanArray, Int32Array},
array::{Array, BooleanArray, Index, PrimitiveArray},
buffer::MutableBuffer,
datatypes::DataType,
};

use super::SortOptions;

/// Returns the indices that would sort a [`BooleanArray`].
pub fn sort_boolean(
pub fn sort_boolean<I: Index>(
values: &BooleanArray,
value_indices: Vec<i32>,
null_indices: Vec<i32>,
value_indices: Vec<I>,
null_indices: Vec<I>,
options: &SortOptions,
limit: Option<usize>,
) -> Int32Array {
) -> PrimitiveArray<I> {
let descending = options.descending;

// create tuples that are used for sorting
let mut valids = value_indices
.into_iter()
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(i32, bool)>>();
.map(|index| (index, values.value(index.to_usize())))
.collect::<Vec<(I, bool)>>();

let mut nulls = null_indices;

Expand All @@ -32,7 +31,7 @@ pub fn sort_boolean(
nulls.reverse();
}

let mut values = MutableBuffer::<i32>::with_capacity(values.len());
let mut values = MutableBuffer::<I>::with_capacity(values.len());

if options.nulls_first {
values.extend_from_slice(nulls.as_slice());
Expand All @@ -48,5 +47,5 @@ pub fn sort_boolean(
values.truncate(limit);
}

Int32Array::from_data(DataType::Int32, values.into(), None)
PrimitiveArray::<I>::from_data(I::DATA_TYPE, values.into(), None)
}
90 changes: 48 additions & 42 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType};
use crate::{
array::{Index, PrimitiveArray},
bitmap::Bitmap,
buffer::MutableBuffer,
};

use super::SortOptions;

Expand All @@ -7,8 +11,8 @@ use super::SortOptions;
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
fn k_element_sort_inner<I: Index, T, G, F>(
indices: &mut [I],
get: G,
descending: bool,
limit: usize,
Expand All @@ -18,28 +22,28 @@ fn k_element_sort_inner<T, G, F>(
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
Expand All @@ -51,13 +55,14 @@ fn k_element_sort_inner<T, G, F>(
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
fn sort_unstable_by<I, T, G, F>(
indices: &mut [I],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
I: Index,
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
Expand All @@ -67,14 +72,14 @@ fn sort_unstable_by<T, G, F>(

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
})
}
Expand All @@ -85,15 +90,16 @@ fn sort_unstable_by<T, G, F>(
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<T, G, F>(
pub(super) fn indices_sorted_unstable_by<I, T, G, F>(
validity: &Option<Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<i32>
) -> PrimitiveArray<I>
where
I: Index,
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
Expand All @@ -104,20 +110,20 @@ where
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length);
let mut indices = MutableBuffer::<I>::from_len_zeroed(length);

if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.zip(0..length)
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
indices[validity.null_count() + valids] = I::from_usize(index).unwrap();
valids += 1;
} else {
indices[nulls] = index;
indices[nulls] = I::from_usize(index).unwrap();
nulls += 1;
}
});
Expand All @@ -136,18 +142,15 @@ where
let last_valid_index = length - validity.null_count();
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});
validity.iter().zip(0..length).for_each(|(x, index)| {
if x {
indices[valids] = I::from_usize(index).unwrap();
valids += 1;
} else {
indices[last_valid_index + nulls] = I::from_usize(index).unwrap();
nulls += 1;
}
});

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
Expand All @@ -162,8 +165,11 @@ where

indices
} else {
let mut indices =
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) };
let mut indices = unsafe {
MutableBuffer::from_trusted_len_iter_unchecked(
(0..length).map(|x| I::from_usize(x).unwrap()),
)
};

// Soundness:
// indices are by construction `< values.len()`
Expand All @@ -175,5 +181,5 @@ where

indices
};
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None)
PrimitiveArray::<I>::from_data(I::DATA_TYPE, indices.into(), None)
}
Loading