Skip to content

Commit

Permalink
Workaround simd_bitmask limitations
Browse files Browse the repository at this point in the history
  • Loading branch information
calebzulawski committed Nov 17, 2023
1 parent 4ca9f04 commit 082e3c8
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 25 deletions.
90 changes: 79 additions & 11 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,40 +207,108 @@ where
}

#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
let resized = self.to_int().extend::<64>(T::FALSE);
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
where
LaneCount<M>: SupportedLaneCount,
{
let resized = self.to_int().resize::<M>(T::FALSE);

// SAFETY: `resized` is an integer vector with length 64
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };
// Safety: `resized` is an integer vector with length M, which must match T
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits()
bitmask.reverse_bits(M)
} else {
bitmask
}
}

#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
where
LaneCount<M>: SupportedLaneCount,
{
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits()
bitmask.reverse_bits(M)
} else {
bitmask
};

// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, 64> = unsafe {
let mask: Simd<T, M> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Simd::<T, 64>::splat(T::TRUE),
Simd::<T, 64>::splat(T::FALSE),
Simd::<T, M>::splat(T::TRUE),
Simd::<T, M>::splat(T::FALSE),
)
};

// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
}

#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
)*)*
// Safety: bitmask matches length
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
)*)*
// Safety: bitmask matches length
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

#[inline]
Expand Down
16 changes: 8 additions & 8 deletions crates/core_simd/src/swizzle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ where
)
}

/// Extend a vector.
/// Resize a vector.
///
/// Extends the length of a vector, setting the new elements to `value`.
/// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
/// If `M` < `N`, truncates the vector to the first `M` elements.
///
/// ```
Expand All @@ -361,17 +361,17 @@ where
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::u32x4;
/// let x = u32x4::from_array([0, 1, 2, 3]);
/// assert_eq!(x.extend::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
/// assert_eq!(x.extend::<2>(9).to_array(), [0, 1]);
/// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
/// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn extend<const M: usize>(self, value: T) -> Simd<T, M>
pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
where
LaneCount<M>: SupportedLaneCount,
{
struct Extend<const N: usize>;
impl<const N: usize, const M: usize> Swizzle<M> for Extend<N> {
struct Resize<const N: usize>;
impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
const INDEX: [usize; M] = const {
let mut index = [0; M];
let mut i = 0;
Expand All @@ -382,6 +382,6 @@ where
index
};
}
Extend::<N>::concat_swizzle(self, Simd::splat(value))
Resize::<N>::concat_swizzle(self, Simd::splat(value))
}
}
9 changes: 3 additions & 6 deletions crates/core_simd/tests/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ macro_rules! test_mask_api {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

use core_simd::simd::{Mask, Simd};
use core_simd::simd::Mask;

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
Expand Down Expand Up @@ -124,17 +124,14 @@ macro_rules! test_mask_api {

#[test]
fn roundtrip_bitmask_vector_conversion() {
use core_simd::simd::ToBytes;
let values = [
true, false, false, true, false, false, true, false,
true, true, false, false, false, false, false, true,
];
let mask = Mask::<$type, 16>::from_array(values);
let bitmask = mask.to_bitmask_vector();
if core::mem::size_of::<$type>() == 1 {
assert_eq!(bitmask, Simd::from_array([0b01001001 as _, 0b10000011 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]));
} else {
assert_eq!(bitmask, Simd::from_array([0b1000001101001001 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]));
}
assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
}
}
Expand Down

0 comments on commit 082e3c8

Please sign in to comment.