From d2c1197c3d06f44f701d2c3bc482238421ff978e Mon Sep 17 00:00:00 2001 From: Zachary S Date: Sun, 28 Jul 2024 00:37:44 -0500 Subject: [PATCH 1/9] Add support for deriving Zeroable for fieldful enums if: 1. the enum is repr(Int), repr(C), or repr(C, Int), 2. the enum has a variant with discriminant 0, 3. and all fields of the variant with discriminant 0 are Zeroable. --- derive/src/traits.rs | 121 +++++++++++++++++++++++++++--------------- derive/tests/basic.rs | 120 ++++++++++++++++++++++++++++++----------- 2 files changed, 165 insertions(+), 76 deletions(-) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 89fc6c4..8326f8d 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -76,8 +76,11 @@ impl Derivable for Pod { } else { None }; - let assert_fields_are_pod = - generate_fields_are_trait(input, Self::ident(input, crate_name)?)?; + let assert_fields_are_pod = generate_fields_are_trait( + input, + None, + Self::ident(input, crate_name)?, + )?; Ok(quote!( #assert_no_padding @@ -118,7 +121,7 @@ impl Derivable for AnyBitPattern { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` Data::Struct(_) => { - generate_fields_are_trait(input, Self::ident(input, crate_name)?) + generate_fields_are_trait(input, None, Self::ident(input, crate_name)?) } Data::Enum(_) => { bail!("Deriving AnyBitPattern is not supported for enums") @@ -139,24 +142,23 @@ impl Derivable for Zeroable { match ty { Data::Struct(_) => Ok(()), Data::Enum(DataEnum { variants, .. }) => { - if !repr.repr.is_integer() { - bail!("Zeroable requires the enum to be an explicit #[repr(Int)]") - } - - if variants.iter().any(|variant| !variant.fields.is_empty()) { - bail!("Only fieldless enums are supported for Zeroable") + if !matches!( + repr.repr, + Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_) + ) { + bail!("Zeroable requires the enum to be an explicit #[repr(Int)] or #[repr(C)]") } let iter = VariantDiscriminantIterator::new(variants.iter()); - let mut has_zero_variant = false; + let mut zero_variant = None; for res in iter { - let discriminant = res?; + let (discriminant, variant) = res?; if discriminant == 0 { - has_zero_variant = true; + zero_variant = Some(variant); break; } } - if !has_zero_variant { + if zero_variant.is_none() { bail!("No variant's discriminant is 0") } @@ -172,9 +174,28 @@ impl Derivable for Zeroable { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` Data::Struct(_) => { - generate_fields_are_trait(input, Self::ident(input, crate_name)?) + generate_fields_are_trait(input, None, Self::ident(input, crate_name)?) + } + Data::Enum(DataEnum { variants, .. }) => { + let iter = VariantDiscriminantIterator::new(variants.iter()); + let mut zero_variant = None; + for res in iter { + let (discriminant, variant) = res?; + if discriminant == 0 { + zero_variant = Some(variant); + break; + } + } + if zero_variant.is_none() { + bail!("No variant's discriminant is 0") + }; + + generate_fields_are_trait( + input, + zero_variant, + Self::ident(input, crate_name)?, + ) } - Data::Enum(_) => Ok(quote!()), } } @@ -216,8 +237,11 @@ impl Derivable for NoUninit { match &input.data { Data::Struct(DataStruct { .. }) => { let assert_no_padding = generate_assert_no_padding(&input)?; - let assert_fields_are_no_padding = - generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?; + let assert_fields_are_no_padding = generate_fields_are_trait( + &input, + None, + Self::ident(input, crate_name)?, + )?; Ok(quote!( #assert_no_padding @@ -282,13 +306,16 @@ impl Derivable for CheckedBitPattern { match &input.data { Data::Struct(DataStruct { .. }) => { - let assert_fields_are_maybe_pod = - generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?; + let assert_fields_are_maybe_pod = generate_fields_are_trait( + &input, + None, + Self::ident(input, crate_name)?, + )?; Ok(assert_fields_are_maybe_pod) } - Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed - * OK by NoUninit */ + // nothing needed, already guaranteed OK by NoUninit. + Data::Enum(_) => Ok(quote!()), Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } @@ -439,16 +466,16 @@ impl Derivable for Contiguous { )); } - let mut variants_with_discriminator = + let mut variants_with_discriminant = VariantDiscriminantIterator::new(variants); - let (min, max, count) = variants_with_discriminator.try_fold( + let (min, max, count) = variants_with_discriminant.try_fold( (i64::max_value(), i64::min_value(), 0), |(min, max, count), res| { - let discriminator = res?; + let (discriminant, _variant) = res?; Ok::<_, Error>(( - i64::min(min, discriminator), - i64::max(max, discriminator), + i64::min(min, discriminant), + i64::max(max, discriminant), count + 1, )) }, @@ -505,11 +532,16 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> { } } -fn get_fields(input: &DeriveInput) -> Result { +fn get_fields( + input: &DeriveInput, enum_variant: Option<&Variant>, +) -> Result { match &input.data { Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), - Data::Enum(_) => bail!("deriving this trait is not supported for enums"), + Data::Enum(_) => match enum_variant { + Some(variant) => Ok(variant.fields.clone()), + None => bail!("deriving this trait is not supported for enums"), + }, } } @@ -598,7 +630,7 @@ fn generate_checked_bit_pattern_enum_without_fields( let (min, max, count) = variants_with_discriminant.try_fold( (i64::max_value(), i64::min_value(), 0), |(min, max, count), res| { - let discriminant = res?; + let (discriminant, _variant) = res?; Ok::<_, Error>(( i64::min(min, discriminant), i64::max(max, discriminant), @@ -617,16 +649,17 @@ fn generate_checked_bit_pattern_enum_without_fields( quote!(*bits >= #min_lit && *bits <= #max_lit) } else { // not contiguous range, check for each - let variant_lits = VariantDiscriminantIterator::new(variants.iter()) - .map(|res| { - let variant = res?; - Ok(LitInt::new(&format!("{}", variant), span)) - }) - .collect::>>()?; + let variant_discriminant_lits = + VariantDiscriminantIterator::new(variants.iter()) + .map(|res| { + let (discriminant, _variant) = res?; + Ok(LitInt::new(&format!("{}", discriminant), span)) + }) + .collect::>>()?; // count is at least 1 - let first = &variant_lits[0]; - let rest = &variant_lits[1..]; + let first = &variant_discriminant_lits[0]; + let rest = &variant_discriminant_lits[1..]; quote!(matches!(*bits, #first #(| #rest )*)) }; @@ -720,7 +753,7 @@ fn generate_checked_bit_pattern_enum_with_fields( .zip(VariantDiscriminantIterator::new(variants.iter())) .zip(variants.iter()) .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { - let discriminant = discriminant?; + let (discriminant, _variant) = discriminant?; let discriminant = LitInt::new(&discriminant.to_string(), v.span()); let ident = &v.ident; Ok(quote! { @@ -850,7 +883,7 @@ fn generate_checked_bit_pattern_enum_with_fields( .zip(VariantDiscriminantIterator::new(variants.iter())) .zip(variants.iter()) .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { - let discriminant = discriminant?; + let (discriminant, _variant) = discriminant?; let discriminant = LitInt::new(&discriminant.to_string(), v.span()); let ident = &v.ident; Ok(quote! { @@ -906,7 +939,7 @@ fn generate_checked_bit_pattern_enum_with_fields( fn generate_assert_no_padding(input: &DeriveInput) -> Result { let struct_type = &input.ident; let span = input.ident.span(); - let fields = get_fields(input)?; + let fields = get_fields(input, None)?; let mut field_types = get_field_types(&fields); let size_sum = if let Some(first) = field_types.next() { @@ -928,11 +961,11 @@ fn generate_assert_no_padding(input: &DeriveInput) -> Result { /// Check that all fields implement a given trait fn generate_fields_are_trait( - input: &DeriveInput, trait_: syn::Path, + input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path, ) -> Result { let (impl_generics, _ty_generics, where_clause) = input.generics.split_for_impl(); - let fields = get_fields(input)?; + let fields = get_fields(input, enum_variant)?; let span = input.span(); let field_types = get_field_types(&fields); Ok(quote_spanned! {span => #(const _: fn() = || { @@ -1200,7 +1233,7 @@ impl<'a, I: Iterator + 'a> impl<'a, I: Iterator + 'a> Iterator for VariantDiscriminantIterator<'a, I> { - type Item = Result; + type Item = Result<(i64, &'a Variant)>; fn next(&mut self) -> Option { let variant = self.inner.next()?; @@ -1215,7 +1248,7 @@ impl<'a, I: Iterator + 'a> Iterator self.last_value += 1; } - Some(Ok(self.last_value)) + Some(Ok((self.last_value, variant))) } } diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index dfb1ff6..8a3dbe1 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -1,8 +1,8 @@ #![allow(dead_code)] use bytemuck::{ - AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod, - TransparentWrapper, Zeroable, checked::CheckedCastError, + checked::CheckedCastError, AnyBitPattern, CheckedBitPattern, Contiguous, + NoUninit, Pod, TransparentWrapper, Zeroable, }; use std::marker::{PhantomData, PhantomPinned}; @@ -58,6 +58,45 @@ enum ZeroEnum { C = 2, } +#[derive(Zeroable)] +#[repr(u8)] +enum BasicFieldfulZeroEnum { + A(u8) = 0, + B = 1, + C(String) = 2, +} + +#[derive(Zeroable)] +#[repr(C)] +enum ReprCFieldfulZeroEnum { + A(u8), + B(Box<[u8]>), + C, +} + +#[derive(Zeroable)] +#[repr(C, i32)] +enum ReprCIntFieldfulZeroEnum { + B(String) = 1, + A(u8, bool, char) = 0, + C = 2, +} + +#[derive(Zeroable)] +#[repr(i32)] +enum GenericFieldfulZeroEnum { + A(Box) = 1, + B(T, T) = 0, +} + +#[derive(Zeroable)] +#[repr(i32)] +#[zeroable(bound = "")] +enum GenericCustomBoundFieldfulZeroEnum { + A(Option>), + B(String), +} + #[derive(TransparentWrapper)] #[repr(transparent)] struct TransparentSingle { @@ -202,8 +241,10 @@ enum CheckedBitPatternTransparentEnumWithFields { } // size 24, align 8. -// first byte always the u8 discriminant, then 7 bytes of padding until the payload union since the align of the payload -// is the greatest of the align of all the variants, which is 8 (from CheckedBitPatternCDefaultDiscriminantEnumWithFields) +// first byte always the u8 discriminant, then 7 bytes of padding until the +// payload union since the align of the payload is the greatest of the align of +// all the variants, which is 8 (from +// CheckedBitPatternCDefaultDiscriminantEnumWithFields) #[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)] #[repr(C, u8)] enum CheckedBitPatternEnumNested { @@ -388,52 +429,68 @@ fn checkedbitpattern_nested_enum_with_fields() { // first we'll check variantA, nested variant A let pod = Align8Bytes([ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding. - 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, // byte 0 discriminant = 0 = variant A, bytes 1-7 irrelevant padding. + 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, + 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding ]); - let value = bytemuck::checked::from_bytes::< - CheckedBitPatternEnumNested, - >(&pod.0); - assert_eq!(value, &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A(0xcc5555cc))); + let value = + bytemuck::checked::from_bytes::(&pod.0); + assert_eq!( + value, + &CheckedBitPatternEnumNested::A(CheckedBitPatternCEnumWithFields::A( + 0xcc5555cc + )) + ); // next we'll check invalid first discriminant fails let pod = Align8Bytes([ - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding - 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, // byte 0 discriminant = 2 = invalid, bytes 1-7 padding + 0x00, 0x00, 0x00, 0x00, 0xcc, 0x55, 0x55, + 0xcc, // bytes 8-15 are the nested CheckedBitPatternCEnumWithFields = A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 16-23 padding ]); - let result = bytemuck::checked::try_from_bytes::< - CheckedBitPatternEnumNested, - >(&pod.0); + let result = + bytemuck::checked::try_from_bytes::(&pod.0); assert_eq!(result, Err(CheckedCastError::InvalidBitPattern)); - // next we'll check variant B, nested variant B let pod = Align8Bytes([ - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte order) = variant B - 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, // byte 0 discriminant = 1 = variant B, bytes 1-7 padding + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, /* bytes 8-15 is C int size discriminant of + * CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 1 (LE byte + * order) = variant B */ + 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, + 0xcc, // bytes 16-13 is the data contained in nested variant B ]); - let value = bytemuck::checked::from_bytes::< - CheckedBitPatternEnumNested, - >(&pod.0); + let value = + bytemuck::checked::from_bytes::(&pod.0); assert_eq!( value, - &CheckedBitPatternEnumNested::B(CheckedBitPatternCDefaultDiscriminantEnumWithFields::B { - c: 0xcc555555555555cc - }) + &CheckedBitPatternEnumNested::B( + CheckedBitPatternCDefaultDiscriminantEnumWithFields::B { + c: 0xcc555555555555cc + } + ) ); // finally we'll check variant B, nested invalid discriminant let pod = Align8Bytes([ - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 1 discriminant = variant B, bytes 1-7 padding - 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // bytes 8-15 is C int size discriminant of CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is invalid - 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xcc, // bytes 16-13 is the data contained in nested variant B + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, // 1 discriminant = variant B, bytes 1-7 padding + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, /* bytes 8-15 is C int size discriminant of + * CheckedBitPatternCDefaultDiscrimimantEnumWithFields, 0x08 is + * invalid */ + 0xcc, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, + 0xcc, // bytes 16-13 is the data contained in nested variant B ]); - let result = bytemuck::checked::try_from_bytes::< - CheckedBitPatternEnumNested, - >(&pod.0); + let result = + bytemuck::checked::try_from_bytes::(&pod.0); assert_eq!(result, Err(CheckedCastError::InvalidBitPattern)); } #[test] @@ -457,4 +514,3 @@ use bytemuck as reexport_name; #[bytemuck(crate = "reexport_name")] #[repr(C)] struct Issue93 {} - From ff316f6d453739e7b119fe24f8d39b5811b1b848 Mon Sep 17 00:00:00 2001 From: Zachary S Date: Sun, 28 Jul 2024 00:58:07 -0500 Subject: [PATCH 2/9] Allow using derive(Zeroable) with explicit bounds. Update documentation and doctests. --- derive/src/lib.rs | 74 +++++++++++++++++++++++++++++++++++--------- derive/src/traits.rs | 21 +++++++++++++ 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 806a72f..0e30756 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -6,7 +6,7 @@ mod traits; use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_macro_input, DeriveInput, Result}; +use syn::{parse_macro_input, DeriveInput, Result, Variant}; use crate::traits::{ bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, @@ -114,14 +114,26 @@ pub fn derive_anybitpattern( proc_macro::TokenStream::from(expanded) } -/// Derive the `Zeroable` trait for a struct +/// Derive the `Zeroable` trait for a type. /// /// The macro ensures that the struct follows all the the safety requirements /// for the `Zeroable` trait. /// -/// The following constraints need to be satisfied for the macro to succeed +/// The following constraints need to be satisfied for the macro to succeed on a +/// struct: +/// +/// - All fields in the struct must implement `Zeroable` +/// +/// The following constraints need to be satisfied for the macro to succeed on +/// an enum: /// -/// - All fields in the struct must to implement `Zeroable` +/// - The enum has an explicit `#[repr(Int)]`, `#[repr(C)]`, or `#[repr(C, +/// Int)]`. +/// - The enum has a variant with discriminant 0 (explicitly or implicitly). +/// - All fields in the variant with discriminant 0 (if any) must implement +/// `Zeroable` +/// +/// The macro always succeeds on `union`s. /// /// ## Example /// @@ -134,6 +146,23 @@ pub fn derive_anybitpattern( /// b: u16, /// } /// ``` +/// ```rust +/// # use bytemuck_derive::{Zeroable}; +/// #[derive(Copy, Clone, Zeroable)] +/// #[repr(i32)] +/// enum Values { +/// A = 0, +/// B = 1, +/// C = 2, +/// } +/// #[derive(Clone, Zeroable)] +/// #[repr(C)] +/// enum Implicit { +/// A(bool, u8, char), +/// B(String), +/// C(std::num::NonZeroU8), +/// } +/// ``` /// /// # Custom bounds /// @@ -157,6 +186,18 @@ pub fn derive_anybitpattern( /// /// AlwaysZeroable::::zeroed(); /// ``` +/// ```rust +/// # use bytemuck::{Zeroable}; +/// #[derive(Copy, Clone, Zeroable)] +/// #[repr(u8)] +/// #[zeroable(bound = "")] +/// enum MyOption { +/// None, +/// Some(T), +/// } +/// +/// assert!(matches!(MyOption::::zeroed(), MyOption::None)); +/// ``` /// /// ```rust,compile_fail /// # use bytemuck::Zeroable; @@ -407,7 +448,8 @@ pub fn derive_byte_eq( let input = parse_macro_input!(input as DeriveInput); let crate_name = bytemuck_crate_name(&input); let ident = input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause { @@ -460,7 +502,8 @@ pub fn derive_byte_hash( let input = parse_macro_input!(input as DeriveInput); let crate_name = bytemuck_crate_name(&input); let ident = input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); proc_macro::TokenStream::from(quote! { impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause { @@ -569,19 +612,18 @@ fn derive_marker_trait_inner( .flatten() .collect::>(); - let predicates = &mut input.generics.make_where_clause().predicates; - - predicates.extend(explicit_bounds); - - let fields = match &input.data { - syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(), - syn::Data::Union(_) => { + let fields = match (Trait::perfect_derive_fields(&input), &input.data) { + (Some(fields), _) => fields, + (None, syn::Data::Struct(syn::DataStruct { fields, .. })) => { + fields.clone() + } + (None, syn::Data::Union(_)) => { return Err(syn::Error::new_spanned( trait_, &"perfect derive is not supported for unions", )); } - syn::Data::Enum(_) => { + (None, syn::Data::Enum(_)) => { return Err(syn::Error::new_spanned( trait_, &"perfect derive is not supported for enums", @@ -589,6 +631,10 @@ fn derive_marker_trait_inner( } }; + let predicates = &mut input.generics.make_where_clause().predicates; + + predicates.extend(explicit_bounds); + for field in fields { let ty = field.ty; predicates.push(syn::parse_quote!( diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 8326f8d..10e583e 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -44,6 +44,9 @@ pub trait Derivable { fn explicit_bounds_attribute_name() -> Option<&'static str> { None } + fn perfect_derive_fields(input: &DeriveInput) -> Option { + None + } } pub struct Pod; @@ -202,6 +205,24 @@ impl Derivable for Zeroable { fn explicit_bounds_attribute_name() -> Option<&'static str> { Some("zeroable") } + + fn perfect_derive_fields(input: &DeriveInput) -> Option { + match &input.data { + Data::Struct(struct_) => Some(struct_.fields.clone()), + Data::Enum(DataEnum { variants, .. }) => { + let iter = VariantDiscriminantIterator::new(variants.iter()); + for res in iter { + match res { + Ok((0, variant)) => return Some(variant.fields.clone()), + Ok(_) => (), + Err(_) => return None, + } + } + None + } + Data::Union(_) => None, + } + } } pub struct NoUninit; From c8e6fca45ac5b8e9b43d0b41169f488610b2c692 Mon Sep 17 00:00:00 2001 From: Zachary S Date: Sun, 28 Jul 2024 01:10:40 -0500 Subject: [PATCH 3/9] doc update --- derive/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 0e30756..5f6b19c 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -116,7 +116,7 @@ pub fn derive_anybitpattern( /// Derive the `Zeroable` trait for a type. /// -/// The macro ensures that the struct follows all the the safety requirements +/// The macro ensures that the type follows all the the safety requirements /// for the `Zeroable` trait. /// /// The following constraints need to be satisfied for the macro to succeed on a @@ -133,7 +133,7 @@ pub fn derive_anybitpattern( /// - All fields in the variant with discriminant 0 (if any) must implement /// `Zeroable` /// -/// The macro always succeeds on `union`s. +/// The macro always succeeds on unions. /// /// ## Example /// From 92f77c8fb9c32f3ce2f7a5b8daa20bd901451aac Mon Sep 17 00:00:00 2001 From: Zachary S Date: Sun, 28 Jul 2024 01:18:50 -0500 Subject: [PATCH 4/9] doc update --- derive/src/traits.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 10e583e..c31ab43 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -44,6 +44,11 @@ pub trait Derivable { fn explicit_bounds_attribute_name() -> Option<&'static str> { None } + + /// If this trait has a custom meaning for "perfect derive", this function + /// should be overridden to return `Some`. + /// + /// The default is "the fields of a struct; unions and enums not supported". fn perfect_derive_fields(input: &DeriveInput) -> Option { None } From d107b2a0acf9c48f6bd9330a2af253189f9fea6b Mon Sep 17 00:00:00 2001 From: Zachary S Date: Sun, 28 Jul 2024 01:27:08 -0500 Subject: [PATCH 5/9] remove unused --- derive/src/lib.rs | 2 +- derive/src/traits.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5f6b19c..26ecfc1 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -6,7 +6,7 @@ mod traits; use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_macro_input, DeriveInput, Result, Variant}; +use syn::{parse_macro_input, DeriveInput, Result}; use crate::traits::{ bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, diff --git a/derive/src/traits.rs b/derive/src/traits.rs index c31ab43..953787a 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -49,7 +49,7 @@ pub trait Derivable { /// should be overridden to return `Some`. /// /// The default is "the fields of a struct; unions and enums not supported". - fn perfect_derive_fields(input: &DeriveInput) -> Option { + fn perfect_derive_fields(_input: &DeriveInput) -> Option { None } } From be6ece8a56c37b3a8fa7288ac0fd83b0ce815ae7 Mon Sep 17 00:00:00 2001 From: Zachary S Date: Mon, 16 Sep 2024 14:18:48 -0500 Subject: [PATCH 6/9] Factor out get_zero_variant helper function. --- derive/src/traits.rs | 61 ++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 953787a..8cd9e33 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -140,6 +140,21 @@ impl Derivable for AnyBitPattern { pub struct Zeroable; +/// Helper function to get the variant with discriminant zero (implicit or +/// explicit). +fn get_zero_variant(enum_: &DataEnum) -> Result> { + let iter = VariantDiscriminantIterator::new(enum_.variants.iter()); + let mut zero_variant = None; + for res in iter { + let (discriminant, variant) = res?; + if discriminant == 0 { + zero_variant = Some(variant); + break; + } + } + Ok(zero_variant) +} + impl Derivable for Zeroable { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::Zeroable)) @@ -149,26 +164,16 @@ impl Derivable for Zeroable { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => Ok(()), - Data::Enum(DataEnum { variants, .. }) => { + Data::Enum(_) => { if !matches!( repr.repr, Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_) ) { - bail!("Zeroable requires the enum to be an explicit #[repr(Int)] or #[repr(C)]") + bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]") } - let iter = VariantDiscriminantIterator::new(variants.iter()); - let mut zero_variant = None; - for res in iter { - let (discriminant, variant) = res?; - if discriminant == 0 { - zero_variant = Some(variant); - break; - } - } - if zero_variant.is_none() { - bail!("No variant's discriminant is 0") - } + // We ensure there is a zero variant in `asserts`, since it is needed + // there anyway. Ok(()) } @@ -184,16 +189,9 @@ impl Derivable for Zeroable { Data::Struct(_) => { generate_fields_are_trait(input, None, Self::ident(input, crate_name)?) } - Data::Enum(DataEnum { variants, .. }) => { - let iter = VariantDiscriminantIterator::new(variants.iter()); - let mut zero_variant = None; - for res in iter { - let (discriminant, variant) = res?; - if discriminant == 0 { - zero_variant = Some(variant); - break; - } - } + Data::Enum(enum_) => { + let zero_variant = get_zero_variant(enum_)?; + if zero_variant.is_none() { bail!("No variant's discriminant is 0") }; @@ -214,16 +212,11 @@ impl Derivable for Zeroable { fn perfect_derive_fields(input: &DeriveInput) -> Option { match &input.data { Data::Struct(struct_) => Some(struct_.fields.clone()), - Data::Enum(DataEnum { variants, .. }) => { - let iter = VariantDiscriminantIterator::new(variants.iter()); - for res in iter { - match res { - Ok((0, variant)) => return Some(variant.fields.clone()), - Ok(_) => (), - Err(_) => return None, - } - } - None + Data::Enum(enum_) => { + // We handle `Err` returns from `get_zero_variant` in `asserts`, so it's + // fine to just ignore them here and return `None`. + // Otherwise, we clone the `fields` of the zero variant (if any). + Some(get_zero_variant(enum_).ok()??.fields.clone()) } Data::Union(_) => None, } From b5ac207f5b1d739c4d0a6e36c1b5d0909ebd2115 Mon Sep 17 00:00:00 2001 From: Zachary S Date: Mon, 16 Sep 2024 14:40:49 -0500 Subject: [PATCH 7/9] Use i128 to track disciminants instead of i64. --- derive/src/traits.rs | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 8cd9e33..85c16e5 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -489,12 +489,12 @@ impl Derivable for Contiguous { VariantDiscriminantIterator::new(variants); let (min, max, count) = variants_with_discriminant.try_fold( - (i64::max_value(), i64::min_value(), 0), + (i128::MAX, i128::MIN, 0), |(min, max, count), res| { let (discriminant, _variant) = res?; Ok::<_, Error>(( - i64::min(min, discriminant), - i64::max(max, discriminant), + i128::min(min, discriminant), + i128::max(max, discriminant), count + 1, )) }, @@ -647,12 +647,12 @@ fn generate_checked_bit_pattern_enum_without_fields( VariantDiscriminantIterator::new(variants.iter()); let (min, max, count) = variants_with_discriminant.try_fold( - (i64::max_value(), i64::min_value(), 0), + (i128::MAX, i128::MIN, 0), |(min, max, count), res| { let (discriminant, _variant) = res?; Ok::<_, Error>(( - i64::min(min, discriminant), - i64::max(max, discriminant), + i128::min(min, discriminant), + i128::max(max, discriminant), count + 1, )) }, @@ -1238,7 +1238,7 @@ fn enum_has_fields<'a>( struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { inner: I, - last_value: i64, + last_value: i128, } impl<'a, I: Iterator + 'a> @@ -1252,7 +1252,7 @@ impl<'a, I: Iterator + 'a> impl<'a, I: Iterator + 'a> Iterator for VariantDiscriminantIterator<'a, I> { - type Item = Result<(i64, &'a Variant)>; + type Item = Result<(i128, &'a Variant)>; fn next(&mut self) -> Option { let variant = self.inner.next()?; @@ -1264,14 +1264,38 @@ impl<'a, I: Iterator + 'a> Iterator }; self.last_value = discriminant_value; } else { - self.last_value += 1; + // If this wraps, then either: + // 1. the enum is using repr(u128), so wrapping is correct + // 2. the enum is using repr(i<=128 or u<128), so the compiler will + // already emit a "wrapping discriminant" E0370 error. + self.last_value = self.last_value.wrapping_add(1); + // Static assert that there is no integer repr > 128 bits. If that + // changes, the above comment is inaccurate and needs to be updated! + // FIXME(zachs18): maybe should also do something to ensure `isize::BITS + // <= 128`? + if let Some(repr) = None:: { + match repr { + IntegerRepr::U8 + | IntegerRepr::I8 + | IntegerRepr::U16 + | IntegerRepr::I16 + | IntegerRepr::U32 + | IntegerRepr::I32 + | IntegerRepr::U64 + | IntegerRepr::I64 + | IntegerRepr::I128 + | IntegerRepr::U128 + | IntegerRepr::Usize + | IntegerRepr::Isize => (), + } + } } Some(Ok((self.last_value, variant))) } } -fn parse_int_expr(expr: &Expr) -> Result { +fn parse_int_expr(expr: &Expr) -> Result { match expr { Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { parse_int_expr(expr).map(|int| -int) From 4dc453a3eb6a90c5233109e86f15ca3365b14ce7 Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Mon, 16 Sep 2024 14:46:11 -0500 Subject: [PATCH 8/9] Add doc-comment for `get_fields` Co-authored-by: Daniel Henry-Mantilla --- derive/src/traits.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 85c16e5..c1409c2 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -551,6 +551,11 @@ fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> { } } +/// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off +/// those of the `enum_variant`, when provided (e.g., for `Zeroable`). +/// +/// We purposely allow not providing an `enum_variant` for cases where +/// the caller wants to reject supporting `enum`s (e.g., `NoPadding`). fn get_fields( input: &DeriveInput, enum_variant: Option<&Variant>, ) -> Result { From 1b020248df402fc38b9a61244ea85dcfcb882e05 Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Mon, 16 Sep 2024 14:48:20 -0500 Subject: [PATCH 9/9] Update derive/src/traits.rs Co-authored-by: Daniel Henry-Mantilla --- derive/src/traits.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/derive/src/traits.rs b/derive/src/traits.rs index c1409c2..ed8db9f 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -963,7 +963,8 @@ fn generate_checked_bit_pattern_enum_with_fields( fn generate_assert_no_padding(input: &DeriveInput) -> Result { let struct_type = &input.ident; let span = input.ident.span(); - let fields = get_fields(input, None)?; + let enum_variant = None; // `no padding` check is not supported for `enum`s yet. + let fields = get_fields(input, enum_variant)?; let mut field_types = get_field_types(&fields); let size_sum = if let Some(first) = field_types.next() {