Skip to content

Commit

Permalink
Cleanup custom bounds code.
Browse files Browse the repository at this point in the history
  • Loading branch information
zachs18 committed Jun 12, 2023
1 parent 044c7fe commit 60f5524
Showing 1 changed file with 37 additions and 44 deletions.
81 changes: 37 additions & 44 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,82 +366,75 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
.unwrap_or_else(|err| err.into_compile_error())
}

/// Find a `#[name(key = "value")]` attribute on the struct, and parse "value"
/// with `parser` and return it.
/// Find `#[name(key = "value")]` helper attributes on the struct, and return
/// their `"value"`s parsed with `parser`.
///
/// Returns an error if multiple attributes with `name` are found, or if the one
/// found does not match the expected format. Returns `Ok(None)` if no attribute
/// with `name` is found.
fn find_helper_attribute<P: syn::parse::Parser + Copy>(
/// Returns an error any matching attributes do not match the expected format.
/// Returns `Ok([])` if no attributes with `name` are found.
fn find_helper_attributes<P: syn::parse::Parser + Copy>(
attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
example_value: &str,
) -> Result<Vec<(syn::LitStr, P::Output)>> {
example_value: &str, invalid_value_msg: &str,
) -> Result<Vec<P::Output>> {
let invalid_format_msg =
format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
syn::Meta::Path(path) => path.is_ident(name).then(|| {
Err(syn::Error::new_spanned(
&path,
format!(
"{name} attribute must be `{name}({key} = \"{example_value}\")`",
),
))
}),
// If a `Path` matches our `name`, return an error, else ignore it.
// e.g. `#[zeroable]`
syn::Meta::Path(path) => path
.is_ident(name)
.then(|| Err(syn::Error::new_spanned(&path, &invalid_format_msg))),
// If a `NameValue` matches our `name`, return an error, else ignore it.
// e.g. `#[zeroable = "hello"]`
syn::Meta::NameValue(namevalue) => {
namevalue.path.is_ident(name).then(|| {
Err(syn::Error::new_spanned(
&namevalue.path,
format!(
"{name} attribute must be `{name}({key} = \"{example_value}\")`",
),
))
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
})
}
// If a `List` matches our `name`, match its contents to our format, else
// ignore it. If its contents match our format, return the value, else
// return an error.
syn::Meta::List(list) => list.path.is_ident(name).then(|| {
let namevalue: MetaNameValue =
syn::parse2(list.tokens.clone()).map_err(|_| {
syn::Error::new_spanned(
&list.tokens,
format!(
"{name} attribute must be `{name}({key} = \"{example_value}\")`",
),
)
syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
})?;
if namevalue.path.is_ident("bound") {
if namevalue.path.is_ident(key) {
match namevalue.value {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(strlit), ..
}) => Ok(strlit),
_ => Err(syn::Error::new_spanned(
&namevalue.path,
format!(
"{name} attribute must be `{name}({key} = \"{example_value}\")`",
),
)),
_ => {
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
}
}
} else {
Err(syn::Error::new_spanned(
&namevalue.path,
format!(
"{name} attribute must be `{name}({key} = \"{example_value}\")`",
),
))
Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
}
}),
});
// Parse each value found with the given parser, and return them if no errors
// occur.
values_to_check
.map(|r| r.and_then(|lit| Ok((lit.clone(), lit.parse_with(parser)?))))
.map(|lit| {
let lit = lit?;
lit.parse_with(parser).map_err(|err| {
syn::Error::new_spanned(&lit, &format!("{invalid_value_msg}: {err}"))
})
})
.collect()
}

fn derive_marker_trait_inner<Trait: Derivable>(
mut input: DeriveInput,
) -> Result<TokenStream> {
let trait_ = Trait::ident(&input)?;
let explicit_bounds = find_helper_attribute(
let explicit_bounds = find_helper_attributes(
&input.attrs,
"zeroable",
"bound",
<Punctuated<WherePredicate, syn::Token![,]>>::parse_terminated,
"Type: Trait",
"invalid where predicate",
)?;
if explicit_bounds.is_empty() {
// Enforce bound on all generic fields.
Expand All @@ -451,7 +444,7 @@ fn derive_marker_trait_inner<Trait: Derivable>(
// soundness)
let explicit_bounds = explicit_bounds
.into_iter()
.flat_map(|a| (a.1))
.flatten()
.collect::<Vec<syn::WherePredicate>>();

input.generics.make_where_clause().predicates.extend(explicit_bounds);
Expand Down

0 comments on commit 60f5524

Please sign in to comment.