Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Sort improve (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li authored Aug 4, 2021
1 parent 1908eb7 commit 3d123b3
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 42 deletions.
8 changes: 4 additions & 4 deletions benches/growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) {
let i32_array = create_primitive_array::<i32>(1026 * 10, DataType::Int32, 0.0);
c.bench_function("growable::primitive::non_null::non_null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10);
values
.clone()
.into_iter()
Expand All @@ -25,7 +25,7 @@ fn add_benchmark(c: &mut Criterion) {
let i32_array = create_primitive_array::<i32>(1026 * 10, DataType::Int32, 0.0);
c.bench_function("growable::primitive::non_null::null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10);
values.clone().into_iter().for_each(|start| {
if start % 2 == 0 {
a.extend_validity(10);
Expand All @@ -41,7 +41,7 @@ fn add_benchmark(c: &mut Criterion) {
let values = values.collect::<Vec<_>>();
c.bench_function("growable::primitive::null::non_null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10);
values
.clone()
.into_iter()
Expand All @@ -50,7 +50,7 @@ fn add_benchmark(c: &mut Criterion) {
});
c.bench_function("growable::primitive::null::null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10);
values.clone().into_iter().for_each(|start| {
if start % 2 == 0 {
a.extend_validity(10);
Expand Down
15 changes: 14 additions & 1 deletion benches/sort_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
extern crate criterion;
use criterion::Criterion;

use arrow2::compute::sort::{lexsort, sort, SortColumn, SortOptions};
use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOptions};
use arrow2::util::bench_util::*;
use arrow2::{array::*, datatypes::*};

Expand All @@ -42,6 +42,15 @@ fn bench_sort(arr_a: &dyn Array) {
sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap();
}

fn bench_sort_limit(arr_a: &dyn Array) {
let _: PrimitiveArray<u32> = sort_to_indices(
criterion::black_box(arr_a),
&SortOptions::default(),
Some(100),
)
.unwrap();
}

fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
let size = 2usize.pow(log2_size);
Expand All @@ -51,6 +60,10 @@ fn add_benchmark(c: &mut Criterion) {
b.iter(|| bench_sort(&arr_a))
});

c.bench_function(&format!("sort-limit 2^{} f32", log2_size), |b| {
b.iter(|| bench_sort_limit(&arr_a))
});

let arr_b = create_primitive_array_with_seed::<f32>(size, DataType::Float32, 0.0, 43);
c.bench_function(&format!("lexsort 2^{} f32", log2_size), |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b))
Expand Down
2 changes: 1 addition & 1 deletion benches/take_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn create_random_index(size: usize, null_density: f32) -> PrimitiveArray<i32> {
(0..size)
.map(|_| {
if rng.gen::<f32>() > null_density {
let value = rng.gen_range::<i32, _, _>(0i32, size as i32);
let value = rng.gen_range::<i32, _>(0i32..size as i32);
Some(value)
} else {
None
Expand Down
67 changes: 66 additions & 1 deletion src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use std::convert::TryFrom;
use num::Num;

use crate::{
buffer::Buffer,
buffer::{Buffer, MutableBuffer},
types::{NativeType, NaturalDataType},
};

/// Trait describing any type that can be used to index a slot of an array.
pub trait Index: NativeType + NaturalDataType {
fn to_usize(&self) -> usize;
fn from_usize(index: usize) -> Option<Self>;
fn is_usize() -> bool {
false
}

fn buffer_from_range(start: usize, end: usize) -> Option<MutableBuffer<Self>>;
}

/// Trait describing types that can be used as offsets as per Arrow specification.
Expand Down Expand Up @@ -71,6 +76,19 @@ impl Index for i32 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn buffer_from_range(start: usize, end: usize) -> Option<MutableBuffer<Self>> {
let start = Self::from_usize(start);
let end = Self::from_usize(end);
match (start, end) {
(Some(start), Some(end)) => unsafe {
Some(MutableBuffer::<Self>::from_trusted_len_iter_unchecked(
start..end,
))
},
_ => None,
}
}
}

impl Index for i64 {
Expand All @@ -83,6 +101,19 @@ impl Index for i64 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn buffer_from_range(start: usize, end: usize) -> Option<MutableBuffer<Self>> {
let start = Self::from_usize(start);
let end = Self::from_usize(end);
match (start, end) {
(Some(start), Some(end)) => unsafe {
Some(MutableBuffer::<Self>::from_trusted_len_iter_unchecked(
start..end,
))
},
_ => None,
}
}
}

impl Index for u32 {
Expand All @@ -95,6 +126,23 @@ impl Index for u32 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn is_usize() -> bool {
std::mem::size_of::<Self>() == std::mem::size_of::<usize>()
}

fn buffer_from_range(start: usize, end: usize) -> Option<MutableBuffer<Self>> {
let start = Self::from_usize(start);
let end = Self::from_usize(end);
match (start, end) {
(Some(start), Some(end)) => unsafe {
Some(MutableBuffer::<Self>::from_trusted_len_iter_unchecked(
start..end,
))
},
_ => None,
}
}
}

impl Index for u64 {
Expand All @@ -107,6 +155,23 @@ impl Index for u64 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn is_usize() -> bool {
std::mem::size_of::<Self>() == std::mem::size_of::<usize>()
}

fn buffer_from_range(start: usize, end: usize) -> Option<MutableBuffer<Self>> {
let start = Self::from_usize(start);
let end = Self::from_usize(end);
match (start, end) {
(Some(start), Some(end)) => unsafe {
Some(MutableBuffer::<Self>::from_trusted_len_iter_unchecked(
start..end,
))
},
_ => None,
}
}
}

#[inline]
Expand Down
9 changes: 9 additions & 0 deletions src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ impl<T: NativeType> MutableBuffer<T> {
}
}

/// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where
/// all bytes are not initialized
#[inline]
pub unsafe fn from_len(len: usize) -> Self {
let mut buffer = MutableBuffer::with_capacity(len);
buffer.set_len(len);
buffer
}

/// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff
/// `self.len + additional > capacity`.
/// # Example
Expand Down
39 changes: 39 additions & 0 deletions src/compute/merge_sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ where
None => self.right = None,
}
}

/// Collect the MergeSortSlices to be a vec for reusing
#[warn(dead_code)]
pub fn to_vec(self, limit: Option<usize>) -> Vec<MergeSlice> {
match limit {
Some(limit) => {
let mut v = Vec::with_capacity(limit);
let mut current_len = 0;
for (index, start, len) in self {
if len + current_len >= limit {
v.push((index, start, limit - current_len));
break;
} else {
v.push((index, start, len));
}
current_len += len;
}

v
}
None => self.into_iter().collect(),
}
}
}

impl<'a, L, R> Iterator for MergeSortSlices<'a, L, R>
Expand Down Expand Up @@ -562,6 +585,22 @@ mod tests {
Ok(())
}

#[test]
fn test_merge_slices_to_vec() -> Result<()> {
let a0: &dyn Array = &Int32Array::from_slice(&[0, 2, 4, 6, 8]);
let a1: &dyn Array = &Int32Array::from_slice(&[1, 3, 5, 7, 9]);

let options = SortOptions::default();
let arrays = vec![a0, a1];
let pairs = vec![(arrays.as_ref(), &options)];
let comparator = build_comparator(&pairs)?;

let slices = merge_sort_slices(once(&(0, 0, 5)), once(&(1, 0, 5)), &comparator);
let vec = slices.to_vec(Some(5));
assert_eq!(vec, [(0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1), (0, 2, 1)]);
Ok(())
}

#[test]
fn test_merge_4_i32() -> Result<()> {
let a0: &dyn Array = &Int32Array::from_slice(&[0, 1]);
Expand Down
3 changes: 2 additions & 1 deletion src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn sort_boolean<I: Index>(
if !descending {
valids.sort_by(|a, b| a.1.cmp(&b.1));
} else {
valids.sort_by(|a, b| a.1.cmp(&b.1).reverse());
valids.sort_by(|a, b| b.1.cmp(&a.1));
// reverse to keep a stable ordering
nulls.reverse();
}
Expand All @@ -45,6 +45,7 @@ pub fn sort_boolean<I: Index>(
// un-efficient; there are much more performant ways of sorting nulls above, anyways.
if let Some(limit) = limit {
values.truncate(limit);
values.shrink_to_fit();
}

PrimitiveArray::<I>::from_data(I::DATA_TYPE, values.into(), None)
Expand Down
43 changes: 13 additions & 30 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,21 @@ fn k_element_sort_inner<I: Index, T, G, F>(
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &I, rhs: &I| {
let mut compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
cmp(&rhs, &lhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
} else {
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &I, rhs: &I| {
let mut compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
before.sort_unstable_by(compare);
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
}
}

Expand Down Expand Up @@ -74,7 +64,7 @@ fn sort_unstable_by<I, T, G, F>(
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
cmp(&rhs, &lhs)
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
Expand Down Expand Up @@ -110,8 +100,7 @@ where
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<I>::from_len_zeroed(length);

let mut indices = unsafe { MutableBuffer::<I>::from_len(length) };
if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
Expand All @@ -134,12 +123,12 @@ where
// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let limit = limit.saturating_sub(validity.null_count());
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
}
} else {
let last_valid_index = length - validity.null_count();
let last_valid_index = length.saturating_sub(validity.null_count());
let mut nulls = 0;
let mut valids = 0;
validity.iter().zip(0..length).for_each(|(x, index)| {
Expand All @@ -165,21 +154,15 @@ where

indices
} else {
let mut indices = unsafe {
MutableBuffer::from_trusted_len_iter_unchecked(
(0..length).map(|x| I::from_usize(x).unwrap()),
)
};

let mut indices = Index::buffer_from_range(0, length).unwrap();
// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
sort_unstable_by(&mut indices, get, cmp, descending, limit);

indices.truncate(limit);
indices.shrink_to_fit();

indices
};

PrimitiveArray::<I>::from_data(I::DATA_TYPE, indices.into(), None)
}
1 change: 1 addition & 0 deletions src/compute/sort/lex_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ pub fn lexsort_to_indices<I: Index>(
let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator);
before.sort_unstable_by(lex_comparator);
values.truncate(limit);
values.shrink_to_fit();
} else {
values.sort_unstable_by(lex_comparator);
}
Expand Down
2 changes: 1 addition & 1 deletion src/compute/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ where
if !options.descending {
valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()))
} else {
valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse())
valids.sort_by(|a, b| cmp_array(b.1.as_ref(), a.1.as_ref()))
}

let values = valids.iter().map(|tuple| tuple.0);
Expand Down
Loading

0 comments on commit 3d123b3

Please sign in to comment.