diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index dcec336cfaf..11d7288eccb 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -15,6 +15,9 @@ mod mask_impl; mod to_bitmask; pub use to_bitmask::ToBitMask; +#[cfg(feature = "generic_const_exprs")] +pub use to_bitmask::{bitmask_len, ToBitMaskArray}; + use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SimdPartialEq, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index ec4dd357ee9..365ecc0a325 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -115,6 +115,26 @@ where unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) } } + #[cfg(feature = "generic_const_exprs")] + #[inline] + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn to_bitmask_array(self) -> [u8; N] { + assert!(core::mem::size_of::() == N); + + // Safety: converting an integer to an array of bytes of the same size is safe + unsafe { core::mem::transmute_copy(&self.0) } + } + + #[cfg(feature = "generic_const_exprs")] + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask_array(bitmask: [u8; N]) -> Self { + assert!(core::mem::size_of::() == N); + + // Safety: converting an array of bytes to an integer of the same size is safe + Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData) + } + #[inline] pub fn to_bitmask_integer(self) -> U where diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index efa688b128f..adf0fcbeae2 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -4,6 +4,9 @@ use super::MaskElement; use crate::simd::intrinsics; use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; +#[cfg(feature = "generic_const_exprs")] +use crate::simd::ToBitMaskArray; + #[repr(transparent)] pub struct Mask(Simd) where @@ -139,6 +142,68 @@ where unsafe { Mask(intrinsics::simd_cast(self.0)) } } + #[cfg(feature = "generic_const_exprs")] + #[inline] + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn to_bitmask_array(self) -> [u8; N] + where + super::Mask: ToBitMaskArray, + [(); as ToBitMaskArray>::BYTES]: Sized, + { + assert_eq!( as ToBitMaskArray>::BYTES, N); + + // Safety: N is the correct bitmask size + unsafe { + // Compute the bitmask + let bitmask: [u8; as ToBitMaskArray>::BYTES] = + intrinsics::simd_bitmask(self.0); + + // Transmute to the return type, previously asserted to be the same size + let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask); + + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + for x in bitmask.as_mut() { + *x = x.reverse_bits(); + } + }; + + bitmask + } + } + + #[cfg(feature = "generic_const_exprs")] + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask_array(mut bitmask: [u8; N]) -> Self + where + super::Mask: ToBitMaskArray, + [(); as ToBitMaskArray>::BYTES]: Sized, + { + assert_eq!( as ToBitMaskArray>::BYTES, N); + + // Safety: N is the correct bitmask size + unsafe { + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + for x in bitmask.as_mut() { + *x = x.reverse_bits(); + } + } + + // Transmute to the bitmask type, previously asserted to be the same size + let bitmask: [u8; as ToBitMaskArray>::BYTES] = + core::mem::transmute_copy(&bitmask); + + // Compute the regular mask + Self::from_int_unchecked(intrinsics::simd_select_bitmask( + bitmask, + Self::splat(true).to_int(), + Self::splat(false).to_int(), + )) + } + } + #[inline] pub(crate) fn to_bitmask_integer(self) -> U where diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs index c263f6a4eec..954f88ea511 100644 --- a/crates/core_simd/src/masks/to_bitmask.rs +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -31,6 +31,25 @@ pub unsafe trait ToBitMask: Sealed { fn from_bitmask(bitmask: Self::BitMask) -> Self; } +/// Converts masks to and from byte array bitmasks. +/// +/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte. +/// +/// # Safety +/// This trait is `unsafe` and sealed, since the `BYTES` value must match the number of lanes in +/// the mask. +#[cfg(feature = "generic_const_exprs")] +pub unsafe trait ToBitMaskArray: Sealed { + /// The length of the bitmask array. + const BYTES: usize; + + /// Converts a mask to a bitmask. + fn to_bitmask_array(self) -> [u8; Self::BYTES]; + + /// Converts a bitmask to a mask. + fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self; +} + macro_rules! impl_integer_intrinsic { { $(unsafe impl ToBitMask for Mask<_, $lanes:literal>)* } => { $( @@ -58,3 +77,25 @@ impl_integer_intrinsic! { unsafe impl ToBitMask for Mask<_, 32> unsafe impl ToBitMask for Mask<_, 64> } + +/// Returns the minimum numnber of bytes in a bitmask with `lanes` lanes. +#[cfg(feature = "generic_const_exprs")] +pub const fn bitmask_len(lanes: usize) -> usize { + (lanes + 7) / 8 +} + +#[cfg(feature = "generic_const_exprs")] +unsafe impl ToBitMaskArray for Mask +where + LaneCount: SupportedLaneCount, +{ + const BYTES: usize = bitmask_len(LANES); + + fn to_bitmask_array(self) -> [u8; Self::BYTES] { + self.0.to_bitmask_array() + } + + fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self { + Mask(mask_impl::Mask::from_bitmask_array(bitmask)) + } +} diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 3a0493d4ee6..673d0db93fe 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -122,6 +122,20 @@ macro_rules! test_mask_api { cast_impl::(); cast_impl::(); } + + #[cfg(feature = "generic_const_exprs")] + #[test] + fn roundtrip_bitmask_array_conversion() { + use core_simd::ToBitMaskArray; + let values = [ + true, false, false, true, false, false, true, false, + true, true, false, false, false, false, false, true, + ]; + let mask = core_simd::Mask::<$type, 16>::from_array(values); + let bitmask = mask.to_bitmask_array(); + assert_eq!(bitmask, [0b01001001, 0b10000011]); + assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask_array(bitmask), mask); + } } } }