Skip to content

Commit

Permalink
Rework the masking logic, rename the functions
Browse files Browse the repository at this point in the history
  • Loading branch information
farnoy committed Nov 19, 2023
1 parent 66a5748 commit 34e54b4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
1 change: 1 addition & 0 deletions crates/core_simd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
const_maybe_uninit_as_mut_ptr,
const_mut_refs,
convert_float_to_int,
core_intrinsics,
decl_macro,
inline_const,
intra_doc_pointers,
Expand Down
122 changes: 109 additions & 13 deletions crates/core_simd/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::masks::{ToBitMask, ToBitMaskArray};
use crate::simd::{
cmp::SimdPartialOrd,
intrinsics,
prelude::SimdPartialEq,
ptr::{SimdConstPtr, SimdMutPtr},
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
};
Expand Down Expand Up @@ -314,48 +315,95 @@ where

#[must_use]
#[inline]
pub fn masked_load_or(slice: &[T], or: Self) -> Self
pub fn load_or_default(slice: &[T]) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
T: Default,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
Self::masked_load_select(slice, Mask::splat(true), or)
Self::load_or(slice, Default::default())
}

#[must_use]
#[inline]
pub fn masked_load_select(
slice: &[T],
mut enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self
pub fn load_or(slice: &[T], or: Self) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
enable &= {
Self::load_select(slice, Mask::splat(true), or)
}

#[must_use]
#[inline]
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
T: Default,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
Self::load_select(slice, enable, Default::default())
}

#[must_use]
#[inline]
pub fn load_select(slice: &[T], mut enable: Mask<<T as SimdElement>::Mask, N>, or: Self) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
if USE_BRANCH {
if core::intrinsics::likely(enable.all() && slice.len() > N) {
return Self::from_slice(slice);
}
}
enable &= if USE_BITMASK {
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
let len = in_bounds_arr.as_ref().len();
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
Mask::from_bitmask_array(in_bounds_arr)
} else {
mask_up_to(enable, slice.len())
};
unsafe { Self::masked_load_select_ptr(slice.as_ptr(), enable, or) }
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
}

#[must_use]
#[inline]
pub unsafe fn masked_load_select_unchecked(
pub unsafe fn load_select_unchecked(
slice: &[T],
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
let ptr = slice.as_ptr();
unsafe { Self::masked_load_select_ptr(ptr, enable, or) }
unsafe { Self::load_select_ptr(ptr, enable, or) }
}

#[must_use]
#[inline]
pub unsafe fn masked_load_select_ptr(
pub unsafe fn load_select_ptr(
ptr: *const T,
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
Expand Down Expand Up @@ -545,14 +593,28 @@ where
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>)
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
enable &= {
if USE_BRANCH {
if core::intrinsics::likely(enable.all() && slice.len() > N) {
return self.copy_to_slice(slice);
}
}
enable &= if USE_BITMASK {
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
let len = in_bounds_arr.as_ref().len();
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
Mask::from_bitmask_array(in_bounds_arr)
} else {
mask_up_to(enable, slice.len())
};
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
}
Expand Down Expand Up @@ -1058,9 +1120,43 @@ where
type Mask = isize;
}

const USE_BRANCH: bool = false;
const USE_BITMASK: bool = false;

#[inline]
fn index<T, const N: usize>() -> Simd<T, N>
where
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
LaneCount<N>: SupportedLaneCount,
{
let mut index = [T::default(); N];
for i in 1..N {
index[i] = index[i - 1] + T::from(1);
}
Simd::from_array(index)
}

#[inline]
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement + Default + core::convert::From<i8> + core::ops::Add<M, Output = M>,
Simd<M, N>: SimdPartialOrd,
// <Simd<M, N> as SimdPartialEq>::Mask: Mask<M, N>,
Mask<M, N>: core::ops::BitAnd<Output = Mask<M, N>>
+ core::convert::From<<Simd<M, N> as SimdPartialEq>::Mask>,
{
let index = index::<M, N>();
enable
& Mask::<M, N>::from(
index.simd_lt(Simd::splat(M::from(i8::try_from(len).unwrap_or(i8::MAX)))),
)
}

// This function matches the semantics of the `bzhi` instruction on x86 BMI2
// TODO: optimize it further if possible
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
#[inline(always)]
fn bzhi_u64(a: u64, ix: u32) -> u64 {
if ix > 63 {
a
Expand Down
4 changes: 2 additions & 2 deletions crates/core_simd/tests/masked_load_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ fn masked_load_store() {

// read from index 8 is OOB and dropped
assert_eq!(
u8x4::masked_load_or(&arr[4..], u8x4::splat(42)),
u8x4::load_or(&arr[4..], u8x4::splat(42)),
u8x4::from_array([3, 255, 0, 42])
);
assert_eq!(
u8x4::masked_load_select(
u8x4::load_select(
&arr[4..],
Mask::from_array([true, false, true, true]),
u8x4::splat(42)
Expand Down

0 comments on commit 34e54b4

Please sign in to comment.