Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added support for limited sort #218

Merged
merged 5 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions benches/sort_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) {
},
];

criterion::black_box(lexsort(&columns).unwrap());
criterion::black_box(lexsort(&columns, None).unwrap());
}

fn bench_sort(arr_a: &dyn Array) {
sort(criterion::black_box(arr_a), &SortOptions::default()).unwrap();
sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap();
}

fn add_benchmark(c: &mut Criterion) {
Expand All @@ -66,6 +66,11 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function(&format!("lexsort null 2^{} f32", log2_size), |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b))
});

let arr_a = create_string_array::<i32>(size, 0.1);
c.bench_function(&format!("sort utf8 null 2^{}", log2_size), |b| {
b.iter(|| bench_sort(&arr_a))
});
});
}

Expand Down
9 changes: 9 additions & 0 deletions src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ impl<T: NativeType> MutableBuffer<T> {
self.len = 0
}

/// Shortens the buffer.
/// If `len` is greater or equal to the buffers' current length, this has no effect.
#[inline]
pub fn truncate(&mut self, len: usize) {
if len < self.len {
self.len = len;
}
}

/// Returns the data stored in this buffer as a slice.
#[inline]
pub fn as_slice(&self) -> &[T] {
Expand Down
4 changes: 2 additions & 2 deletions src/compute/merge_sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ mod tests {
let options = SortOptions::default();

// sort individually, potentially in parallel.
let a0 = sort(a0, &options)?;
let a1 = sort(a1, &options)?;
let a0 = sort(a0, &options, None)?;
let a1 = sort(a1, &options, None)?;

// merge then. If multiple arrays, this can be applied in parallel.
let result = merge_sort(a0.as_ref(), a1.as_ref(), &options)?;
Expand Down
52 changes: 52 additions & 0 deletions src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::{
array::{Array, BooleanArray, Int32Array},
buffer::MutableBuffer,
datatypes::DataType,
};

use super::SortOptions;

/// Returns the indices that would sort a [`BooleanArray`].
pub fn sort_boolean(
values: &BooleanArray,
value_indices: Vec<i32>,
null_indices: Vec<i32>,
options: &SortOptions,
limit: Option<usize>,
) -> Int32Array {
let descending = options.descending;

// create tuples that are used for sorting
let mut valids = value_indices
.into_iter()
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(i32, bool)>>();

let mut nulls = null_indices;

if !descending {
valids.sort_by(|a, b| a.1.cmp(&b.1));
} else {
valids.sort_by(|a, b| a.1.cmp(&b.1).reverse());
// reverse to keep a stable ordering
nulls.reverse();
}

let mut values = MutableBuffer::<i32>::with_capacity(values.len());

if options.nulls_first {
values.extend_from_slice(nulls.as_slice());
valids.iter().for_each(|x| values.push(x.0));
} else {
// nulls last
valids.iter().for_each(|x| values.push(x.0));
values.extend_from_slice(nulls.as_slice());
}

// un-efficient; there are much more performant ways of sorting nulls above, anyways.
if let Some(limit) = limit {
values.truncate(limit);
}

Int32Array::from_data(DataType::Int32, values.into(), None)
}
179 changes: 179 additions & 0 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, datatypes::DataType};

use super::SortOptions;

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn k_element_sort_inner<T, G, F>(
indices: &mut [i32],
get: G,
descending: bool,
limit: usize,
mut cmp: F,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
} else {
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &i32, rhs: &i32| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
}
}

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn sort_unstable_by<T, G, F>(
indices: &mut [i32],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, get, descending, limit, cmp);
}

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs).reverse()
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(*lhs as usize);
let rhs = get(*rhs as usize);
cmp(&lhs, &rhs)
})
}
}

/// # Safety
/// This function guarantees that:
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<T, G, F>(
validity: &Option<Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<i32>
where
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
let descending = options.descending;

let limit = limit.unwrap_or(length);
// Safety: without this, we go out of bounds when limit >= length.
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<i32>::from_len_zeroed(length);

if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
valids += 1;
} else {
indices[nulls] = index;
nulls += 1;
}
});

if limit > validity.null_count() {
// when limit is larger, we must sort values:

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
}
} else {
let last_valid_index = length - validity.null_count();
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(0..length as i32)
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});

// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction <= values.len()
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
sort_unstable_by(indices, get, cmp, options.descending, limit);
}

indices.truncate(limit);
indices.shrink_to_fit();

indices
} else {
let mut indices =
unsafe { MutableBuffer::from_trusted_len_iter_unchecked(0..length as i32) };

// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
sort_unstable_by(&mut indices, get, cmp, descending, limit);

indices.truncate(limit);
indices.shrink_to_fit();

indices
};
PrimitiveArray::<i32>::from_data(DataType::Int32, indices.into(), None)
}
53 changes: 27 additions & 26 deletions src/compute/sort/lex_sort.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;

use crate::compute::take;
Expand Down Expand Up @@ -66,14 +49,14 @@ pub struct SortColumn<'a> {
/// nulls_first: false,
/// }),
/// },
/// ]).unwrap();
/// ], None).unwrap();
///
/// let sorted = sorted_columns[0].as_any().downcast_ref::<Int64Array>().unwrap();
/// assert_eq!(sorted.value(1), -64);
/// assert!(sorted.is_null(0));
/// ```
pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<Box<dyn Array>>> {
let indices = lexsort_to_indices(columns)?;
pub fn lexsort(columns: &[SortColumn], limit: Option<usize>) -> Result<Vec<Box<dyn Array>>> {
let indices = lexsort_to_indices(columns, limit)?;
columns
.iter()
.map(|c| take::take(c.values, &indices))
Expand Down Expand Up @@ -135,9 +118,12 @@ pub(crate) fn build_compare(array: &dyn Array, sort_option: SortOptions) -> Resu
})
}

/// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
/// [`Int32Array`] of indices.
pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>> {
/// Sorts a list of [`SortColumn`] into a non-nullable [`PrimitiveArray<i32>`]
/// representing the indices that would sort the columns.
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
pub fn lexsort_to_indices(
columns: &[SortColumn],
limit: Option<usize>,
) -> Result<PrimitiveArray<i32>> {
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"Sort requires at least one column".to_string(),
Expand All @@ -146,7 +132,7 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>>
if columns.len() == 1 {
// fallback to non-lexical sort
let column = &columns[0];
return sort_to_indices(column.values, &column.options.unwrap_or_default());
return sort_to_indices(column.values, &column.options.unwrap_or_default(), limit);
}

let row_count = columns[0].values.len();
Expand Down Expand Up @@ -180,7 +166,15 @@ pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<PrimitiveArray<i32>>
// Safety: `0..row_count` is TrustedLen
let mut values =
unsafe { MutableBuffer::<i32>::from_trusted_len_iter_unchecked(0..row_count as i32) };
values.sort_unstable_by(lex_comparator);

if let Some(limit) = limit {
let limit = limit.min(row_count);
let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator);
before.sort_unstable_by(lex_comparator);
values.truncate(limit);
} else {
values.sort_unstable_by(lex_comparator);
}

Ok(PrimitiveArray::<i32>::from_data(
DataType::Int32,
Expand All @@ -196,7 +190,14 @@ mod tests {
use super::*;

fn test_lex_sort_arrays(input: Vec<SortColumn>, expected: Vec<Box<dyn Array>>) {
let sorted = lexsort(&input).unwrap();
let sorted = lexsort(&input, None).unwrap();
assert_eq!(sorted, expected);

let sorted = lexsort(&input, Some(2)).unwrap();
let expected = expected
.into_iter()
.map(|x| x.slice(0, 2))
.collect::<Vec<_>>();
assert_eq!(sorted, expected);
}

Expand Down
Loading