-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for masked loads & stores #374
Changes from 4 commits
3492364
f182fa7
66a5748
34e54b4
f0b4f26
3ed9d15
470e711
a4ae456
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
use super::masks::{ToBitMask, ToBitMaskArray}; | ||
use crate::simd::{ | ||
cmp::SimdPartialOrd, | ||
intrinsics, | ||
prelude::SimdPartialEq, | ||
ptr::{SimdConstPtr, SimdMutPtr}, | ||
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle, | ||
}; | ||
|
@@ -311,6 +313,104 @@ where | |
unsafe { self.store(slice.as_mut_ptr().cast()) } | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub fn load_or_default(slice: &[T]) -> Self | ||
where | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
T: Default, | ||
<T as SimdElement>::Mask: Default | ||
+ core::convert::From<i8> | ||
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>, | ||
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd, | ||
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>> | ||
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>, | ||
{ | ||
Self::load_or(slice, Default::default()) | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub fn load_or(slice: &[T], or: Self) -> Self | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: I need to test if this works with |
||
where | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
<T as SimdElement>::Mask: Default | ||
+ core::convert::From<i8> | ||
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>, | ||
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd, | ||
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>> | ||
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>, | ||
{ | ||
Self::load_select(slice, Mask::splat(true), or) | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self | ||
where | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
T: Default, | ||
<T as SimdElement>::Mask: Default | ||
+ core::convert::From<i8> | ||
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>, | ||
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd, | ||
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>> | ||
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>, | ||
{ | ||
Self::load_select(slice, enable, Default::default()) | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub fn load_select(slice: &[T], mut enable: Mask<<T as SimdElement>::Mask, N>, or: Self) -> Self | ||
where | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
<T as SimdElement>::Mask: Default | ||
+ core::convert::From<i8> | ||
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>, | ||
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd, | ||
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>> | ||
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>, | ||
{ | ||
if USE_BRANCH { | ||
if core::intrinsics::likely(enable.all() && slice.len() > N) { | ||
return Self::from_slice(slice); | ||
} | ||
} | ||
enable &= if USE_BITMASK { | ||
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); | ||
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; | ||
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); | ||
let len = in_bounds_arr.as_ref().len(); | ||
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); | ||
Mask::from_bitmask_array(in_bounds_arr) | ||
} else { | ||
mask_up_to(enable, slice.len()) | ||
}; | ||
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) } | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub unsafe fn load_select_unchecked( | ||
slice: &[T], | ||
enable: Mask<<T as SimdElement>::Mask, N>, | ||
or: Self, | ||
) -> Self { | ||
let ptr = slice.as_ptr(); | ||
unsafe { Self::load_select_ptr(ptr, enable, or) } | ||
} | ||
|
||
#[must_use] | ||
#[inline] | ||
pub unsafe fn load_select_ptr( | ||
ptr: *const T, | ||
enable: Mask<<T as SimdElement>::Mask, N>, | ||
or: Self, | ||
) -> Self { | ||
calebzulawski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unsafe { intrinsics::simd_masked_load(or, ptr, enable.to_int()) } | ||
} | ||
|
||
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector. | ||
/// If an index is out-of-bounds, the element is instead selected from the `or` vector. | ||
/// | ||
|
@@ -489,6 +589,51 @@ where | |
unsafe { intrinsics::simd_gather(or, source, enable.to_int()) } | ||
} | ||
|
||
#[inline] | ||
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>) | ||
where | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
Mask<<T as SimdElement>::Mask, N>: ToBitMask + ToBitMaskArray, | ||
<T as SimdElement>::Mask: Default | ||
+ core::convert::From<i8> | ||
+ core::ops::Add<<T as SimdElement>::Mask, Output = <T as SimdElement>::Mask>, | ||
Simd<<T as SimdElement>::Mask, N>: SimdPartialOrd, | ||
Mask<<T as SimdElement>::Mask, N>: core::ops::BitAnd<Output = Mask<<T as SimdElement>::Mask, N>> | ||
+ core::convert::From<<Simd<<T as SimdElement>::Mask, N> as SimdPartialEq>::Mask>, | ||
{ | ||
if USE_BRANCH { | ||
if core::intrinsics::likely(enable.all() && slice.len() > N) { | ||
return self.copy_to_slice(slice); | ||
} | ||
} | ||
enable &= if USE_BITMASK { | ||
let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); | ||
let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; | ||
let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); | ||
let len = in_bounds_arr.as_ref().len(); | ||
in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); | ||
Mask::from_bitmask_array(in_bounds_arr) | ||
} else { | ||
mask_up_to(enable, slice.len()) | ||
}; | ||
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) } | ||
} | ||
|
||
#[inline] | ||
pub unsafe fn masked_store_unchecked( | ||
self, | ||
slice: &mut [T], | ||
enable: Mask<<T as SimdElement>::Mask, N>, | ||
) { | ||
let ptr = slice.as_mut_ptr(); | ||
unsafe { self.masked_store_ptr(ptr, enable) } | ||
} | ||
|
||
#[inline] | ||
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) { | ||
unsafe { intrinsics::simd_masked_store(self, ptr, enable.to_int()) } | ||
} | ||
|
||
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`. | ||
/// If an index is out-of-bounds, the write is suppressed without panicking. | ||
/// If two elements in the scattered vector would write to the same index | ||
|
@@ -974,3 +1119,48 @@ where | |
{ | ||
type Mask = isize; | ||
} | ||
|
||
const USE_BRANCH: bool = false; | ||
const USE_BITMASK: bool = false; | ||
|
||
#[inline] | ||
fn index<T, const N: usize>() -> Simd<T, N> | ||
where | ||
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>, | ||
LaneCount<N>: SupportedLaneCount, | ||
{ | ||
let mut index = [T::default(); N]; | ||
for i in 1..N { | ||
index[i] = index[i - 1] + T::from(1); | ||
} | ||
Simd::from_array(index) | ||
} | ||
|
||
#[inline] | ||
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N> | ||
where | ||
LaneCount<N>: SupportedLaneCount, | ||
M: MaskElement + Default + core::convert::From<i8> + core::ops::Add<M, Output = M>, | ||
Simd<M, N>: SimdPartialOrd, | ||
// <Simd<M, N> as SimdPartialEq>::Mask: Mask<M, N>, | ||
Mask<M, N>: core::ops::BitAnd<Output = Mask<M, N>> | ||
+ core::convert::From<<Simd<M, N> as SimdPartialEq>::Mask>, | ||
{ | ||
let index = index::<M, N>(); | ||
enable | ||
& Mask::<M, N>::from( | ||
index.simd_lt(Simd::splat(M::from(i8::try_from(len).unwrap_or(i8::MAX)))), | ||
) | ||
} | ||
|
||
// This function matches the semantics of the `bzhi` instruction on x86 BMI2 | ||
// TODO: optimize it further if possible | ||
// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform | ||
#[inline(always)] | ||
fn bzhi_u64(a: u64, ix: u32) -> u64 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we made Which is doubly important if we take it further: extern "C" {
#[link_name = "llvm.x86.bmi.bzhi.64"]
fn x86_bmi2_bzhi_64(x: u64, y: u64) -> u64;
}
#[inline(always)]
fn bzhi_u64(a: u64, ix: u64) -> u64 {
#[cfg(target_feature = "bmi2")]
unsafe {
return x86_bmi2_bzhi_64(a, ix);
}
if ix > 63 {
a
} else {
a & (1u64 << ix) - 1
}
} For whatever reason the intel intrinsic of It's the difference between cmp rsi, 64
mov r9d, 64
cmovb r9, rsi
bzhi r9, r8, r9
In the current pure-Rust version, mov r9d, esi
bzhi r9, rdx, r9 When calling the bzhi r9, rdx, rsi When you use the above code with The layers of mismatched semantics are working against us There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind, the version without So |
||
if ix > 63 { | ||
a | ||
} else { | ||
a & (1u64 << ix) - 1 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#![feature(portable_simd)] | ||
use core_simd::simd::prelude::*; | ||
|
||
#[cfg(target_arch = "wasm32")] | ||
use wasm_bindgen_test::*; | ||
|
||
#[cfg(target_arch = "wasm32")] | ||
wasm_bindgen_test_configure!(run_in_browser); | ||
|
||
#[test] | ||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
fn masked_load_store() { | ||
let mut arr = [u8::MAX; 7]; | ||
|
||
u8x4::splat(0).masked_store(&mut arr[5..], Mask::from_array([false, true, false, true])); | ||
// write to index 8 is OOB and dropped | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make the stack array larger to prove that the store doesn't spill beyond a smaller slice |
||
assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]); | ||
|
||
u8x4::from_array([0, 1, 2, 3]).masked_store(&mut arr[1..], Mask::splat(true)); | ||
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]); | ||
|
||
// read from index 8 is OOB and dropped | ||
assert_eq!( | ||
u8x4::load_or(&arr[4..], u8x4::splat(42)), | ||
u8x4::from_array([3, 255, 0, 42]) | ||
); | ||
assert_eq!( | ||
u8x4::load_select( | ||
&arr[4..], | ||
Mask::from_array([true, false, true, true]), | ||
u8x4::splat(42) | ||
), | ||
u8x4::from_array([3, 42, 0, 42]) | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this function does anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose it's masking the slice length, but I'm not sure I would even call this a masked load. Also, I'm not sure this is the best way to load from slices on most architectures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's supposed to enable us to write extremely nice loops like this:
No epilogues or scalar fallbacks needed. This could be even shorter since
SimdElement
impliesDefault
?But it's only a reality on AVX-512 for now on account of the codegen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would at least rename it--it might be implemented via masked loads, but API-wise I would consider it something else entirely. This is perhaps
Simd::from_slice
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the rename -- I don't like that this function seems more discoverable than from_slice. Users should prefer
from_slice
if it fits their use-case. This isn't a problem with gather/scatter because that's a specific name with a more complex signature, it's hard to use it by accident.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_slice
already has a branch on the length anyway, what I'm suggesting that perhapsfrom_slice
should instead do a masked load forlen < N
:portable-simd/crates/core_simd/src/vector.rs
Lines 280 to 283 in 64ea088