Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for masked loads & stores #374

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/core_simd/src/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ extern "platform-intrinsic" {
/// like gather, but more spicy, as it writes instead of reads
pub(crate) fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);

/// like a loop of reads offset from the same pointer
/// val: vector of values to select if a lane is masked
/// ptr: vector of pointers to read from
/// mask: a "wide" mask of integers, selects as if simd_select(mask, read(ptr), val)
/// note, the LLVM intrinsic accepts a mask vector of `<N x i1>`
pub(crate) fn simd_masked_load<T, U, V>(val: T, ptr: U, mask: V) -> T;
/// like masked_load, but more spicy, as it writes instead of reads
pub(crate) fn simd_masked_store<T, U, V>(val: T, ptr: U, mask: V);

// {s,u}add.sat
pub(crate) fn simd_saturating_add<T>(x: T, y: T) -> T;

Expand Down
1 change: 1 addition & 0 deletions crates/core_simd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
const_maybe_uninit_as_mut_ptr,
const_mut_refs,
convert_float_to_int,
core_intrinsics,
decl_macro,
inline_const,
intra_doc_pointers,
Expand Down
190 changes: 190 additions & 0 deletions crates/core_simd/src/vector.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::masks::{ToBitMask, ToBitMaskArray};
use crate::simd::{
cmp::SimdPartialOrd,
intrinsics,
prelude::SimdPartialEq,
ptr::{SimdConstPtr, SimdMutPtr},
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
};
Expand Down Expand Up @@ -311,6 +313,104 @@ where
unsafe { self.store(slice.as_mut_ptr().cast()) }
}

#[must_use]
#[inline]
pub fn load_or_default(slice: &[T]) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
T: Default,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
Self::load_or(slice, Default::default())
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function does anything?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it's masking the slice length, but I'm not sure I would even call this a masked load. Also, I'm not sure this is the best way to load from slices on most architectures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's supposed to enable us to write extremely nice loops like this:

let mut accum = Simd::<u8, N>::default();
for i in (0..data.len()).step_by(N) {
    accum ^= Simd::masked_load_or(&data[i..], Simd::default());
}

No epilogues or scalar fallbacks needed. This could be even shorter since SimdElement implies Default?

But it's only a reality on AVX-512 for now on account of the codegen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would at least rename it--it might be implemented via masked loads, but API-wise I would consider it something else entirely. This is perhaps Simd::from_slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the rename -- I don't like that this function seems more discoverable than from_slice. Users should prefer from_slice if it fits their use-case. This isn't a problem with gather/scatter because that's a specific name with a more complex signature, it's hard to use it by accident.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_slice already has a branch on the length anyway, what I'm suggesting that perhaps from_slice should instead do a masked load for len < N:

assert!(
slice.len() >= Self::LEN,
"slice length must be at least the number of elements"
);


#[must_use]
#[inline]
pub fn load_or(slice: &[T], or: Self) -> Self
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: I need to test if this works with T=*const u8 and other pointers. These trait bounds are involved, at least for now, but they shouldn't require T: Default for this and other functions that accept an or: Self

where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
Self::load_select(slice, Mask::splat(true), or)
}

#[must_use]
#[inline]
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
T: Default,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
Self::load_select(slice, enable, Default::default())
}

#[must_use]
#[inline]
pub fn load_select(slice: &[T], mut enable: Mask<<T as SimdElement>::Mask, N>, or: Self) -> Self
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
if USE_BRANCH {
if core::intrinsics::likely(enable.all() && slice.len() > N) {
return Self::from_slice(slice);
}
}
enable &= if USE_BITMASK {
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
let len = in_bounds_arr.as_ref().len();
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
Mask::from_bitmask_array(in_bounds_arr)
} else {
mask_up_to(enable, slice.len())
};
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
}

#[must_use]
#[inline]
pub unsafe fn load_select_unchecked(
slice: &[T],
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
let ptr = slice.as_ptr();
unsafe { Self::load_select_ptr(ptr, enable, or) }
}

#[must_use]
#[inline]
pub unsafe fn load_select_ptr(
ptr: *const T,
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
calebzulawski marked this conversation as resolved.
Show resolved Hide resolved
unsafe { intrinsics::simd_masked_load(or, ptr, enable.to_int()) }
}

/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
/// If an index is out-of-bounds, the element is instead selected from the `or` vector.
///
Expand Down Expand Up @@ -489,6 +589,51 @@ where
unsafe { intrinsics::simd_gather(or, source, enable.to_int()) }
}

#[inline]
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>)
where
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray,
<T as SimdElement>::Mask: Default
+ core::convert::From<i8>
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>,
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd,
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>>
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>,
{
if USE_BRANCH {
if core::intrinsics::likely(enable.all() && slice.len() > N) {
return self.copy_to_slice(slice);
}
}
enable &= if USE_BITMASK {
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32);
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) };
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array();
let len = in_bounds_arr.as_ref().len();
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]);
Mask::from_bitmask_array(in_bounds_arr)
} else {
mask_up_to(enable, slice.len())
};
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
}

#[inline]
pub unsafe fn masked_store_unchecked(
self,
slice: &mut [T],
enable: Mask<<T as SimdElement>::Mask, N>,
) {
let ptr = slice.as_mut_ptr();
unsafe { self.masked_store_ptr(ptr, enable) }
}

#[inline]
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
unsafe { intrinsics::simd_masked_store(self, ptr, enable.to_int()) }
}

/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
/// If an index is out-of-bounds, the write is suppressed without panicking.
/// If two elements in the scattered vector would write to the same index
Expand Down Expand Up @@ -974,3 +1119,48 @@ where
{
type Mask = isize;
}

const USE_BRANCH: bool = false;
const USE_BITMASK: bool = false;

#[inline]
fn index<T, const N: usize>() -> Simd<T, N>
where
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
LaneCount<N>: SupportedLaneCount,
{
let mut index = [T::default(); N];
for i in 1..N {
index[i] = index[i - 1] + T::from(1);
}
Simd::from_array(index)
}

#[inline]
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement + Default + core::convert::From<i8> + core::ops::Add<M, Output = M>,
Simd<M, N>: SimdPartialOrd,
// <Simd<M, N> as SimdPartialEq>::Mask: Mask<M, N>,
Mask<M, N>: core::ops::BitAnd<Output = Mask<M, N>>
+ core::convert::From<<Simd<M, N> as SimdPartialEq>::Mask>,
{
let index = index::<M, N>();
enable
& Mask::<M, N>::from(
index.simd_lt(Simd::splat(M::from(i8::try_from(len).unwrap_or(i8::MAX)))),
)
}

// This function matches the semantics of the `bzhi` instruction on x86 BMI2
// TODO: optimize it further if possible
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform
#[inline(always)]
fn bzhi_u64(a: u64, ix: u32) -> u64 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we made ix a u64, this would remove the need to do std::cmp::min(N, slice.len() as u32) in call-sites above.

Which is doubly important if we take it further:

extern "C" {
    #[link_name = "llvm.x86.bmi.bzhi.64"]
    fn x86_bmi2_bzhi_64(x: u64, y: u64) -> u64;
}

#[inline(always)]
fn bzhi_u64(a: u64, ix: u64) -> u64 {
    #[cfg(target_feature = "bmi2")]
    unsafe {
        return x86_bmi2_bzhi_64(a, ix);
    }

    if ix > 63 {
        a
    } else {
        a & (1u64 << ix) - 1
    }
}

For whatever reason the intel intrinsic of _bzhi_u64 takes a ix: u32, also in the C/C++ versions. Even a 1u64 as u32 generates a redundant register mov to a 32-bit register, clearing the high bits. All of this is truly unnecessary because the instruction operates on 64b registers.

It's the difference between

        cmp rsi, 64
        mov r9d, 64
        cmovb r9, rsi
        bzhi r9, r8, r9

In the current pure-Rust version,

       mov r9d, esi
       bzhi r9, rdx, r9

When calling the core::arch::x86_64::_bzhi_u64 intrinsic with slice.len() as u32 (which is incorrect for slices longer than u32::MAX).

        bzhi r9, rdx, rsi

When you use the above code with #![feature(llvm_link_intrinsics)]

The layers of mismatched semantics are working against us

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, the version without min(N, slice.len()) is wrong because the instruction only looks at the lowest 8 bits. So it effectively does a modulo 256 on the second operand.

So x86_bmi2_bzhi_64(u64::MAX, 525) is equal to 0x1FFF, 13 low bits set, because 525 % 256 = 13

if ix > 63 {
a
} else {
a & (1u64 << ix) - 1
}
}
35 changes: 35 additions & 0 deletions crates/core_simd/tests/masked_load_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(portable_simd)]
use core_simd::simd::prelude::*;

#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

#[cfg(target_arch = "wasm32")]
wasm_bindgen_test_configure!(run_in_browser);

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn masked_load_store() {
let mut arr = [u8::MAX; 7];

u8x4::splat(0).masked_store(&mut arr[5..], Mask::from_array([false, true, false, true]));
// write to index 8 is OOB and dropped
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the stack array larger to prove that the store doesn't spill beyond a smaller slice

assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]);

u8x4::from_array([0, 1, 2, 3]).masked_store(&mut arr[1..], Mask::splat(true));
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]);

// read from index 8 is OOB and dropped
assert_eq!(
u8x4::load_or(&arr[4..], u8x4::splat(42)),
u8x4::from_array([3, 255, 0, 42])
);
assert_eq!(
u8x4::load_select(
&arr[4..],
Mask::from_array([true, false, true, true]),
u8x4::splat(42)
),
u8x4::from_array([3, 42, 0, 42])
);
}