Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added SIMD support for comparison kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 20, 2021
1 parent 999742b commit 9dc32a9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where

fn bench_op_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T, op: Operator)
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
primitive_compare_scalar(
criterion::black_box(arr_a),
Expand Down
46 changes: 22 additions & 24 deletions src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ where
/// a specified comparison function.
pub fn compare_op_scalar<T, F>(lhs: &PrimitiveArray<T>, rhs: T, op: F) -> Result<BooleanArray>
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
let validity = lhs.validity().clone();
let rhs = T::Simd::from_chunk(&[rhs; 8]);

let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8);

Expand All @@ -97,17 +98,14 @@ where
values[..chunks]
.iter_mut()
.zip(lhs_chunks_iter)
.for_each(|(byte, chunk)| {
chunk.iter().enumerate().for_each(|(i, &c_i)| {
*byte |= if op(c_i, rhs) { 1 << i } else { 0 };
});
.for_each(|(byte, lhs)| {
let lhs = T::Simd::from_chunk(lhs);
*byte = op(lhs, rhs);
});

if !lhs_remainder.is_empty() {
let last = &mut values[chunks];
lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| {
*last |= if op(lhs, rhs) { 1 << i } else { 0 };
});
let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default());
values[chunks] = op(lhs, rhs);
};

Ok(BooleanArray::from_data(
Expand All @@ -127,9 +125,9 @@ where
/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a == b)
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left != right` operation on two arrays.
Expand All @@ -143,9 +141,9 @@ where
/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a != b)
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left < right` operation on two arrays.
Expand All @@ -159,9 +157,9 @@ where
/// Perform `left < right` operation on an array and a scalar value.
pub fn lt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a < b)
compare_op_scalar(lhs, rhs, |a, b| a.lt(b))
}

/// Perform `left <= right` operation on two arrays.
Expand All @@ -176,9 +174,9 @@ where
/// Null values are less than non-null values.
pub fn lt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a <= b)
compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b))
}

/// Perform `left > right` operation on two arrays. Non-null values are greater than null
Expand All @@ -194,9 +192,9 @@ where
/// Non-null values are greater than null values.
pub fn gt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a > b)
compare_op_scalar(lhs, rhs, |a, b| a.gt(b))
}

/// Perform `left >= right` operation on two arrays. Non-null values are greater than null
Expand All @@ -212,12 +210,12 @@ where
/// Non-null values are greater than null values.
pub fn gt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a >= b)
compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b))
}

pub fn compare<T: NativeType + std::cmp::PartialOrd + Simd8>(
pub fn compare<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
op: Operator,
Expand All @@ -232,7 +230,7 @@ pub fn compare<T: NativeType + std::cmp::PartialOrd + Simd8>(
}
}

pub fn compare_scalar<T: NativeType + std::cmp::PartialOrd>(
pub fn compare_scalar<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: T,
op: Operator,
Expand Down
36 changes: 18 additions & 18 deletions src/compute/comparison/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
use crate::types::NativeType;

/// [`NativeType`] that supports a representation of 8 lanes
pub trait Simd8: NativeType {
type Simd: Simd8Lanes<Self>;
}

pub trait Simd8Lanes<T>: Copy {
fn from_chunk(v: &[T]) -> Self;
fn from_incomplete_chunk(v: &[T], remaining: T) -> Self;
fn eq(self, other: Self) -> u8;
fn neq(self, other: Self) -> u8;
fn lt_eq(self, other: Self) -> u8;
fn lt(self, other: Self) -> u8;
fn gt(self, other: Self) -> u8;
fn gt_eq(self, other: Self) -> u8;
}

#[inline]
pub(super) fn set<T: Copy, F: Fn(T, T) -> bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 {
let mut byte = 0u8;
Expand Down Expand Up @@ -71,21 +89,3 @@ pub use native::*;
mod packed;
#[cfg(feature = "simd")]
pub use packed::*;

use crate::types::NativeType;

/// [`NativeType`] that supports a representation of 8 lanes
pub trait Simd8: NativeType {
type Simd: Simd8Lanes<Self>;
}

pub trait Simd8Lanes<T> {
fn from_chunk(v: &[T]) -> Self;
fn from_incomplete_chunk(v: &[T], remaining: T) -> Self;
fn eq(self, other: Self) -> u8;
fn neq(self, other: Self) -> u8;
fn lt_eq(self, other: Self) -> u8;
fn lt(self, other: Self) -> u8;
fn gt(self, other: Self) -> u8;
fn gt_eq(self, other: Self) -> u8;
}
2 changes: 2 additions & 0 deletions src/compute/comparison/simd/packed.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::TryInto;

use super::{set, Simd8, Simd8Lanes};

use packed_simd::*;
Expand Down

0 comments on commit 9dc32a9

Please sign in to comment.