Skip to content

Commit

Permalink
Add support for masked loads & stores
Browse files Browse the repository at this point in the history
  • Loading branch information
farnoy committed Feb 29, 2024
1 parent fbc9efa commit 3692383
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 0 deletions.
170 changes: 170 additions & 0 deletions crates/core_simd/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,131 @@ where
unsafe { self.store(slice.as_mut_ptr().cast()) }
}

/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
/// the `slice`. Otherwise, the default value for the element type is returned.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11];
///
/// let result = Simd::<i32, 4>::load_or_default(&vec);
/// assert_eq!(result, Simd::from_array([10, 11, 0, 0]));
/// ```
#[must_use]
#[inline]
pub fn load_or_default(slice: &[T]) -> Self
where
T: Default,
{
Self::load_or(slice, Default::default())
}

/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
/// the `slice`. Otherwise, the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11];
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_or(&vec, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, -2]));
/// ```
#[must_use]
#[inline]
pub fn load_or(slice: &[T], or: Self) -> Self {
Self::load_select(slice, Mask::splat(true), or)
}

/// Reads contiguous elements from `slice`. Each lane is read from memory if its
/// corresponding lane in `enable` is `true`.
///
/// When the lane is disabled or out of bounds for the slice, that memory location
/// is not accessed and the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
/// let enable = Mask::from_array([true, true, false, true]);
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_select(&vec, enable, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, 14]));
/// ```
#[must_use]
#[inline]
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
where
T: Default,
{
Self::load_select(slice, enable, Default::default())
}

/// Reads contiguous elements from `slice`. Each lane is read from memory if its
/// corresponding lane in `enable` is `true`.
///
/// When the lane is disabled or out of bounds for the slice, that memory location
/// is not accessed and the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
/// let enable = Mask::from_array([true, true, false, true]);
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_select(&vec, enable, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, 14]));
/// ```
#[must_use]
#[inline]
pub fn load_select(
slice: &[T],
mut enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
enable &= mask_up_to(enable, slice.len());
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
}

/// Reads contiguous elements from `slice`. Each lane is read from memory if its
/// corresponding lane in `enable` is `true`.
///
/// When the lane is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
#[must_use]
#[inline]
pub unsafe fn load_select_unchecked(
slice: &[T],
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
let ptr = slice.as_ptr();
unsafe { Self::load_select_ptr(ptr, enable, or) }
}

/// Reads contiguous elements starting at `ptr`. Each lane is read from memory if its
/// corresponding lane in `enable` is `true`.
///
/// When the lane is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
#[must_use]
#[inline]
pub unsafe fn load_select_ptr(
ptr: *const T,
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
}

/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
/// If an index is out-of-bounds, the element is instead selected from the `or` vector.
///
Expand Down Expand Up @@ -492,6 +617,27 @@ where
unsafe { core::intrinsics::simd::simd_gather(or, source, enable.to_int()) }
}

#[inline]
pub fn masked_store(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>) {
enable &= mask_up_to(enable, slice.len());
unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) }
}

#[inline]
pub unsafe fn masked_store_unchecked(
self,
slice: &mut [T],
enable: Mask<<T as SimdElement>::Mask, N>,
) {
let ptr = slice.as_mut_ptr();
unsafe { self.masked_store_ptr(ptr, enable) }
}

#[inline]
pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
}

/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
/// If an index is out-of-bounds, the write is suppressed without panicking.
/// If two elements in the scattered vector would write to the same index
Expand Down Expand Up @@ -979,3 +1125,27 @@ where
{
type Mask = isize;
}

#[inline]
fn lane_indices<T, const N: usize>() -> Simd<T, N>
where
T: MaskElement + Default + core::convert::From<i8> + core::ops::Add<T, Output = T>,
LaneCount<N>: SupportedLaneCount,
{
let mut index = [T::default(); N];
for i in 1..N {
index[i] = index[i - 1] + T::from(1);
}
Simd::from_array(index)
}

#[inline]
fn mask_up_to<M, const N: usize>(enable: Mask<M, N>, len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<i8, N>();
let lt = index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)));
enable & Mask::<M, N>::from_bitmask_vector(lt.to_bitmask_vector())
}
35 changes: 35 additions & 0 deletions crates/core_simd/tests/masked_load_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(portable_simd)]
use core_simd::simd::prelude::*;

#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

#[cfg(target_arch = "wasm32")]
wasm_bindgen_test_configure!(run_in_browser);

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn masked_load_store() {
let mut arr = [u8::MAX; 7];

u8x4::splat(0).masked_store(&mut arr[5..], Mask::from_array([false, true, false, true]));
// write to index 8 is OOB and dropped
assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]);

u8x4::from_array([0, 1, 2, 3]).masked_store(&mut arr[1..], Mask::splat(true));
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]);

// read from index 8 is OOB and dropped
assert_eq!(
u8x4::load_or(&arr[4..], u8x4::splat(42)),
u8x4::from_array([3, 255, 0, 42])
);
assert_eq!(
u8x4::load_select(
&arr[4..],
Mask::from_array([true, false, true, true]),
u8x4::splat(42)
),
u8x4::from_array([3, 42, 0, 42])
);
}

0 comments on commit 3692383

Please sign in to comment.