Skip to content

Commit

Permalink
add some simple derive tests :)
Browse files Browse the repository at this point in the history
  • Loading branch information
mumbleskates committed May 17, 2024
1 parent dbd5111 commit 4c00d4e
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ extern crate proc_macro;

use anyhow::{bail, Error};
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
Expand All @@ -19,7 +18,7 @@ mod field;
use crate::field::Field;

fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -258,12 +257,12 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
}

#[proc_macro_derive(Message, attributes(prost))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_message(input.into()).unwrap().into()
}

fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;
let ident = input.ident;

let generics = &input.generics;
Expand Down Expand Up @@ -366,12 +365,12 @@ fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
}

#[proc_macro_derive(Enumeration, attributes(prost))]
pub fn enumeration(input: TokenStream) -> TokenStream {
try_enumeration(input).unwrap()
pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_enumeration(input.into()).unwrap().into()
}

fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -415,12 +414,8 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
}

if let Some((invalid_variant, _)) = fields.iter().find(|(_, field)| field.tags().len() > 1) {
bail!(
"invalid oneof variant {}::{}: oneof variants may only have a single tag",
ident,
invalid_variant
);
if fields.iter().any(|(_, field)| field.tags().len() > 1) {
panic!("variant with multiple tags"); // Not clear if this is possible, but good to be safe
}
if let Some((duplicate_tag, _)) = fields
.iter()
Expand Down Expand Up @@ -528,6 +523,48 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}

#[proc_macro_derive(Oneof, attributes(prost))]
pub fn oneof(input: TokenStream) -> TokenStream {
try_oneof(input).unwrap()
pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_oneof(input.into()).unwrap().into()
}

#[cfg(test)]
mod test {
use crate::{try_message, try_oneof};
use quote::quote;

#[test]
fn test_rejects_colliding_message_fields() {
let output = try_message(quote!(
struct Invalid {
#[prost(bool, tag = "1")]
a: bool,
#[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
b: Option<super::Whatever>,
}
));
assert!(output.is_err());
assert_eq!(
output.unwrap_err().to_string(),
"message Invalid has multiple fields with tag 1"
);
}

#[test]
fn test_rejects_colliding_oneof_variants() {
let output = try_oneof(quote!(
pub enum Invalid {
#[prost(bool, tag = "1")]
A(bool),
#[prost(bool, tag = "3")]
B(bool),
#[prost(bool, tag = "1")]
C(bool),
}
));
assert!(output.is_err());
assert_eq!(
output.unwrap_err().to_string(),
"invalid oneof Invalid: multiple variants have tag 1"
);
}
}

0 comments on commit 4c00d4e

Please sign in to comment.