Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Made DynComparator Send+Sync #414

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 29 additions & 39 deletions src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::error::{ArrowError, Result};
use crate::{array::*, types::NativeType};

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;

/// implements comparison using IEEE 754 total ordering for f32
// Original implementation from https://doc.rust-lang.org/std/primitive.f32.html#method.total_cmp
Expand Down Expand Up @@ -47,66 +47,57 @@ where
l.cmp(r)
}

fn compare_primitives<'a, T: NativeType + Ord>(
left: &'a dyn Array,
right: &'a dyn Array,
) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let left = left.values();
let right = right.values();
Box::new(move |i, j| total_cmp(&left[i], &right[j]))
fn compare_primitives<T: NativeType + Ord>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j)))
}

fn compare_boolean<'a>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap();
fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_f32<'a>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<PrimitiveArray<f32>>().unwrap();
fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f32>>().unwrap().clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap();
let left = left.values();
let right = right.values();
Box::new(move |i, j| total_cmp_f32(&left[i], &right[j]))
.unwrap().clone();
Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j)))
}

fn compare_f64<'a>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap().clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap();
let left = left.values();
let right = right.values();
Box::new(move |i, j| total_cmp_f64(&left[i], &right[j]))
.unwrap().clone();
Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j)))
}

fn compare_string<'a, O: Offset>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
fn compare_string<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_binary<'a, O: Offset>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap();
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap();
fn compare_binary<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_dict<'a, K>(
left: &'a DictionaryArray<K>,
right: &'a DictionaryArray<K>,
) -> Result<DynComparator<'a>>
fn compare_dict<K>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
) -> Result<DynComparator>
where
K: DictionaryKey,
{
let left_keys = left.keys().values();
let right_keys = right.keys().values();
let left_keys = left.keys().values().clone();
let right_keys = right.keys().values().clone();

let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?;

Expand Down Expand Up @@ -145,8 +136,7 @@ macro_rules! dyn_dict {
/// # Error
/// The arrays' [`DataType`] must be equal and the types must have a natural order.
// This is a factory of comparisons.
// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
pub fn build_compare<'a>(left: &'a dyn Array, right: &'a dyn Array) -> Result<DynComparator<'a>> {
pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparator> {
use DataType::*;
use IntervalUnit::*;
use TimeUnit::*;
Expand Down