From afaff939154a87e41db0b363c725509e6f479a01 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 20 Aug 2021 14:01:50 +0000 Subject: [PATCH] Added SIMD support for comparison kernel. --- benches/comparison_kernels.rs | 2 +- src/compute/comparison/primitive.rs | 77 ++++++++++++--------------- src/compute/comparison/simd/mod.rs | 36 ++++++------- src/compute/comparison/simd/packed.rs | 2 + src/trusted_len.rs | 2 + 5 files changed, 56 insertions(+), 63 deletions(-) diff --git a/benches/comparison_kernels.rs b/benches/comparison_kernels.rs index a5b6a8293dc..b1643bb1aae 100644 --- a/benches/comparison_kernels.rs +++ b/benches/comparison_kernels.rs @@ -32,7 +32,7 @@ where fn bench_op_scalar(arr_a: &PrimitiveArray, value_b: T, op: Operator) where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { primitive_compare_scalar( criterion::black_box(arr_a), diff --git a/src/compute/comparison/primitive.rs b/src/compute/comparison/primitive.rs index bcd7fdfe35d..04dc5aa078b 100644 --- a/src/compute/comparison/primitive.rs +++ b/src/compute/comparison/primitive.rs @@ -32,29 +32,24 @@ where F: Fn(T::Simd, T::Simd) -> u8, { assert_eq!(lhs.len(), rhs.len()); - let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); let lhs_chunks_iter = lhs.chunks_exact(8); let lhs_remainder = lhs_chunks_iter.remainder(); let rhs_chunks_iter = rhs.chunks_exact(8); let rhs_remainder = rhs_chunks_iter.remainder(); - let chunks = lhs.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .zip(rhs_chunks_iter) - .for_each(|((byte, lhs), rhs)| { - let lhs = T::Simd::from_chunk(lhs); - let rhs = T::Simd::from_chunk(rhs); - *byte = op(lhs, rhs); - }); + let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| { + let lhs = T::Simd::from_chunk(lhs); + let rhs = T::Simd::from_chunk(rhs); + op(lhs, rhs) + }); + values.extend_from_trusted_len_iter(iterator); if !lhs_remainder.is_empty() { let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default()); - values[chunks] = op(lhs, rhs); + values.push(op(lhs, rhs)) }; MutableBitmap::from_buffer(values, lhs.len()) } @@ -83,31 +78,25 @@ where /// a specified comparison function. pub fn compare_op_scalar(lhs: &PrimitiveArray, rhs: T, op: F) -> Result where - T: NativeType, - F: Fn(T, T) -> bool, + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, { let validity = lhs.validity().clone(); - - let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); + let rhs = T::Simd::from_chunk(&[rhs; 8]); let lhs_chunks_iter = lhs.values().chunks_exact(8); let lhs_remainder = lhs_chunks_iter.remainder(); - let chunks = lhs.len() / 8; - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .for_each(|(byte, chunk)| { - chunk.iter().enumerate().for_each(|(i, &c_i)| { - *byte |= if op(c_i, rhs) { 1 << i } else { 0 }; - }); - }); + let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.map(|lhs| { + let lhs = T::Simd::from_chunk(lhs); + op(lhs, rhs) + }); + values.extend_from_trusted_len_iter(iterator); if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| { - *last |= if op(lhs, rhs) { 1 << i } else { 0 }; - }); + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + values.push(op(lhs, rhs)) }; Ok(BooleanArray::from_data( @@ -127,9 +116,9 @@ where /// Perform `left == right` operation on an array and a scalar value. pub fn eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a == b) + compare_op_scalar(lhs, rhs, |a, b| a.eq(b)) } /// Perform `left != right` operation on two arrays. @@ -143,9 +132,9 @@ where /// Perform `left != right` operation on an array and a scalar value. pub fn neq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a != b) + compare_op_scalar(lhs, rhs, |a, b| a.neq(b)) } /// Perform `left < right` operation on two arrays. @@ -159,9 +148,9 @@ where /// Perform `left < right` operation on an array and a scalar value. pub fn lt_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a < b) + compare_op_scalar(lhs, rhs, |a, b| a.lt(b)) } /// Perform `left <= right` operation on two arrays. @@ -176,9 +165,9 @@ where /// Null values are less than non-null values. pub fn lt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a <= b) + compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b)) } /// Perform `left > right` operation on two arrays. Non-null values are greater than null @@ -194,9 +183,9 @@ where /// Non-null values are greater than null values. pub fn gt_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a > b) + compare_op_scalar(lhs, rhs, |a, b| a.gt(b)) } /// Perform `left >= right` operation on two arrays. Non-null values are greater than null @@ -212,12 +201,12 @@ where /// Non-null values are greater than null values. pub fn gt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a >= b) + compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) } -pub fn compare( +pub fn compare( lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: Operator, @@ -232,7 +221,7 @@ pub fn compare( } } -pub fn compare_scalar( +pub fn compare_scalar( lhs: &PrimitiveArray, rhs: T, op: Operator, diff --git a/src/compute/comparison/simd/mod.rs b/src/compute/comparison/simd/mod.rs index 9768ef1ae8d..3fe82b81dbb 100644 --- a/src/compute/comparison/simd/mod.rs +++ b/src/compute/comparison/simd/mod.rs @@ -1,3 +1,21 @@ +use crate::types::NativeType; + +/// [`NativeType`] that supports a representation of 8 lanes +pub trait Simd8: NativeType { + type Simd: Simd8Lanes; +} + +pub trait Simd8Lanes: Copy { + fn from_chunk(v: &[T]) -> Self; + fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; + fn eq(self, other: Self) -> u8; + fn neq(self, other: Self) -> u8; + fn lt_eq(self, other: Self) -> u8; + fn lt(self, other: Self) -> u8; + fn gt(self, other: Self) -> u8; + fn gt_eq(self, other: Self) -> u8; +} + #[inline] pub(super) fn set bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 { let mut byte = 0u8; @@ -71,21 +89,3 @@ pub use native::*; mod packed; #[cfg(feature = "simd")] pub use packed::*; - -use crate::types::NativeType; - -/// [`NativeType`] that supports a representation of 8 lanes -pub trait Simd8: NativeType { - type Simd: Simd8Lanes; -} - -pub trait Simd8Lanes { - fn from_chunk(v: &[T]) -> Self; - fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; - fn eq(self, other: Self) -> u8; - fn neq(self, other: Self) -> u8; - fn lt_eq(self, other: Self) -> u8; - fn lt(self, other: Self) -> u8; - fn gt(self, other: Self) -> u8; - fn gt_eq(self, other: Self) -> u8; -} diff --git a/src/compute/comparison/simd/packed.rs b/src/compute/comparison/simd/packed.rs index 2a87fd70767..1e9b250c73b 100644 --- a/src/compute/comparison/simd/packed.rs +++ b/src/compute/comparison/simd/packed.rs @@ -1,3 +1,5 @@ +use std::convert::TryInto; + use super::{set, Simd8, Simd8Lanes}; use packed_simd::*; diff --git a/src/trusted_len.rs b/src/trusted_len.rs index 877a180c9a7..8c52c0bf861 100644 --- a/src/trusted_len.rs +++ b/src/trusted_len.rs @@ -26,6 +26,8 @@ where { } +unsafe impl TrustedLen for std::slice::ChunksExact<'_, T> {} + unsafe impl TrustedLen for std::slice::Windows<'_, T> {} unsafe impl TrustedLen for std::iter::Chain