diff --git a/src/compute/aggregate/sum.rs b/src/compute/aggregate/sum.rs index bdc40f2f94c..d05493b8d30 100644 --- a/src/compute/aggregate/sum.rs +++ b/src/compute/aggregate/sum.rs @@ -26,10 +26,7 @@ where T: NativeType + Simd + Add + std::iter::Sum, T::Simd: Sum + Add, { - // Safety: - // T::Simd is the vector type T and the alignment is similar to aligning to [T; alignment] - // the alignment of T::Simd ensures that it fits T. - let (head, simd_vals, tail) = unsafe { values.align_to::() }; + let (head, simd_vals, tail) = T::Simd::align(values); let mut reduced = T::Simd::from_incomplete_chunk(&[], T::default()); for chunk in simd_vals { diff --git a/src/types/simd/mod.rs b/src/types/simd/mod.rs index 36fe065048c..eabbe00f640 100644 --- a/src/types/simd/mod.rs +++ b/src/types/simd/mod.rs @@ -10,7 +10,10 @@ pub trait FromMaskChunk { } /// A struct that lends itself well to be compiled leveraging SIMD -pub trait NativeSimd: Default + Copy { +/// # Safety +/// The `NativeType` and the `NativeSimd` must have possible a matching alignment. +/// e.g. slicing `&[NativeType]` by `align_of()` must be properly aligned/safe. +pub unsafe trait NativeSimd: Default + Copy { /// Number of lanes const LANES: usize; /// The [`NativeType`] of this struct. E.g. `f32` for a `NativeSimd = f32x16`. @@ -32,6 +35,8 @@ pub trait NativeSimd: Default + Copy { /// Items from `v` at positions larger than the number of lanes are ignored; /// remaining items are populated with `remaining`. fn from_incomplete_chunk(v: &[Self::Native], remaining: Self::Native) -> Self; + + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]); } /// Trait implemented by some [`NativeType`] that have a SIMD representation. diff --git a/src/types/simd/native.rs b/src/types/simd/native.rs index 9f62845ecf0..1ab1f748c92 100644 --- a/src/types/simd/native.rs +++ b/src/types/simd/native.rs @@ -9,7 +9,7 @@ macro_rules! simd { #[derive(Copy, Clone)] pub struct $name(pub [$type; $lanes]); - impl NativeSimd for $name { + unsafe impl NativeSimd for $name { const LANES: usize = $lanes; type Native = $type; type Chunk = $mask; @@ -36,6 +36,11 @@ macro_rules! simd { a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); Self(a) } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } } impl std::ops::Index for $name { diff --git a/src/types/simd/packed.rs b/src/types/simd/packed.rs index b8bb35d9806..160d47ecc35 100644 --- a/src/types/simd/packed.rs +++ b/src/types/simd/packed.rs @@ -7,7 +7,7 @@ use super::*; macro_rules! simd { ($name:tt, $type:ty, $lanes:expr, $chunk:ty, $mask:tt) => { - impl NativeSimd for $name { + unsafe impl NativeSimd for $name { const LANES: usize = $lanes; type Native = $type; type Chunk = $chunk; @@ -29,6 +29,11 @@ macro_rules! simd { a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); <$name>::from_chunk(a.as_ref()) } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } } }; }