Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Improved performance of comparison with SIMD feature flag (2x-3.5x) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Aug 20, 2021
1 parent 4f20537 commit 2873e1a
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 74 deletions.
2 changes: 1 addition & 1 deletion benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where

fn bench_op_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T, op: Operator)
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
primitive_compare_scalar(
criterion::black_box(arr_a),
Expand Down
3 changes: 3 additions & 0 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ mod boolean;
mod primitive;
mod utf8;

mod simd;
pub use simd::{Simd8, Simd8Lanes};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operator {
Lt,
Expand Down
122 changes: 52 additions & 70 deletions src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,33 @@ use crate::{
error::{ArrowError, Result},
};

use super::simd::{Simd8, Simd8Lanes};
use super::{super::utils::combine_validities, Operator};

pub(crate) fn compare_values_op<T, F>(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
assert_eq!(lhs.len(), rhs.len());
let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8);

let lhs_chunks_iter = lhs.chunks_exact(8);
let lhs_remainder = lhs_chunks_iter.remainder();
let rhs_chunks_iter = rhs.chunks_exact(8);
let rhs_remainder = rhs_chunks_iter.remainder();

let chunks = lhs.len() / 8;

values[..chunks]
.iter_mut()
.zip(lhs_chunks_iter)
.zip(rhs_chunks_iter)
.for_each(|((byte, lhs), rhs)| {
lhs.iter()
.zip(rhs.iter())
.enumerate()
.for_each(|(i, (&lhs, &rhs))| {
*byte |= if op(lhs, rhs) { 1 << i } else { 0 };
});
});
let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| {
let lhs = T::Simd::from_chunk(lhs);
let rhs = T::Simd::from_chunk(rhs);
op(lhs, rhs)
});
values.extend_from_trusted_len_iter(iterator);

if !lhs_remainder.is_empty() {
let last = &mut values[chunks];
lhs_remainder
.iter()
.zip(rhs_remainder.iter())
.enumerate()
.for_each(|(i, (&lhs, &rhs))| {
*last |= if op(lhs, rhs) { 1 << i } else { 0 };
});
let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default());
let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default());
values.push(op(lhs, rhs))
};
MutableBitmap::from_buffer(values, lhs.len())
}
Expand All @@ -70,8 +58,8 @@ where
/// comparison function.
fn compare_op<T, F>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
Expand All @@ -90,31 +78,25 @@ where
/// a specified comparison function.
pub fn compare_op_scalar<T, F>(lhs: &PrimitiveArray<T>, rhs: T, op: F) -> Result<BooleanArray>
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
let validity = lhs.validity().clone();

let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8);
let rhs = T::Simd::from_chunk(&[rhs; 8]);

let lhs_chunks_iter = lhs.values().chunks_exact(8);
let lhs_remainder = lhs_chunks_iter.remainder();
let chunks = lhs.len() / 8;

values[..chunks]
.iter_mut()
.zip(lhs_chunks_iter)
.for_each(|(byte, chunk)| {
chunk.iter().enumerate().for_each(|(i, &c_i)| {
*byte |= if op(c_i, rhs) { 1 << i } else { 0 };
});
});
let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iterator = lhs_chunks_iter.map(|lhs| {
let lhs = T::Simd::from_chunk(lhs);
op(lhs, rhs)
});
values.extend_from_trusted_len_iter(iterator);

if !lhs_remainder.is_empty() {
let last = &mut values[chunks];
lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| {
*last |= if op(lhs, rhs) { 1 << i } else { 0 };
});
let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default());
values.push(op(lhs, rhs))
};

Ok(BooleanArray::from_data(
Expand All @@ -126,105 +108,105 @@ where
/// Perform `lhs == rhs` operation on two arrays.
pub fn eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a == b)
compare_op(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a == b)
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left != right` operation on two arrays.
pub fn neq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a != b)
compare_op(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a != b)
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left < right` operation on two arrays.
pub fn lt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a < b)
compare_op(lhs, rhs, |a, b| a.lt(b))
}

/// Perform `left < right` operation on an array and a scalar value.
pub fn lt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a < b)
compare_op_scalar(lhs, rhs, |a, b| a.lt(b))
}

/// Perform `left <= right` operation on two arrays.
pub fn lt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a <= b)
compare_op(lhs, rhs, |a, b| a.lt_eq(b))
}

/// Perform `left <= right` operation on an array and a scalar value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a <= b)
compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b))
}

/// Perform `left > right` operation on two arrays. Non-null values are greater than null
/// values.
pub fn gt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a > b)
compare_op(lhs, rhs, |a, b| a.gt(b))
}

/// Perform `left > right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a > b)
compare_op_scalar(lhs, rhs, |a, b| a.gt(b))
}

/// Perform `left >= right` operation on two arrays. Non-null values are greater than null
/// values.
pub fn gt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a >= b)
compare_op(lhs, rhs, |a, b| a.gt_eq(b))
}

/// Perform `left >= right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a >= b)
compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b))
}

pub fn compare<T: NativeType + std::cmp::PartialOrd>(
pub fn compare<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
op: Operator,
Expand All @@ -239,7 +221,7 @@ pub fn compare<T: NativeType + std::cmp::PartialOrd>(
}
}

pub fn compare_scalar<T: NativeType + std::cmp::PartialOrd>(
pub fn compare_scalar<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: T,
op: Operator,
Expand Down
91 changes: 91 additions & 0 deletions src/compute/comparison/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::types::NativeType;

/// [`NativeType`] that supports a representation of 8 lanes
pub trait Simd8: NativeType {
type Simd: Simd8Lanes<Self>;
}

pub trait Simd8Lanes<T>: Copy {
fn from_chunk(v: &[T]) -> Self;
fn from_incomplete_chunk(v: &[T], remaining: T) -> Self;
fn eq(self, other: Self) -> u8;
fn neq(self, other: Self) -> u8;
fn lt_eq(self, other: Self) -> u8;
fn lt(self, other: Self) -> u8;
fn gt(self, other: Self) -> u8;
fn gt_eq(self, other: Self) -> u8;
}

#[inline]
pub(super) fn set<T: Copy, F: Fn(T, T) -> bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 {
let mut byte = 0u8;
lhs.iter()
.zip(rhs.iter())
.enumerate()
.for_each(|(i, (lhs, rhs))| {
byte |= if op(*lhs, *rhs) { 1 << i } else { 0 };
});
byte
}

macro_rules! simd8_native {
($type:ty) => {
impl Simd8 for $type {
type Simd = [$type; 8];
}

impl Simd8Lanes<$type> for [$type; 8] {
#[inline]
fn from_chunk(v: &[$type]) -> Self {
v.try_into().unwrap()
}

#[inline]
fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self {
let mut a = [remaining; 8];
a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b);
a
}

#[inline]
fn eq(self, other: Self) -> u8 {
set(self, other, |x, y| x == y)
}

#[inline]
fn neq(self, other: Self) -> u8 {
#[allow(clippy::float_cmp)]
set(self, other, |x, y| x != y)
}

#[inline]
fn lt_eq(self, other: Self) -> u8 {
set(self, other, |x, y| x <= y)
}

#[inline]
fn lt(self, other: Self) -> u8 {
set(self, other, |x, y| x < y)
}

#[inline]
fn gt_eq(self, other: Self) -> u8 {
set(self, other, |x, y| x >= y)
}

#[inline]
fn gt(self, other: Self) -> u8 {
set(self, other, |x, y| x > y)
}
}
};
}

#[cfg(not(feature = "simd"))]
mod native;
#[cfg(not(feature = "simd"))]
pub use native::*;
#[cfg(feature = "simd")]
mod packed;
#[cfg(feature = "simd")]
pub use packed::*;
15 changes: 15 additions & 0 deletions src/compute/comparison/simd/native.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::convert::TryInto;

use super::{set, Simd8, Simd8Lanes};

simd8_native!(u8);
simd8_native!(u16);
simd8_native!(u32);
simd8_native!(u64);
simd8_native!(i8);
simd8_native!(i16);
simd8_native!(i32);
simd8_native!(i128);
simd8_native!(i64);
simd8_native!(f32);
simd8_native!(f64);
Loading

0 comments on commit 2873e1a

Please sign in to comment.