Skip to content

Commit

Permalink
Fix compile fail message for mixed #[diesel(embed, serialize_as = _)]
Browse files Browse the repository at this point in the history
  • Loading branch information
weiznich committed Dec 3, 2021
1 parent 24fcf3c commit dba9212
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error: `#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`
--> $DIR/embed_and_serialize_as_cannot_be_mixed.rs:22:36
--> tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.rs:22:13
|
22 | #[diesel(embed, serialize_as = SomeType)]
| ^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
61 changes: 57 additions & 4 deletions diesel_derives/src/attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::{Display, Formatter};

use proc_macro2::{Span, TokenStream};
use proc_macro_error::ResultExt;
use quote::spanned::Spanned;
use quote::ToTokens;
use syn::parse::discouraged::Speculative;
use syn::parse::{Parse, ParseStream, Parser, Result};
Expand All @@ -13,6 +14,12 @@ use deprecated::ParseDeprecated;
use parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
use util::{parse_eq, parse_paren, unknown_attribute};

pub struct AttributeSpanWrapper<T> {
pub item: T,
pub attribute_span: Span,
pub ident_span: Span,
}

pub enum FieldAttr {
Embed(Ident),

Expand Down Expand Up @@ -96,6 +103,18 @@ impl Parse for FieldAttr {
}
}

impl Spanned for FieldAttr {
fn __span(&self) -> Span {
match self {
FieldAttr::Embed(ident)
| FieldAttr::ColumnName(ident, _)
| FieldAttr::SqlType(ident, _)
| FieldAttr::SerializeAs(ident, _)
| FieldAttr::DeserializeAs(ident, _) => ident.span(),
}
}
}

#[allow(clippy::large_enum_variant)]
pub enum StructAttr {
Aggregate(Ident),
Expand Down Expand Up @@ -146,15 +165,46 @@ impl Parse for StructAttr {
}
}

pub fn parse_attributes<T: Parse + ParseDeprecated>(attrs: &[Attribute]) -> Vec<T> {
impl Spanned for StructAttr {
fn __span(&self) -> Span {
match self {
StructAttr::Aggregate(ident)
| StructAttr::NotSized(ident)
| StructAttr::ForeignDerive(ident)
| StructAttr::TableName(ident, _)
| StructAttr::SqlType(ident, _)
| StructAttr::TreatNoneAsDefaultValue(ident, _)
| StructAttr::TreatNoneAsNull(ident, _)
| StructAttr::BelongsTo(ident, _)
| StructAttr::MysqlType(ident, _)
| StructAttr::SqliteType(ident, _)
| StructAttr::PostgresType(ident, _)
| StructAttr::PrimaryKey(ident, _) => ident.span(),
}
}
}

pub fn parse_attributes<T>(attrs: &[Attribute]) -> Vec<AttributeSpanWrapper<T>>
where
T: Parse + ParseDeprecated + Spanned,
{
use syn::spanned::Spanned;

attrs
.iter()
.flat_map(|attr| {
if attr.path.is_ident("diesel") {
attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)
.unwrap_or_abort()
.into_iter()
.map(|a| AttributeSpanWrapper {
ident_span: a.span(),
item: a,
attribute_span: attr.tokens.span(),
})
.collect::<Vec<_>>()
} else {
let mut p = Punctuated::new();
let mut p = Vec::new();
let Attribute { path, tokens, .. } = attr;
let ident = path.get_ident().map(|f| f.to_string());

Expand All @@ -166,10 +216,13 @@ pub fn parse_attributes<T: Parse + ParseDeprecated>(attrs: &[Attribute]) -> Vec<
let value = Parser::parse(T::parse_deprecated, ts).unwrap_or_abort();

if let Some(value) = value {
p.push_value(value);
p.push(AttributeSpanWrapper {
ident_span: value.span(),
item: value,
attribute_span: attr.tokens.span(),
});
}
}

p
}
})
Expand Down
86 changes: 66 additions & 20 deletions diesel_derives/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ use syn::{Field as SynField, Ident, Index, Type};

use attrs::{parse_attributes, FieldAttr, SqlIdentifier};

use crate::attrs::AttributeSpanWrapper;

pub struct Field {
pub ty: Type,
pub span: Span,
pub name: FieldName,
column_name: Option<SqlIdentifier>,
pub sql_type: Option<Type>,
pub serialize_as: Option<Type>,
pub deserialize_as: Option<Type>,
pub embed: bool,
column_name: Option<AttributeSpanWrapper<SqlIdentifier>>,
pub sql_type: Option<AttributeSpanWrapper<Type>>,
pub serialize_as: Option<AttributeSpanWrapper<Type>>,
pub deserialize_as: Option<AttributeSpanWrapper<Type>>,
pub embed: Option<AttributeSpanWrapper<bool>>,
}

impl Field {
Expand All @@ -26,15 +28,47 @@ impl Field {
let mut sql_type = None;
let mut serialize_as = None;
let mut deserialize_as = None;
let mut embed = false;
let mut embed = None;

for attr in parse_attributes(attrs) {
match attr {
FieldAttr::ColumnName(_, value) => column_name = Some(value),
FieldAttr::SqlType(_, value) => sql_type = Some(value),
FieldAttr::SerializeAs(_, value) => serialize_as = Some(value),
FieldAttr::DeserializeAs(_, value) => deserialize_as = Some(value),
FieldAttr::Embed(_) => embed = true,
let attribute_span = attr.attribute_span;
let ident_span = attr.ident_span;
match attr.item {
FieldAttr::ColumnName(_, value) => {
column_name = Some(AttributeSpanWrapper {
item: value,
attribute_span,
ident_span,
})
}
FieldAttr::SqlType(_, value) => {
sql_type = Some(AttributeSpanWrapper {
item: value,
attribute_span,
ident_span,
})
}
FieldAttr::SerializeAs(_, value) => {
serialize_as = Some(AttributeSpanWrapper {
item: value,
attribute_span,
ident_span,
})
}
FieldAttr::DeserializeAs(_, value) => {
deserialize_as = Some(AttributeSpanWrapper {
item: value,
attribute_span,
ident_span,
})
}
FieldAttr::Embed(_) => {
embed = Some(AttributeSpanWrapper {
item: true,
attribute_span,
ident_span,
})
}
}
}

Expand All @@ -43,9 +77,14 @@ impl Field {
None => FieldName::Unnamed(index.into()),
};

let span = match name {
FieldName::Named(ref ident) => ident.span(),
FieldName::Unnamed(_) => ty.span(),
};

Self {
ty: ty.clone(),
span: field.span(),
span,
name,
column_name,
sql_type,
Expand All @@ -56,24 +95,31 @@ impl Field {
}

pub fn column_name(&self) -> SqlIdentifier {
self.column_name.clone().unwrap_or_else(|| match self.name {
FieldName::Named(ref x) => x.into(),
FieldName::Unnamed(ref x) => {
abort!(
self.column_name
.as_ref()
.map(|a| a.item.clone())
.unwrap_or_else(|| match self.name {
FieldName::Named(ref x) => x.into(),
FieldName::Unnamed(ref x) => {
abort!(
x,
"All fields of tuple structs must be annotated with `#[diesel(column_name)]`"
);
}
})
}
})
}

pub fn ty_for_deserialize(&self) -> &Type {
if let Some(value) = &self.deserialize_as {
if let Some(AttributeSpanWrapper { item: value, .. }) = &self.deserialize_as {
value
} else {
&self.ty
}
}

pub(crate) fn embed(&self) -> bool {
self.embed.as_ref().map(|a| a.item).unwrap_or(false)
}
}

pub enum FieldName {
Expand Down
10 changes: 6 additions & 4 deletions diesel_derives/src/insertable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use field::Field;
use model::Model;
use util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod};

use crate::attrs::AttributeSpanWrapper;

pub fn derive(item: DeriveInput) -> TokenStream {
let model = Model::from_item(&item, false);

Expand All @@ -25,7 +27,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
let mut ref_field_assign = Vec::with_capacity(model.fields().len());

for field in model.fields() {
match (field.serialize_as.as_ref(), field.embed) {
match (field.serialize_as.as_ref(), field.embed()) {
(None, true) => {
direct_field_ty.push(field_ty_embed(field, None));
direct_field_assign.push(field_expr_embed(field, None));
Expand Down Expand Up @@ -58,7 +60,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
treat_none_as_default_value,
));
}
(Some(ty), false) => {
(Some(AttributeSpanWrapper { item: ty, .. }), false) => {
direct_field_ty.push(field_ty_serialize_as(
field,
table_name,
Expand All @@ -74,9 +76,9 @@ pub fn derive(item: DeriveInput) -> TokenStream {

generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs
}
(Some(ty), true) => {
(Some(AttributeSpanWrapper { attribute_span, .. }), true) => {
abort!(
ty,
attribute_span,
"`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`"
)
}
Expand Down
2 changes: 1 addition & 1 deletion diesel_derives/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Model {
let mut postgres_type = None;

for attr in parse_attributes(attrs) {
match attr {
match attr.item {
StructAttr::SqlType(_, value) => sql_types.push(value),
StructAttr::TableName(_, value) => table_name = Some(value),
StructAttr::PrimaryKey(_, keys) => {
Expand Down
8 changes: 5 additions & 3 deletions diesel_derives/src/queryable_by_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use field::{Field, FieldName};
use model::Model;
use util::wrap_in_dummy_mod;

use crate::attrs::AttributeSpanWrapper;

pub fn derive(item: DeriveInput) -> TokenStream {
let model = Model::from_item(&item, false);

Expand All @@ -15,7 +17,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
let initial_field_expr = model.fields().iter().map(|f| {
let field_ty = &f.ty;

if f.embed {
if f.embed() {
quote!(<#field_ty as QueryableByName<__DB>>::build(row)?)
} else {
let deserialize_ty = f.ty_for_deserialize();
Expand All @@ -39,7 +41,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
for field in model.fields() {
let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
let field_ty = field.ty_for_deserialize();
if field.embed {
if field.embed() {
where_clause
.predicates
.push(parse_quote!(#field_ty: QueryableByName<__DB>));
Expand Down Expand Up @@ -88,7 +90,7 @@ fn sql_type(field: &Field, model: &Model) -> Type {
let table_name = model.table_name();

match field.sql_type {
Some(ref st) => st.clone(),
Some(AttributeSpanWrapper { item: ref st, .. }) => st.clone(),
None => {
if model.has_table_name_attribute() {
let column_name = field.column_name();
Expand Down
6 changes: 3 additions & 3 deletions diesel_derives/src/selectable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
.params
.push(parse_quote!(__DB: diesel::backend::Backend));

for embed_field in model.fields().iter().filter(|f| f.embed) {
for embed_field in model.fields().iter().filter(|f| f.embed()) {
let embed_ty = &embed_field.ty;
generics
.where_clause
Expand Down Expand Up @@ -48,7 +48,7 @@ pub fn derive(item: DeriveInput) -> TokenStream {
}

fn field_column_ty(field: &Field, model: &Model) -> TokenStream {
if field.embed {
if field.embed() {
let embed_ty = &field.ty;
quote!(<#embed_ty as Selectable<__DB>>::SelectExpression)
} else {
Expand All @@ -59,7 +59,7 @@ fn field_column_ty(field: &Field, model: &Model) -> TokenStream {
}

fn field_column_inst(field: &Field, model: &Model) -> TokenStream {
if field.embed {
if field.embed() {
let embed_ty = &field.ty;
quote!(<#embed_ty as Selectable<__DB>>::construct_selection())
} else {
Expand Down

0 comments on commit dba9212

Please sign in to comment.