diff --git a/benches/comparison_kernels.rs b/benches/comparison_kernels.rs index a5b6a8293dc..b1643bb1aae 100644 --- a/benches/comparison_kernels.rs +++ b/benches/comparison_kernels.rs @@ -32,7 +32,7 @@ where fn bench_op_scalar(arr_a: &PrimitiveArray, value_b: T, op: Operator) where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { primitive_compare_scalar( criterion::black_box(arr_a), diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 40bf8e1ca60..69d7ed07b1b 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -30,6 +30,9 @@ mod boolean; mod primitive; mod utf8; +mod simd; +pub use simd::{Simd8, Simd8Lanes}; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Operator { Lt, diff --git a/src/compute/comparison/primitive.rs b/src/compute/comparison/primitive.rs index e84f7a436b7..04dc5aa078b 100644 --- a/src/compute/comparison/primitive.rs +++ b/src/compute/comparison/primitive.rs @@ -23,45 +23,33 @@ use crate::{ error::{ArrowError, Result}, }; +use super::simd::{Simd8, Simd8Lanes}; use super::{super::utils::combine_validities, Operator}; pub(crate) fn compare_values_op(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap where - T: NativeType, - F: Fn(T, T) -> bool, + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, { assert_eq!(lhs.len(), rhs.len()); - let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); let lhs_chunks_iter = lhs.chunks_exact(8); let lhs_remainder = lhs_chunks_iter.remainder(); let rhs_chunks_iter = rhs.chunks_exact(8); let rhs_remainder = rhs_chunks_iter.remainder(); - let chunks = lhs.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .zip(rhs_chunks_iter) - .for_each(|((byte, lhs), rhs)| { - lhs.iter() - .zip(rhs.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *byte |= if op(lhs, rhs) { 1 << i } else { 0 }; - }); - }); + let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| { + let lhs = T::Simd::from_chunk(lhs); + let rhs = T::Simd::from_chunk(rhs); + op(lhs, rhs) + }); + values.extend_from_trusted_len_iter(iterator); if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder - .iter() - .zip(rhs_remainder.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *last |= if op(lhs, rhs) { 1 << i } else { 0 }; - }); + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default()); + values.push(op(lhs, rhs)) }; MutableBitmap::from_buffer(values, lhs.len()) } @@ -70,8 +58,8 @@ where /// comparison function. fn compare_op(lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: F) -> Result where - T: NativeType, - F: Fn(T, T) -> bool, + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, { if lhs.len() != rhs.len() { return Err(ArrowError::InvalidArgumentError( @@ -90,31 +78,25 @@ where /// a specified comparison function. pub fn compare_op_scalar(lhs: &PrimitiveArray, rhs: T, op: F) -> Result where - T: NativeType, - F: Fn(T, T) -> bool, + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, { let validity = lhs.validity().clone(); - - let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); + let rhs = T::Simd::from_chunk(&[rhs; 8]); let lhs_chunks_iter = lhs.values().chunks_exact(8); let lhs_remainder = lhs_chunks_iter.remainder(); - let chunks = lhs.len() / 8; - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .for_each(|(byte, chunk)| { - chunk.iter().enumerate().for_each(|(i, &c_i)| { - *byte |= if op(c_i, rhs) { 1 << i } else { 0 }; - }); - }); + let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.map(|lhs| { + let lhs = T::Simd::from_chunk(lhs); + op(lhs, rhs) + }); + values.extend_from_trusted_len_iter(iterator); if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| { - *last |= if op(lhs, rhs) { 1 << i } else { 0 }; - }); + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + values.push(op(lhs, rhs)) }; Ok(BooleanArray::from_data( @@ -126,105 +108,105 @@ where /// Perform `lhs == rhs` operation on two arrays. pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a == b) + compare_op(lhs, rhs, |a, b| a.eq(b)) } /// Perform `left == right` operation on an array and a scalar value. pub fn eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a == b) + compare_op_scalar(lhs, rhs, |a, b| a.eq(b)) } /// Perform `left != right` operation on two arrays. pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a != b) + compare_op(lhs, rhs, |a, b| a.neq(b)) } /// Perform `left != right` operation on an array and a scalar value. pub fn neq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a != b) + compare_op_scalar(lhs, rhs, |a, b| a.neq(b)) } /// Perform `left < right` operation on two arrays. pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a < b) + compare_op(lhs, rhs, |a, b| a.lt(b)) } /// Perform `left < right` operation on an array and a scalar value. pub fn lt_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a < b) + compare_op_scalar(lhs, rhs, |a, b| a.lt(b)) } /// Perform `left <= right` operation on two arrays. pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a <= b) + compare_op(lhs, rhs, |a, b| a.lt_eq(b)) } /// Perform `left <= right` operation on an array and a scalar value. /// Null values are less than non-null values. pub fn lt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a <= b) + compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b)) } /// Perform `left > right` operation on two arrays. Non-null values are greater than null /// values. pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a > b) + compare_op(lhs, rhs, |a, b| a.gt(b)) } /// Perform `left > right` operation on an array and a scalar value. /// Non-null values are greater than null values. pub fn gt_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a > b) + compare_op_scalar(lhs, rhs, |a, b| a.gt(b)) } /// Perform `left >= right` operation on two arrays. Non-null values are greater than null /// values. pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op(lhs, rhs, |a, b| a >= b) + compare_op(lhs, rhs, |a, b| a.gt_eq(b)) } /// Perform `left >= right` operation on an array and a scalar value. /// Non-null values are greater than null values. pub fn gt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> Result where - T: NativeType + std::cmp::PartialOrd, + T: NativeType + Simd8, { - compare_op_scalar(lhs, rhs, |a, b| a >= b) + compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) } -pub fn compare( +pub fn compare( lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: Operator, @@ -239,7 +221,7 @@ pub fn compare( } } -pub fn compare_scalar( +pub fn compare_scalar( lhs: &PrimitiveArray, rhs: T, op: Operator, diff --git a/src/compute/comparison/simd/mod.rs b/src/compute/comparison/simd/mod.rs new file mode 100644 index 00000000000..3fe82b81dbb --- /dev/null +++ b/src/compute/comparison/simd/mod.rs @@ -0,0 +1,91 @@ +use crate::types::NativeType; + +/// [`NativeType`] that supports a representation of 8 lanes +pub trait Simd8: NativeType { + type Simd: Simd8Lanes; +} + +pub trait Simd8Lanes: Copy { + fn from_chunk(v: &[T]) -> Self; + fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; + fn eq(self, other: Self) -> u8; + fn neq(self, other: Self) -> u8; + fn lt_eq(self, other: Self) -> u8; + fn lt(self, other: Self) -> u8; + fn gt(self, other: Self) -> u8; + fn gt_eq(self, other: Self) -> u8; +} + +#[inline] +pub(super) fn set bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 { + let mut byte = 0u8; + lhs.iter() + .zip(rhs.iter()) + .enumerate() + .for_each(|(i, (lhs, rhs))| { + byte |= if op(*lhs, *rhs) { 1 << i } else { 0 }; + }); + byte +} + +macro_rules! simd8_native { + ($type:ty) => { + impl Simd8 for $type { + type Simd = [$type; 8]; + } + + impl Simd8Lanes<$type> for [$type; 8] { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + v.try_into().unwrap() + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + a + } + + #[inline] + fn eq(self, other: Self) -> u8 { + set(self, other, |x, y| x == y) + } + + #[inline] + fn neq(self, other: Self) -> u8 { + #[allow(clippy::float_cmp)] + set(self, other, |x, y| x != y) + } + + #[inline] + fn lt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x <= y) + } + + #[inline] + fn lt(self, other: Self) -> u8 { + set(self, other, |x, y| x < y) + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x >= y) + } + + #[inline] + fn gt(self, other: Self) -> u8 { + set(self, other, |x, y| x > y) + } + } + }; +} + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +pub use packed::*; diff --git a/src/compute/comparison/simd/native.rs b/src/compute/comparison/simd/native.rs new file mode 100644 index 00000000000..bc3e820e60e --- /dev/null +++ b/src/compute/comparison/simd/native.rs @@ -0,0 +1,15 @@ +use std::convert::TryInto; + +use super::{set, Simd8, Simd8Lanes}; + +simd8_native!(u8); +simd8_native!(u16); +simd8_native!(u32); +simd8_native!(u64); +simd8_native!(i8); +simd8_native!(i16); +simd8_native!(i32); +simd8_native!(i128); +simd8_native!(i64); +simd8_native!(f32); +simd8_native!(f64); diff --git a/src/compute/comparison/simd/packed.rs b/src/compute/comparison/simd/packed.rs new file mode 100644 index 00000000000..1e9b250c73b --- /dev/null +++ b/src/compute/comparison/simd/packed.rs @@ -0,0 +1,69 @@ +use std::convert::TryInto; + +use super::{set, Simd8, Simd8Lanes}; + +use packed_simd::*; + +macro_rules! simd8 { + ($type:ty, $md:ty) => { + impl Simd8 for $type { + type Simd = $md; + } + + impl Simd8Lanes<$type> for $md { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + <$md>::from_slice_aligned(v) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + Self::from_chunk(a.as_ref()) + } + + #[inline] + fn eq(self, other: Self) -> u8 { + self.eq(other).bitmask() + } + + #[inline] + fn neq(self, other: Self) -> u8 { + self.ne(other).bitmask() + } + + #[inline] + fn lt_eq(self, other: Self) -> u8 { + self.le(other).bitmask() + } + + #[inline] + fn lt(self, other: Self) -> u8 { + self.lt(other).bitmask() + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + self.ge(other).bitmask() + } + + #[inline] + fn gt(self, other: Self) -> u8 { + self.gt(other).bitmask() + } + } + }; +} + +simd8!(u8, u8x8); +simd8!(u16, u16x8); +simd8!(u32, u32x8); +simd8!(u64, u64x8); +simd8!(i8, i8x8); +simd8!(i16, i16x8); +simd8!(i32, i32x8); +simd8!(i64, i64x8); +simd8_native!(i128); +simd8!(f32, f32x8); +simd8!(f64, f64x8); diff --git a/src/compute/nullif.rs b/src/compute/nullif.rs index 8d2022e85ea..1b41b9e4782 100644 --- a/src/compute/nullif.rs +++ b/src/compute/nullif.rs @@ -1,5 +1,5 @@ use crate::array::PrimitiveArray; -use crate::compute::comparison::primitive_compare_values_op; +use crate::compute::comparison::{primitive_compare_values_op, Simd8, Simd8Lanes}; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::{array::Array, types::NativeType}; @@ -29,7 +29,7 @@ use super::utils::combine_validities; /// This function errors iff /// * The arguments do not have the same logical type /// * The arguments do not have the same length -pub fn nullif_primitive( +pub fn nullif_primitive( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> Result> { @@ -39,7 +39,7 @@ pub fn nullif_primitive( )); } - let equal = primitive_compare_values_op(lhs.values(), rhs.values(), |lhs, rhs| lhs != rhs); + let equal = primitive_compare_values_op(lhs.values(), rhs.values(), |lhs, rhs| lhs.neq(rhs)); let equal = equal.into(); let validity = combine_validities(lhs.validity(), &equal); diff --git a/src/trusted_len.rs b/src/trusted_len.rs index 877a180c9a7..8c52c0bf861 100644 --- a/src/trusted_len.rs +++ b/src/trusted_len.rs @@ -26,6 +26,8 @@ where { } +unsafe impl TrustedLen for std::slice::ChunksExact<'_, T> {} + unsafe impl TrustedLen for std::slice::Windows<'_, T> {} unsafe impl TrustedLen for std::iter::Chain