From dba9212ffcec0ed3ea251b98b1f055e91997bbd1 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 3 Dec 2021 17:32:01 +0100 Subject: [PATCH] Fix compile fail message for mixed `#[diesel(embed, serialize_as = _)]` --- ...ed_and_serialize_as_cannot_be_mixed.stderr | 4 +- diesel_derives/src/attrs.rs | 61 ++++++++++++- diesel_derives/src/field.rs | 86 ++++++++++++++----- diesel_derives/src/insertable.rs | 10 ++- diesel_derives/src/model.rs | 2 +- diesel_derives/src/queryable_by_name.rs | 8 +- diesel_derives/src/selectable.rs | 6 +- 7 files changed, 140 insertions(+), 37 deletions(-) diff --git a/diesel_compile_tests/tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.stderr b/diesel_compile_tests/tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.stderr index 76a7e616691f..f0b777fdb085 100644 --- a/diesel_compile_tests/tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.stderr +++ b/diesel_compile_tests/tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.stderr @@ -1,5 +1,5 @@ error: `#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]` - --> $DIR/embed_and_serialize_as_cannot_be_mixed.rs:22:36 + --> tests/fail/derive/embed_and_serialize_as_cannot_be_mixed.rs:22:13 | 22 | #[diesel(embed, serialize_as = SomeType)] - | ^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/diesel_derives/src/attrs.rs b/diesel_derives/src/attrs.rs index ff298aaadf5e..dc27b0f52a29 100644 --- a/diesel_derives/src/attrs.rs +++ b/diesel_derives/src/attrs.rs @@ -2,6 +2,7 @@ use std::fmt::{Display, Formatter}; use proc_macro2::{Span, TokenStream}; use proc_macro_error::ResultExt; +use quote::spanned::Spanned; use quote::ToTokens; use syn::parse::discouraged::Speculative; use syn::parse::{Parse, ParseStream, Parser, Result}; @@ -13,6 +14,12 @@ use deprecated::ParseDeprecated; use parsers::{BelongsTo, MysqlType, PostgresType, SqliteType}; use util::{parse_eq, parse_paren, unknown_attribute}; +pub struct AttributeSpanWrapper { + pub item: T, + pub attribute_span: Span, + pub ident_span: Span, +} + pub enum FieldAttr { Embed(Ident), @@ -96,6 +103,18 @@ impl Parse for FieldAttr { } } +impl Spanned for FieldAttr { + fn __span(&self) -> Span { + match self { + FieldAttr::Embed(ident) + | FieldAttr::ColumnName(ident, _) + | FieldAttr::SqlType(ident, _) + | FieldAttr::SerializeAs(ident, _) + | FieldAttr::DeserializeAs(ident, _) => ident.span(), + } + } +} + #[allow(clippy::large_enum_variant)] pub enum StructAttr { Aggregate(Ident), @@ -146,15 +165,46 @@ impl Parse for StructAttr { } } -pub fn parse_attributes(attrs: &[Attribute]) -> Vec { +impl Spanned for StructAttr { + fn __span(&self) -> Span { + match self { + StructAttr::Aggregate(ident) + | StructAttr::NotSized(ident) + | StructAttr::ForeignDerive(ident) + | StructAttr::TableName(ident, _) + | StructAttr::SqlType(ident, _) + | StructAttr::TreatNoneAsDefaultValue(ident, _) + | StructAttr::TreatNoneAsNull(ident, _) + | StructAttr::BelongsTo(ident, _) + | StructAttr::MysqlType(ident, _) + | StructAttr::SqliteType(ident, _) + | StructAttr::PostgresType(ident, _) + | StructAttr::PrimaryKey(ident, _) => ident.span(), + } + } +} + +pub fn parse_attributes(attrs: &[Attribute]) -> Vec> +where + T: Parse + ParseDeprecated + Spanned, +{ + use syn::spanned::Spanned; + attrs .iter() .flat_map(|attr| { if attr.path.is_ident("diesel") { attr.parse_args_with(Punctuated::::parse_terminated) .unwrap_or_abort() + .into_iter() + .map(|a| AttributeSpanWrapper { + ident_span: a.span(), + item: a, + attribute_span: attr.tokens.span(), + }) + .collect::>() } else { - let mut p = Punctuated::new(); + let mut p = Vec::new(); let Attribute { path, tokens, .. } = attr; let ident = path.get_ident().map(|f| f.to_string()); @@ -166,10 +216,13 @@ pub fn parse_attributes(attrs: &[Attribute]) -> Vec< let value = Parser::parse(T::parse_deprecated, ts).unwrap_or_abort(); if let Some(value) = value { - p.push_value(value); + p.push(AttributeSpanWrapper { + ident_span: value.span(), + item: value, + attribute_span: attr.tokens.span(), + }); } } - p } }) diff --git a/diesel_derives/src/field.rs b/diesel_derives/src/field.rs index a6fc9723da59..6c387e0752c4 100644 --- a/diesel_derives/src/field.rs +++ b/diesel_derives/src/field.rs @@ -5,15 +5,17 @@ use syn::{Field as SynField, Ident, Index, Type}; use attrs::{parse_attributes, FieldAttr, SqlIdentifier}; +use crate::attrs::AttributeSpanWrapper; + pub struct Field { pub ty: Type, pub span: Span, pub name: FieldName, - column_name: Option, - pub sql_type: Option, - pub serialize_as: Option, - pub deserialize_as: Option, - pub embed: bool, + column_name: Option>, + pub sql_type: Option>, + pub serialize_as: Option>, + pub deserialize_as: Option>, + pub embed: Option>, } impl Field { @@ -26,15 +28,47 @@ impl Field { let mut sql_type = None; let mut serialize_as = None; let mut deserialize_as = None; - let mut embed = false; + let mut embed = None; for attr in parse_attributes(attrs) { - match attr { - FieldAttr::ColumnName(_, value) => column_name = Some(value), - FieldAttr::SqlType(_, value) => sql_type = Some(value), - FieldAttr::SerializeAs(_, value) => serialize_as = Some(value), - FieldAttr::DeserializeAs(_, value) => deserialize_as = Some(value), - FieldAttr::Embed(_) => embed = true, + let attribute_span = attr.attribute_span; + let ident_span = attr.ident_span; + match attr.item { + FieldAttr::ColumnName(_, value) => { + column_name = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } + FieldAttr::SqlType(_, value) => { + sql_type = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } + FieldAttr::SerializeAs(_, value) => { + serialize_as = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } + FieldAttr::DeserializeAs(_, value) => { + deserialize_as = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } + FieldAttr::Embed(_) => { + embed = Some(AttributeSpanWrapper { + item: true, + attribute_span, + ident_span, + }) + } } } @@ -43,9 +77,14 @@ impl Field { None => FieldName::Unnamed(index.into()), }; + let span = match name { + FieldName::Named(ref ident) => ident.span(), + FieldName::Unnamed(_) => ty.span(), + }; + Self { ty: ty.clone(), - span: field.span(), + span, name, column_name, sql_type, @@ -56,24 +95,31 @@ impl Field { } pub fn column_name(&self) -> SqlIdentifier { - self.column_name.clone().unwrap_or_else(|| match self.name { - FieldName::Named(ref x) => x.into(), - FieldName::Unnamed(ref x) => { - abort!( + self.column_name + .as_ref() + .map(|a| a.item.clone()) + .unwrap_or_else(|| match self.name { + FieldName::Named(ref x) => x.into(), + FieldName::Unnamed(ref x) => { + abort!( x, "All fields of tuple structs must be annotated with `#[diesel(column_name)]`" ); - } - }) + } + }) } pub fn ty_for_deserialize(&self) -> &Type { - if let Some(value) = &self.deserialize_as { + if let Some(AttributeSpanWrapper { item: value, .. }) = &self.deserialize_as { value } else { &self.ty } } + + pub(crate) fn embed(&self) -> bool { + self.embed.as_ref().map(|a| a.item).unwrap_or(false) + } } pub enum FieldName { diff --git a/diesel_derives/src/insertable.rs b/diesel_derives/src/insertable.rs index 0c51b83776df..69612ac8850a 100644 --- a/diesel_derives/src/insertable.rs +++ b/diesel_derives/src/insertable.rs @@ -5,6 +5,8 @@ use field::Field; use model::Model; use util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod}; +use crate::attrs::AttributeSpanWrapper; + pub fn derive(item: DeriveInput) -> TokenStream { let model = Model::from_item(&item, false); @@ -25,7 +27,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { let mut ref_field_assign = Vec::with_capacity(model.fields().len()); for field in model.fields() { - match (field.serialize_as.as_ref(), field.embed) { + match (field.serialize_as.as_ref(), field.embed()) { (None, true) => { direct_field_ty.push(field_ty_embed(field, None)); direct_field_assign.push(field_expr_embed(field, None)); @@ -58,7 +60,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { treat_none_as_default_value, )); } - (Some(ty), false) => { + (Some(AttributeSpanWrapper { item: ty, .. }), false) => { direct_field_ty.push(field_ty_serialize_as( field, table_name, @@ -74,9 +76,9 @@ pub fn derive(item: DeriveInput) -> TokenStream { generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs } - (Some(ty), true) => { + (Some(AttributeSpanWrapper { attribute_span, .. }), true) => { abort!( - ty, + attribute_span, "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`" ) } diff --git a/diesel_derives/src/model.rs b/diesel_derives/src/model.rs index 498f25c173ab..39df87e1c41b 100644 --- a/diesel_derives/src/model.rs +++ b/diesel_derives/src/model.rs @@ -63,7 +63,7 @@ impl Model { let mut postgres_type = None; for attr in parse_attributes(attrs) { - match attr { + match attr.item { StructAttr::SqlType(_, value) => sql_types.push(value), StructAttr::TableName(_, value) => table_name = Some(value), StructAttr::PrimaryKey(_, keys) => { diff --git a/diesel_derives/src/queryable_by_name.rs b/diesel_derives/src/queryable_by_name.rs index e90e94377c1b..98fb0a8f6f12 100644 --- a/diesel_derives/src/queryable_by_name.rs +++ b/diesel_derives/src/queryable_by_name.rs @@ -5,6 +5,8 @@ use field::{Field, FieldName}; use model::Model; use util::wrap_in_dummy_mod; +use crate::attrs::AttributeSpanWrapper; + pub fn derive(item: DeriveInput) -> TokenStream { let model = Model::from_item(&item, false); @@ -15,7 +17,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { let initial_field_expr = model.fields().iter().map(|f| { let field_ty = &f.ty; - if f.embed { + if f.embed() { quote!(<#field_ty as QueryableByName<__DB>>::build(row)?) } else { let deserialize_ty = f.ty_for_deserialize(); @@ -39,7 +41,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { for field in model.fields() { let where_clause = generics.where_clause.get_or_insert(parse_quote!(where)); let field_ty = field.ty_for_deserialize(); - if field.embed { + if field.embed() { where_clause .predicates .push(parse_quote!(#field_ty: QueryableByName<__DB>)); @@ -88,7 +90,7 @@ fn sql_type(field: &Field, model: &Model) -> Type { let table_name = model.table_name(); match field.sql_type { - Some(ref st) => st.clone(), + Some(AttributeSpanWrapper { item: ref st, .. }) => st.clone(), None => { if model.has_table_name_attribute() { let column_name = field.column_name(); diff --git a/diesel_derives/src/selectable.rs b/diesel_derives/src/selectable.rs index 549e89f99aa8..48aebb09e184 100644 --- a/diesel_derives/src/selectable.rs +++ b/diesel_derives/src/selectable.rs @@ -15,7 +15,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { .params .push(parse_quote!(__DB: diesel::backend::Backend)); - for embed_field in model.fields().iter().filter(|f| f.embed) { + for embed_field in model.fields().iter().filter(|f| f.embed()) { let embed_ty = &embed_field.ty; generics .where_clause @@ -48,7 +48,7 @@ pub fn derive(item: DeriveInput) -> TokenStream { } fn field_column_ty(field: &Field, model: &Model) -> TokenStream { - if field.embed { + if field.embed() { let embed_ty = &field.ty; quote!(<#embed_ty as Selectable<__DB>>::SelectExpression) } else { @@ -59,7 +59,7 @@ fn field_column_ty(field: &Field, model: &Model) -> TokenStream { } fn field_column_inst(field: &Field, model: &Model) -> TokenStream { - if field.embed { + if field.embed() { let embed_ty = &field.ty; quote!(<#embed_ty as Selectable<__DB>>::construct_selection()) } else {