Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Generalized function.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jul 23, 2021
1 parent e7b33f1 commit 128be64
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 79 deletions.
76 changes: 76 additions & 0 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/// # Safety
/// `indices[i] < values.len()` for all i
/// `limit < values.len()`
#[inline]
unsafe fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
get: G,
descending: bool,
limit: usize,
mut cmp: F,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
}
}

/// # Safety
/// Safe iff
/// * `indices[i] < values.len()` for all i
/// * `limit < values.len()`
#[inline]
pub(super) unsafe fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, get, descending, limit, cmp);
}

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
})
}
}
1 change: 1 addition & 0 deletions src/compute/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::buffer::MutableBuffer;
use num::ToPrimitive;

mod boolean;
mod common;
mod lex_sort;
mod primitive;

Expand Down
107 changes: 28 additions & 79 deletions src/compute/sort/primitive/indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,9 @@ use crate::{
types::NativeType,
};

use super::super::common::sort_unstable_by;
use super::super::SortOptions;

/// # Safety
/// `indices[i] < values.len()` for all i
#[inline]
unsafe fn k_element_sort_inner<T, F>(
indices: &mut [i32],
values: &[T],
descending: bool,
limit: usize,
mut cmp: F,
) where
T: NativeType,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs)
};
before.sort_unstable_by(compare);
}
}

/// # Safety
/// Safe iff
/// * `indices[i] < values.len()` for all i
/// * `limit < values.len()`
#[inline]
unsafe fn sort_unstable_by<T, F>(
indices: &mut [i32],
values: &[T],
mut cmp: F,
descending: bool,
limit: usize,
) where
T: NativeType,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, values, descending, limit, cmp);
}

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = values.get_unchecked(*lhs as usize);
let rhs = values.get_unchecked(*rhs as usize);
cmp(lhs, rhs)
})
}
}

/// Unstable sort of indices.
pub fn indices_sorted_unstable_by<T, F>(
array: &PrimitiveArray<T>,
Expand Down Expand Up @@ -129,7 +54,15 @@ where
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) }
unsafe {
sort_unstable_by(
indices,
|x: usize| *values.as_slice().get_unchecked(x),
cmp,
options.descending,
limit,
)
}
}
} else {
let last_valid_index = array.len() - validity.null_count();
Expand All @@ -153,7 +86,15 @@ where
// limit is by construction <= values.len()
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
unsafe { sort_unstable_by(indices, values, cmp, options.descending, limit) };
unsafe {
sort_unstable_by(
indices,
|x: usize| *values.as_slice().get_unchecked(x),
cmp,
options.descending,
limit,
)
};
}

indices.truncate(limit);
Expand All @@ -167,7 +108,15 @@ where
// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
unsafe { sort_unstable_by(&mut indices, values, cmp, descending, limit) };
unsafe {
sort_unstable_by(
&mut indices,
|x: usize| *values.as_slice().get_unchecked(x),
cmp,
descending,
limit,
)
};

indices.truncate(limit);
indices.shrink_to_fit();
Expand Down

0 comments on commit 128be64

Please sign in to comment.