Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derive Decode for transparent enums #14

Merged
merged 2 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ssz/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum DecodeError {
BytesInvalid(String),
/// The given union selector is out of bounds.
UnionSelectorInvalid(u8),
/// The given bytes could not be successfully decoded into any variant of the transparent enum.
NoMatchingVariant,
}

/// Performs checks on the `offset` based upon the other parameters provided.
Expand Down
81 changes: 69 additions & 12 deletions ssz_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
//! - `#[ssz(enum_behaviour = "tag")]`: encodes and decodes an `enum` with 0 fields per variant
//! - `#[ssz(enum_behaviour = "union")]`: encodes and decodes an `enum` with a one-byte variant selector.
//! - `#[ssz(enum_behaviour = "transparent")]`: allows encoding an `enum` by serializing only the
//! value whilst ignoring outermost the `enum`.
//! value whilst ignoring outermost the `enum`. decodes by attempting to decode each variant
//! in order and the first one that is successful is returned.
//! - `#[ssz(struct_behaviour = "container")]`: encodes and decodes the `struct` as an SSZ
//! "container".
//! - `#[ssz(struct_behaviour = "transparent")]`: encodes and decodes a `struct` with exactly one
Expand Down Expand Up @@ -127,7 +128,7 @@
//! );
//!
//! /// Represented as only the value in the enum variant.
//! #[derive(Debug, PartialEq, Encode)]
//! #[derive(Debug, PartialEq, Encode, Decode)]
//! #[ssz(enum_behaviour = "transparent")]
//! enum TransparentEnum {
//! Foo(u8),
Expand All @@ -142,6 +143,10 @@
//! TransparentEnum::Bar(vec![42, 42]).as_ssz_bytes(),
//! vec![42, 42]
//! );
//! assert_eq!(
//! TransparentEnum::from_ssz_bytes(&[42, 42]).unwrap(),
//! TransparentEnum::Bar(vec![42, 42]),
//! );
//!
//! /// Representated as an SSZ "uint8"
//! #[derive(Debug, PartialEq, Encode, Decode)]
Expand Down Expand Up @@ -170,9 +175,6 @@ use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident, Index};
/// extensions).
const MAX_UNION_SELECTOR: u8 = 127;

const ENUM_TRANSPARENT: &str = "transparent";
const ENUM_UNION: &str = "union";
const ENUM_TAG: &str = "tag";
const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute with \
a \"transparent\", \"union\", or \"tag\" value, e.g., #[ssz(enum_behaviour = \"transparent\")]";

Expand Down Expand Up @@ -524,8 +526,7 @@ fn ssz_encode_derive_struct_transparent(
///
/// The "transparent" method is distinct from the "union" method specified in the SSZ specification.
/// When using "transparent", the enum will be ignored and the contained field will be serialized as
/// if the enum does not exist. Since an union variant "selector" is not serialized, it is not
/// possible to reliably decode an enum that is serialized transparently.
/// if the enum does not exist.
///
/// ## Limitations
///
Expand Down Expand Up @@ -739,10 +740,7 @@ pub fn ssz_decode_derive(input: TokenStream) -> TokenStream {
Procedure::Enum { data, behaviour } => match behaviour {
EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, data),
EnumBehaviour::Tag => ssz_decode_derive_enum_tag(&item, data),
EnumBehaviour::Transparent => panic!(
"Decode cannot be derived for enum_behaviour \"{}\", only \"{}\" and \"{}\" is valid.",
ENUM_TRANSPARENT, ENUM_UNION, ENUM_TAG,
),
EnumBehaviour::Transparent => ssz_decode_derive_enum_transparent(&item, data),
},
}
}
Expand Down Expand Up @@ -1060,7 +1058,7 @@ fn ssz_decode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum
let variant_name = &variant.ident;

if variant.fields.len() != 1 {
panic!("ssz::Encode can only be derived for enums with 1 field per variant");
panic!("ssz::Decode can only be derived for enums with 1 field per variant");
}

let constructor = quote! {
Expand Down Expand Up @@ -1101,6 +1099,65 @@ fn ssz_decode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum
output.into()
}

/// Derive `ssz::Decode` for an enum in the "transparent" method.
///
/// The "transparent" method attempts to decode into an enum by trying to decode each variant in
/// order until one is successful. If no variant decodes successfully,
/// `ssz::DecodeError::NoMatchingVariant` is returned.
///
/// ## Limitations
///
/// Only supports enums with a single field per variant.
///
/// ## Considerations
///
/// The ordering of the enum variants matters. For example, a variant containing a single
/// "variable-size" type may always result in a successful decoding (e.g. MyEnum::Foo(Vec<u8>)).
fn ssz_decode_derive_enum_transparent(
derive_input: &DeriveInput,
enum_data: &DataEnum,
) -> TokenStream {
let name = &derive_input.ident;
let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();

let (constructors, var_types): (Vec<_>, Vec<_>) = enum_data
.variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;

if variant.fields.len() != 1 {
panic!("ssz::Decode can only be derived for enums with 1 field per variant");
}

let constructor = quote! {
#name::#variant_name
};

let ty = &(&variant.fields).into_iter().next().unwrap().ty;
(constructor, ty)
})
.unzip();

let output = quote! {
impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
fn is_ssz_fixed_len() -> bool {
false
}

fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
#(
if let Ok(var) = <#var_types as ssz::Decode>::from_ssz_bytes(bytes) {
return Ok(#constructors(var));
}
)*
Err(ssz::DecodeError::NoMatchingVariant)
}
}
};
output.into()
}

fn compute_union_selectors(num_variants: usize) -> Vec<u8> {
let union_selectors = (0..num_variants)
.map(|i| {
Expand Down
22 changes: 15 additions & 7 deletions ssz_derive/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ssz::{Decode, Encode};
use ssz::{Decode, DecodeError, Encode};
use ssz_derive::{Decode, Encode};
use std::fmt::Debug;
use std::marker::PhantomData;
Expand Down Expand Up @@ -56,14 +56,14 @@ struct VariableB {
b: u8,
}

#[derive(PartialEq, Debug, Encode)]
#[derive(PartialEq, Debug, Encode, Decode)]
#[ssz(enum_behaviour = "transparent")]
enum TwoVariableTrans {
A(VariableA),
B(VariableB),
}

#[derive(PartialEq, Debug, Encode)]
#[derive(PartialEq, Debug, Encode, Decode)]
struct TwoVariableTransStruct {
a: TwoVariableTrans,
}
Expand Down Expand Up @@ -91,19 +91,27 @@ fn two_variable_trans() {
b: 3,
});

assert_encode(&trans_a, &[1, 5, 0, 0, 0, 2, 3]);
assert_encode(&trans_b, &[5, 0, 0, 0, 3, 1, 2]);
assert_encode_decode(&trans_a, &[1, 5, 0, 0, 0, 2, 3]);
assert_encode_decode(&trans_b, &[5, 0, 0, 0, 3, 1, 2]);

assert_encode(
assert_encode_decode(
&TwoVariableTransStruct { a: trans_a },
&[4, 0, 0, 0, 1, 5, 0, 0, 0, 2, 3],
);
assert_encode(
assert_encode_decode(
&TwoVariableTransStruct { a: trans_b },
&[4, 0, 0, 0, 5, 0, 0, 0, 3, 1, 2],
);
}

#[test]
fn trans_enum_error() {
assert_eq!(
TwoVariableTrans::from_ssz_bytes(&[1, 3, 0, 0, 0]).unwrap_err(),
DecodeError::NoMatchingVariant,
);
}

#[test]
fn two_variable_union() {
let union_a = TwoVariableUnion::A(VariableA {
Expand Down