diff --git a/crates/core_simd/src/lib.rs b/crates/core_simd/src/lib.rs index 64ba9705ef5..e974e7aa25a 100644 --- a/crates/core_simd/src/lib.rs +++ b/crates/core_simd/src/lib.rs @@ -4,6 +4,7 @@ const_maybe_uninit_as_mut_ptr, const_mut_refs, convert_float_to_int, + core_intrinsics, decl_macro, inline_const, intra_doc_pointers, diff --git a/crates/core_simd/src/vector.rs b/crates/core_simd/src/vector.rs index bcd4ddcf69d..e48b8931db6 100644 --- a/crates/core_simd/src/vector.rs +++ b/crates/core_simd/src/vector.rs @@ -2,6 +2,7 @@ use super::masks::{ToBitMask, ToBitMaskArray}; use crate::simd::{ cmp::SimdPartialOrd, intrinsics, + prelude::SimdPartialEq, ptr::{SimdConstPtr, SimdMutPtr}, LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle, }; @@ -314,48 +315,95 @@ where #[must_use] #[inline] - pub fn masked_load_or(slice: &[T], or: Self) -> Self + pub fn load_or_default(slice: &[T]) -> Self where Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + T: Default, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From<::Mask, N> as SimdPartialEq>::Mask>, { - Self::masked_load_select(slice, Mask::splat(true), or) + Self::load_or(slice, Default::default()) } #[must_use] #[inline] - pub fn masked_load_select( - slice: &[T], - mut enable: Mask<::Mask, N>, - or: Self, - ) -> Self + pub fn load_or(slice: &[T], or: Self) -> Self where Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From<::Mask, N> as SimdPartialEq>::Mask>, { - enable &= { + Self::load_select(slice, Mask::splat(true), or) + } + + #[must_use] + #[inline] + pub fn load_select_or_default(slice: &[T], enable: Mask<::Mask, N>) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + T: Default, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From<::Mask, N> as SimdPartialEq>::Mask>, + { + Self::load_select(slice, enable, Default::default()) + } + + #[must_use] + #[inline] + pub fn load_select(slice: &[T], mut enable: Mask<::Mask, N>, or: Self) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From<::Mask, N> as SimdPartialEq>::Mask>, + { + if USE_BRANCH { + if core::intrinsics::likely(enable.all() && slice.len() > N) { + return Self::from_slice(slice); + } + } + enable &= if USE_BITMASK { let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); let len = in_bounds_arr.as_ref().len(); in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); Mask::from_bitmask_array(in_bounds_arr) + } else { + mask_up_to(enable, slice.len()) }; - unsafe { Self::masked_load_select_ptr(slice.as_ptr(), enable, or) } + unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) } } #[must_use] #[inline] - pub unsafe fn masked_load_select_unchecked( + pub unsafe fn load_select_unchecked( slice: &[T], enable: Mask<::Mask, N>, or: Self, ) -> Self { let ptr = slice.as_ptr(); - unsafe { Self::masked_load_select_ptr(ptr, enable, or) } + unsafe { Self::load_select_ptr(ptr, enable, or) } } #[must_use] #[inline] - pub unsafe fn masked_load_select_ptr( + pub unsafe fn load_select_ptr( ptr: *const T, enable: Mask<::Mask, N>, or: Self, @@ -545,14 +593,28 @@ where pub fn masked_store(self, slice: &mut [T], mut enable: Mask<::Mask, N>) where Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From<::Mask, N> as SimdPartialEq>::Mask>, { - enable &= { + if USE_BRANCH { + if core::intrinsics::likely(enable.all() && slice.len() > N) { + return self.copy_to_slice(slice); + } + } + enable &= if USE_BITMASK { let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); let len = in_bounds_arr.as_ref().len(); in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); Mask::from_bitmask_array(in_bounds_arr) + } else { + mask_up_to(enable, slice.len()) }; unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) } } @@ -1058,9 +1120,43 @@ where type Mask = isize; } +const USE_BRANCH: bool = false; +const USE_BITMASK: bool = false; + +#[inline] +fn index() -> Simd +where + T: MaskElement + Default + core::convert::From + core::ops::Add, + LaneCount: SupportedLaneCount, +{ + let mut index = [T::default(); N]; + for i in 1..N { + index[i] = index[i - 1] + T::from(1); + } + Simd::from_array(index) +} + +#[inline] +fn mask_up_to(enable: Mask, len: usize) -> Mask +where + LaneCount: SupportedLaneCount, + M: MaskElement + Default + core::convert::From + core::ops::Add, + Simd: SimdPartialOrd, + // as SimdPartialEq>::Mask: Mask, + Mask: core::ops::BitAnd> + + core::convert::From< as SimdPartialEq>::Mask>, +{ + let index = index::(); + enable + & Mask::::from( + index.simd_lt(Simd::splat(M::from(i8::try_from(len).unwrap_or(i8::MAX)))), + ) +} + // This function matches the semantics of the `bzhi` instruction on x86 BMI2 // TODO: optimize it further if possible // https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform +#[inline(always)] fn bzhi_u64(a: u64, ix: u32) -> u64 { if ix > 63 { a diff --git a/crates/core_simd/tests/masked_load_store.rs b/crates/core_simd/tests/masked_load_store.rs index 374b5c3b728..e830330249c 100644 --- a/crates/core_simd/tests/masked_load_store.rs +++ b/crates/core_simd/tests/masked_load_store.rs @@ -21,11 +21,11 @@ fn masked_load_store() { // read from index 8 is OOB and dropped assert_eq!( - u8x4::masked_load_or(&arr[4..], u8x4::splat(42)), + u8x4::load_or(&arr[4..], u8x4::splat(42)), u8x4::from_array([3, 255, 0, 42]) ); assert_eq!( - u8x4::masked_load_select( + u8x4::load_select( &arr[4..], Mask::from_array([true, false, true, true]), u8x4::splat(42)