From 842ac87747c4a6f8002ada6bab04d97320d206fc Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Thu, 13 Jan 2022 21:20:17 -0500 Subject: [PATCH 1/3] Use bitmask trait --- crates/core_simd/src/masks.rs | 22 ++----- crates/core_simd/src/masks/bitmask.rs | 12 +--- crates/core_simd/src/masks/full_masks.rs | 35 ++--------- crates/core_simd/src/masks/to_bitmask.rs | 78 ++++++++++++++++++++++++ crates/core_simd/tests/masks.rs | 6 +- 5 files changed, 93 insertions(+), 60 deletions(-) create mode 100644 crates/core_simd/src/masks/to_bitmask.rs diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index ae1fef53da8..22514728ffa 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -12,8 +12,10 @@ )] mod mask_impl; -use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount}; +mod to_bitmask; +pub use to_bitmask::ToBitMask; + +use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; @@ -216,22 +218,6 @@ where } } - /// Convert this mask to a bitmask, with one bit set per lane. - #[cfg(feature = "generic_const_exprs")] - #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - self.0.to_bitmask() - } - - /// Convert a bitmask to a mask. - #[cfg(feature = "generic_const_exprs")] - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { - Self(mask_impl::Mask::from_bitmask(bitmask)) - } - /// Returns true if any lane is set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index b4217dc87ba..f20f83ecb38 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -115,20 +115,14 @@ where unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - // Safety: these are the same type and we are laundering the generic + pub unsafe fn to_bitmask_intrinsic(self) -> U { unsafe { core::mem::transmute_copy(&self.0) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { - // Safety: these are the same type and we are laundering the generic - Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData) + pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) } } #[inline] diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index e5bb784bb91..b20b0a4b708 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -109,41 +109,16 @@ where unsafe { Mask(intrinsics::simd_cast(self.0)) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - unsafe { - let mut bitmask: [u8; LaneCount::::BITMASK_LEN] = - intrinsics::simd_bitmask(self.0); - - // There is a bug where LLVM appears to implement this operation with the wrong - // bit order. - // TODO fix this in a better way - if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); - } - } - - bitmask - } + pub unsafe fn to_bitmask_intrinsic(self) -> U { + // Safety: caller must only return bitmask types + unsafe { intrinsics::simd_bitmask(self.0) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(mut bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { + pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + // Safety: caller must only pass bitmask types unsafe { - // There is a bug where LLVM appears to implement this operation with the wrong - // bit order. - // TODO fix this in a better way - if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); - } - } - Self::from_int_unchecked(intrinsics::simd_select_bitmask( bitmask, Self::splat(true).to_int(), diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs new file mode 100644 index 00000000000..3a9f89f19eb --- /dev/null +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -0,0 +1,78 @@ +use super::{mask_impl, Mask, MaskElement}; + +/// Converts masks to and from bitmasks. +/// +/// In a bitmask, each bit represents if the corresponding lane in the mask is set. +pub trait ToBitMask { + /// Converts a mask to a bitmask. + fn to_bitmask(self) -> BitMask; + + /// Converts a bitmask to a mask. + fn from_bitmask(bitmask: BitMask) -> Self; +} + +macro_rules! impl_integer_intrinsic { + { $(unsafe impl ToBitMask<$int:ty> for Mask<_, $lanes:literal>)* } => { + $( + impl ToBitMask<$int> for Mask { + fn to_bitmask(self) -> $int { + unsafe { self.0.to_bitmask_intrinsic() } + } + + fn from_bitmask(bitmask: $int) -> Self { + unsafe { Self(mask_impl::Mask::from_bitmask_intrinsic(bitmask)) } + } + } + )* + } +} + +impl_integer_intrinsic! { + unsafe impl ToBitMask for Mask<_, 8> + unsafe impl ToBitMask for Mask<_, 16> + unsafe impl ToBitMask for Mask<_, 32> + unsafe impl ToBitMask for Mask<_, 64> +} + +macro_rules! impl_integer_via { + { $(impl ToBitMask<$int:ty, via $via:ty> for Mask<_, $lanes:literal>)* } => { + $( + impl ToBitMask<$int> for Mask { + fn to_bitmask(self) -> $int { + let bitmask: $via = self.to_bitmask(); + bitmask as _ + } + + fn from_bitmask(bitmask: $int) -> Self { + Self::from_bitmask(bitmask as $via) + } + } + )* + } +} + +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 8> + + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 16> + + impl ToBitMask for Mask<_, 32> +} + +#[cfg(target_pointer_width = "32")] +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 32> +} + +#[cfg(target_pointer_width = "64")] +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 32> + impl ToBitMask for Mask<_, 64> +} diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 6a8ecd33a73..965c0fa2635 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -68,16 +68,16 @@ macro_rules! test_mask_api { assert_eq!(core_simd::Mask::<$type, 8>::from_int(int), mask); } - #[cfg(feature = "generic_const_exprs")] #[test] fn roundtrip_bitmask_conversion() { + use core_simd::ToBitMask; let values = [ true, false, false, true, false, false, true, false, true, true, false, false, false, false, false, true, ]; let mask = core_simd::Mask::<$type, 16>::from_array(values); - let bitmask = mask.to_bitmask(); - assert_eq!(bitmask, [0b01001001, 0b10000011]); + let bitmask: u16 = mask.to_bitmask(); + assert_eq!(bitmask, 0b1000001101001001); assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask); } } From 11c3eefa3594055192612d0d6f844e764dcbda15 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Sun, 6 Feb 2022 03:25:27 +0000 Subject: [PATCH 2/3] Manually implement for supported lanes --- crates/core_simd/src/masks/bitmask.rs | 5 +- crates/core_simd/src/masks/full_masks.rs | 6 +- crates/core_simd/src/masks/to_bitmask.rs | 84 +++++++++--------------- crates/core_simd/tests/masks.rs | 2 +- 4 files changed, 38 insertions(+), 59 deletions(-) diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index f20f83ecb38..e27b2689606 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -116,12 +116,13 @@ where } #[inline] - pub unsafe fn to_bitmask_intrinsic(self) -> U { + pub unsafe fn to_bitmask_integer(self) -> U { unsafe { core::mem::transmute_copy(&self.0) } } + // Safety: U must be the integer with the exact number of bits required to hold the bitmask for #[inline] - pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + pub unsafe fn from_bitmask_integer(bitmask: U) -> Self { unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) } } diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index b20b0a4b708..90af486a887 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -110,13 +110,15 @@ where } #[inline] - pub unsafe fn to_bitmask_intrinsic(self) -> U { + pub unsafe fn to_bitmask_integer(self) -> U { // Safety: caller must only return bitmask types unsafe { intrinsics::simd_bitmask(self.0) } } + // Safety: U must be the integer with the exact number of bits required to hold the bitmask for + // this mask #[inline] - pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + pub unsafe fn from_bitmask_integer(bitmask: U) -> Self { // Safety: caller must only pass bitmask types unsafe { Self::from_int_unchecked(intrinsics::simd_select_bitmask( diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs index 3a9f89f19eb..86143f2331f 100644 --- a/crates/core_simd/src/masks/to_bitmask.rs +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -1,78 +1,54 @@ use super::{mask_impl, Mask, MaskElement}; -/// Converts masks to and from bitmasks. +/// Converts masks to and from integer bitmasks. /// -/// In a bitmask, each bit represents if the corresponding lane in the mask is set. -pub trait ToBitMask { +/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB. +pub trait ToBitMask { + /// The integer bitmask type. + type BitMask; + /// Converts a mask to a bitmask. - fn to_bitmask(self) -> BitMask; + fn to_bitmask(self) -> Self::BitMask; /// Converts a bitmask to a mask. - fn from_bitmask(bitmask: BitMask) -> Self; + fn from_bitmask(bitmask: Self::BitMask) -> Self; } -macro_rules! impl_integer_intrinsic { - { $(unsafe impl ToBitMask<$int:ty> for Mask<_, $lanes:literal>)* } => { - $( - impl ToBitMask<$int> for Mask { - fn to_bitmask(self) -> $int { - unsafe { self.0.to_bitmask_intrinsic() } - } +/// Converts masks to and from byte array bitmasks. +/// +/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte. +pub trait ToBitMaskArray { + /// The length of the bitmask array. + const BYTES: usize; - fn from_bitmask(bitmask: $int) -> Self { - unsafe { Self(mask_impl::Mask::from_bitmask_intrinsic(bitmask)) } - } - } - )* - } -} + /// Converts a mask to a bitmask. + fn to_bitmask_array(self) -> [u8; Self::BYTES]; -impl_integer_intrinsic! { - unsafe impl ToBitMask for Mask<_, 8> - unsafe impl ToBitMask for Mask<_, 16> - unsafe impl ToBitMask for Mask<_, 32> - unsafe impl ToBitMask for Mask<_, 64> + /// Converts a bitmask to a mask. + fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self; } -macro_rules! impl_integer_via { - { $(impl ToBitMask<$int:ty, via $via:ty> for Mask<_, $lanes:literal>)* } => { +macro_rules! impl_integer_intrinsic { + { $(unsafe impl ToBitMask for Mask<_, $lanes:literal>)* } => { $( - impl ToBitMask<$int> for Mask { + impl ToBitMask for Mask { + type BitMask = $int; + fn to_bitmask(self) -> $int { - let bitmask: $via = self.to_bitmask(); - bitmask as _ + unsafe { self.0.to_bitmask_integer() } } fn from_bitmask(bitmask: $int) -> Self { - Self::from_bitmask(bitmask as $via) + unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) } } } )* } } -impl_integer_via! { - impl ToBitMask for Mask<_, 8> - impl ToBitMask for Mask<_, 8> - impl ToBitMask for Mask<_, 8> - - impl ToBitMask for Mask<_, 16> - impl ToBitMask for Mask<_, 16> - - impl ToBitMask for Mask<_, 32> -} - -#[cfg(target_pointer_width = "32")] -impl_integer_via! { - impl ToBitMask for Mask<_, 8> - impl ToBitMask for Mask<_, 16> - impl ToBitMask for Mask<_, 32> -} - -#[cfg(target_pointer_width = "64")] -impl_integer_via! { - impl ToBitMask for Mask<_, 8> - impl ToBitMask for Mask<_, 16> - impl ToBitMask for Mask<_, 32> - impl ToBitMask for Mask<_, 64> +impl_integer_intrinsic! { + unsafe impl ToBitMask for Mask<_, 8> + unsafe impl ToBitMask for Mask<_, 16> + unsafe impl ToBitMask for Mask<_, 32> + unsafe impl ToBitMask for Mask<_, 64> } diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 965c0fa2635..3aec36ca7b7 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -76,7 +76,7 @@ macro_rules! test_mask_api { true, true, false, false, false, false, false, true, ]; let mask = core_simd::Mask::<$type, 16>::from_array(values); - let bitmask: u16 = mask.to_bitmask(); + let bitmask = mask.to_bitmask(); assert_eq!(bitmask, 0b1000001101001001); assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask); } From 20fa4b76235afb6a2ad543781a10a14e8013b143 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Wed, 9 Feb 2022 04:54:05 +0000 Subject: [PATCH 3/3] Make internal mask implementation safe --- crates/core_simd/src/masks/bitmask.rs | 15 +++++-- crates/core_simd/src/masks/full_masks.rs | 51 ++++++++++++++++++++---- crates/core_simd/src/masks/to_bitmask.rs | 39 +++++++++--------- 3 files changed, 75 insertions(+), 30 deletions(-) diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index e27b2689606..7bf2add2036 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -1,7 +1,7 @@ #![allow(unused_imports)] use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; use core::marker::PhantomData; /// A mask where each lane is represented by a single bit. @@ -116,13 +116,20 @@ where } #[inline] - pub unsafe fn to_bitmask_integer(self) -> U { + pub fn to_bitmask_integer(self) -> U + where + super::Mask: ToBitMask, + { + // Safety: these are the same types unsafe { core::mem::transmute_copy(&self.0) } } - // Safety: U must be the integer with the exact number of bits required to hold the bitmask for #[inline] - pub unsafe fn from_bitmask_integer(bitmask: U) -> Self { + pub fn from_bitmask_integer(bitmask: U) -> Self + where + super::Mask: ToBitMask, + { + // Safety: these are the same types unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) } } diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index 90af486a887..848997a0792 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -2,7 +2,7 @@ use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; #[repr(transparent)] pub struct Mask(Simd) @@ -66,6 +66,23 @@ where } } +// Used for bitmask bit order workaround +pub(crate) trait ReverseBits { + fn reverse_bits(self) -> Self; +} + +macro_rules! impl_reverse_bits { + { $($int:ty),* } => { + $( + impl ReverseBits for $int { + fn reverse_bits(self) -> Self { <$int>::reverse_bits(self) } + } + )* + } +} + +impl_reverse_bits! { u8, u16, u32, u64 } + impl Mask where T: MaskElement, @@ -110,16 +127,34 @@ where } #[inline] - pub unsafe fn to_bitmask_integer(self) -> U { - // Safety: caller must only return bitmask types - unsafe { intrinsics::simd_bitmask(self.0) } + pub(crate) fn to_bitmask_integer(self) -> U + where + super::Mask: ToBitMask, + { + // Safety: U is required to be the appropriate bitmask type + let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) }; + + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + bitmask.reverse_bits() + } else { + bitmask + } } - // Safety: U must be the integer with the exact number of bits required to hold the bitmask for - // this mask #[inline] - pub unsafe fn from_bitmask_integer(bitmask: U) -> Self { - // Safety: caller must only pass bitmask types + pub(crate) fn from_bitmask_integer(bitmask: U) -> Self + where + super::Mask: ToBitMask, + { + // LLVM assumes bit order should match endianness + let bitmask = if cfg!(target_endian = "big") { + bitmask.reverse_bits() + } else { + bitmask + }; + + // Safety: U is required to be the appropriate bitmask type unsafe { Self::from_int_unchecked(intrinsics::simd_select_bitmask( bitmask, diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs index 86143f2331f..1c2037764c1 100644 --- a/crates/core_simd/src/masks/to_bitmask.rs +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -1,9 +1,26 @@ use super::{mask_impl, Mask, MaskElement}; +use crate::simd::{LaneCount, SupportedLaneCount}; + +mod sealed { + pub trait Sealed {} +} +pub use sealed::Sealed; + +impl Sealed for Mask +where + T: MaskElement, + LaneCount: SupportedLaneCount, +{ +} /// Converts masks to and from integer bitmasks. /// /// Each bit of the bitmask corresponds to a mask lane, starting with the LSB. -pub trait ToBitMask { +/// +/// # Safety +/// This trait is `unsafe` and sealed, since the `BitMask` type must match the number of lanes in +/// the mask. +pub unsafe trait ToBitMask: Sealed { /// The integer bitmask type. type BitMask; @@ -14,32 +31,18 @@ pub trait ToBitMask { fn from_bitmask(bitmask: Self::BitMask) -> Self; } -/// Converts masks to and from byte array bitmasks. -/// -/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte. -pub trait ToBitMaskArray { - /// The length of the bitmask array. - const BYTES: usize; - - /// Converts a mask to a bitmask. - fn to_bitmask_array(self) -> [u8; Self::BYTES]; - - /// Converts a bitmask to a mask. - fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self; -} - macro_rules! impl_integer_intrinsic { { $(unsafe impl ToBitMask for Mask<_, $lanes:literal>)* } => { $( - impl ToBitMask for Mask { + unsafe impl ToBitMask for Mask { type BitMask = $int; fn to_bitmask(self) -> $int { - unsafe { self.0.to_bitmask_integer() } + self.0.to_bitmask_integer() } fn from_bitmask(bitmask: $int) -> Self { - unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) } + Self(mask_impl::Mask::from_bitmask_integer(bitmask)) } } )*