diff --git a/diesel_compile_tests/tests/compile-fail/queryable_by_name_requires_table_name_or_sql_type_annotation.rs b/diesel_compile_tests/tests/compile-fail/queryable_by_name_requires_table_name_or_sql_type_annotation.rs deleted file mode 100644 index d1187c3a3bfd..000000000000 --- a/diesel_compile_tests/tests/compile-fail/queryable_by_name_requires_table_name_or_sql_type_annotation.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[macro_use] extern crate diesel; - -#[derive(QueryableByName)] -//~^ ERROR Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` -struct Foo { - a: i32, -} - -fn main() {} diff --git a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.rs b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.rs new file mode 100644 index 000000000000..935000545f30 --- /dev/null +++ b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.rs @@ -0,0 +1,12 @@ +#[macro_use] extern crate diesel; + +#[derive(QueryableByName)] +struct Foo { + foo: i32, + bar: String, +} + +#[derive(QueryableByName)] +struct Bar(i32, String); + +fn main() {} diff --git a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr new file mode 100644 index 000000000000..9197218fdd09 --- /dev/null +++ b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr @@ -0,0 +1,46 @@ +error: Cannot determine the SQL type of foo + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:5:5 + | +5 | foo: i32, + | ^^^ + | + = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` + +error: Cannot determine the SQL type of bar + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:6:5 + | +6 | bar: String, + | ^^^ + | + = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` + +error: All fields of tuple structs must be annotated with `#[column_name]` + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:12 + | +10 | struct Bar(i32, String); + | ^^^ + +error: Cannot determine the SQL type of field + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:12 + | +10 | struct Bar(i32, String); + | ^^^ + | + = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` + +error: All fields of tuple structs must be annotated with `#[column_name]` + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 + | +10 | struct Bar(i32, String); + | ^^^^^^ + +error: Cannot determine the SQL type of field + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 + | +10 | struct Bar(i32, String); + | ^^^^^^ + | + = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` + +error: aborting due to 6 previous errors + diff --git a/diesel_derives/src/attr.rs b/diesel_derives/src/attr.rs index de3cc43e9efa..5725c83dacfb 100644 --- a/diesel_derives/src/attr.rs +++ b/diesel_derives/src/attr.rs @@ -49,18 +49,6 @@ impl Attr { ) } - pub fn sql_type(&self) -> Option<&syn::Ty> { - self.sql_type.as_ref() - } - - pub fn has_flag(&self, flag: &T) -> bool - where - T: ?Sized, - syn::Ident: PartialEq, - { - self.flags.iter().any(|f| f == flag) - } - fn field_kind(&self) -> &str { if is_option_ty(&self.ty) { "option" diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs index f149bdb28528..91c7a3b7aadf 100644 --- a/diesel_derives/src/lib.rs +++ b/diesel_derives/src/lib.rs @@ -33,7 +33,6 @@ mod insertable; mod model; mod query_id; mod queryable; -mod queryable_by_name; mod sql_type; mod util; @@ -45,11 +44,6 @@ pub fn derive_queryable(input: TokenStream) -> TokenStream { expand_derive(input, queryable::derive_queryable) } -#[proc_macro_derive(QueryableByName, attributes(table_name, column_name, sql_type, diesel))] -pub fn derive_queryable_by_name(input: TokenStream) -> TokenStream { - expand_derive(input, queryable_by_name::derive) -} - #[proc_macro_derive(Insertable, attributes(table_name, column_name))] pub fn derive_insertable(input: TokenStream) -> TokenStream { expand_derive(input, insertable::derive_insertable) diff --git a/diesel_derives/src/queryable_by_name.rs b/diesel_derives/src/queryable_by_name.rs deleted file mode 100644 index 6f757aaaabd8..000000000000 --- a/diesel_derives/src/queryable_by_name.rs +++ /dev/null @@ -1,76 +0,0 @@ -use quote::Tokens; -use syn; - -use attr::Attr; -use model::Model; -use util::wrap_item_in_const; - -pub fn derive(item: syn::DeriveInput) -> Tokens { - let model = t!(Model::from_item(&item, "QueryableByName")); - - let generics = syn::aster::from_generics(model.generics.clone()) - .ty_param_id("__DB") - .build(); - let struct_ty = &model.ty; - - let attr_where_clause = model.attrs.iter().map(|attr| { - let attr_ty = &attr.ty; - if attr.has_flag("embed") { - quote!(#attr_ty: diesel::deserialize::QueryableByName<__DB>,) - } else { - let st = sql_type(attr, &model); - quote!( - #attr_ty: diesel::deserialize::FromSql<#st, __DB>, - ) - } - }); - - let build_expr = build_expr_for_model(&model); - - wrap_item_in_const( - model.dummy_const_name("QUERYABLE_BY_NAME"), - quote!( - impl#generics diesel::deserialize::QueryableByName<__DB> for #struct_ty where - __DB: diesel::backend::Backend, - #(#attr_where_clause)* - { - fn build<__R: diesel::row::NamedRow<__DB>>(row: &__R) -> diesel::deserialize::Result { - Ok(#build_expr) - } - } - ), - ) -} - -fn build_expr_for_model(model: &Model) -> Tokens { - let attr_exprs = model.attrs.iter().map(|attr| { - let name = attr.field_name(); - if attr.has_flag("embed") { - quote!(#name: diesel::deserialize::QueryableByName::build(row)?) - } else { - let column_name = attr.column_name(); - let st = sql_type(attr, model); - quote!(#name: diesel::row::NamedRow::get::<#st, _>(row, stringify!(#column_name))?) - } - }); - - quote!(Self { - #(#attr_exprs,)* - }) -} - -fn sql_type(attr: &Attr, model: &Model) -> Tokens { - let table_name = model.table_name(); - let column_name = attr.column_name(); - - match attr.sql_type() { - Some(st) => quote!(#st), - None => if model.has_table_name_annotation() { - quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>) - } else { - quote!(compile_error!( - "Your struct must either be annotated with `#[table_name = \"foo\"]` or have all of its fields annotated with `#[sql_type = \"Integer\"]`" - )) - }, - } -} diff --git a/diesel_derives/tests/tests.rs b/diesel_derives/tests/tests.rs index 5d8b040b5e47..1800a7cfd425 100644 --- a/diesel_derives/tests/tests.rs +++ b/diesel_derives/tests/tests.rs @@ -6,6 +6,5 @@ extern crate diesel; extern crate diesel_derives; mod queryable; -mod queryable_by_name; mod associations; mod test_helpers; diff --git a/diesel_derives2/src/as_changeset.rs b/diesel_derives2/src/as_changeset.rs index 277dd60dfa78..5c13f9b1de9a 100644 --- a/diesel_derives2/src/as_changeset.rs +++ b/diesel_derives2/src/as_changeset.rs @@ -86,7 +86,7 @@ fn field_changeset_expr( table_name: syn::Ident, treat_none_as_null: bool, ) -> syn::Expr { - let field_access = &field.name; + let field_access = field.name.access(); let column_name = field.column_name(); if !treat_none_as_null && is_option_ty(&field.ty) { parse_quote!(self#field_access.as_ref().map(|x| #table_name::#column_name.eq(x))) diff --git a/diesel_derives2/src/field.rs b/diesel_derives2/src/field.rs index df9ac47c81d9..6fdea942f115 100644 --- a/diesel_derives2/src/field.rs +++ b/diesel_derives2/src/field.rs @@ -1,15 +1,18 @@ use proc_macro2::Span; use quote; -use syn; use syn::spanned::Spanned; +use syn; -use diagnostic_shim::*; use meta::*; +use util::*; pub struct Field { pub ty: syn::Type, pub name: FieldName, + pub span: Span, + pub sql_type: Option, column_name_from_attribute: Option, + flags: MetaItem, } impl Field { @@ -17,18 +20,30 @@ impl Field { let column_name_from_attribute = MetaItem::with_name(&field.attrs, "column_name").map(|m| m.expect_ident_value()); let name = match field.ident { - Some(x) => FieldName::Named(x), + Some(mut x) => { + // https://github.com/rust-lang/rust/issues/47983#issuecomment-362817105 + x.span = fix_span(x.span, Span::call_site()); + FieldName::Named(x) + } None => FieldName::Unnamed(syn::Index { index: index as u32, // https://github.com/rust-lang/rust/issues/47312 span: Span::call_site(), }), }; + let sql_type = MetaItem::with_name(&field.attrs, "sql_type") + .and_then(|m| m.ty_value().map_err(Diagnostic::emit).ok()); + let flags = MetaItem::with_name(&field.attrs, "diesel") + .unwrap_or_else(|| MetaItem::empty("diesel")); + let span = field.span(); Self { ty: field.ty.clone(), column_name_from_attribute, name, + sql_type, + flags, + span, } } @@ -37,8 +52,7 @@ impl Field { .unwrap_or_else(|| match self.name { FieldName::Named(x) => x, _ => { - self.ty - .span() + self.span .error( "All fields of tuple structs must be annotated with `#[column_name]`", ) @@ -47,6 +61,10 @@ impl Field { } }) } + + pub fn has_flag(&self, flag: &str) -> bool { + self.flags.has_flag(flag) + } } pub enum FieldName { @@ -54,15 +72,29 @@ pub enum FieldName { Unnamed(syn::Index), } +impl FieldName { + pub fn assign(&self, expr: syn::Expr) -> syn::FieldValue { + let span = self.span(); + // Parens are to work around https://github.com/rust-lang/rust/issues/47311 + let tokens = quote_spanned!(span=> #self: (#expr)); + parse_quote!(#tokens) + } + + pub fn access(&self) -> quote::Tokens { + let span = self.span(); + quote_spanned!(span=> .#self) + } + + pub fn span(&self) -> Span { + match *self { + FieldName::Named(x) => x.span, + FieldName::Unnamed(ref x) => x.span, + } + } +} + impl quote::ToTokens for FieldName { fn to_tokens(&self, tokens: &mut quote::Tokens) { - use proc_macro2::{Spacing, TokenNode, TokenTree}; - - // https://github.com/rust-lang/rust/issues/47312 - tokens.append(TokenTree { - span: Span::call_site(), - kind: TokenNode::Op('.', Spacing::Alone), - }); match *self { FieldName::Named(x) => x.to_tokens(tokens), FieldName::Unnamed(ref x) => x.to_tokens(tokens), diff --git a/diesel_derives2/src/identifiable.rs b/diesel_derives2/src/identifiable.rs index e86492d71f9a..4602ab6a3413 100644 --- a/diesel_derives2/src/identifiable.rs +++ b/diesel_derives2/src/identifiable.rs @@ -18,7 +18,7 @@ pub fn derive(item: syn::DeriveInput) -> Result { .primary_key_names .iter() .filter_map(|&pk| model.find_column(pk).map_err(Diagnostic::emit).ok()) - .map(|f| (&f.ty, &f.name)) + .map(|f| (&f.ty, f.name.access())) .unzip(); Ok(wrap_in_dummy_mod( diff --git a/diesel_derives2/src/lib.rs b/diesel_derives2/src/lib.rs index e1241cf857bf..f4dcf33f3e98 100644 --- a/diesel_derives2/src/lib.rs +++ b/diesel_derives2/src/lib.rs @@ -28,6 +28,7 @@ mod util; mod as_changeset; mod identifiable; +mod queryable_by_name; use diagnostic_shim::*; @@ -42,6 +43,11 @@ pub fn derive_identifiable(input: TokenStream) -> TokenStream { expand_derive(input, identifiable::derive) } +#[proc_macro_derive(QueryableByName, attributes(table_name, column_name, sql_type, diesel))] +pub fn derive_queryable_by_name(input: TokenStream) -> TokenStream { + expand_derive(input, queryable_by_name::derive) +} + fn expand_derive( input: TokenStream, f: fn(syn::DeriveInput) -> Result, diff --git a/diesel_derives2/src/meta.rs b/diesel_derives2/src/meta.rs index 56680171f13b..607c9a9ef5f9 100644 --- a/diesel_derives2/src/meta.rs +++ b/diesel_derives2/src/meta.rs @@ -2,7 +2,7 @@ use proc_macro2::Span; use syn; use syn::spanned::Spanned; -use diagnostic_shim::*; +use util::*; pub struct MetaItem { // Due to https://github.com/rust-lang/rust/issues/47941 @@ -20,6 +20,17 @@ impl MetaItem { .map(|(pound_span, meta)| Self { pound_span, meta }) } + pub fn empty(name: &str) -> Self { + Self { + pound_span: Span::call_site(), + meta: syn::Meta::List(syn::MetaList { + ident: name.into(), + paren_token: Default::default(), + nested: Default::default(), + }), + } + } + pub fn nested_item<'a>(&self, name: &'a str) -> Result { self.nested().and_then(|mut i| { i.nth(0).ok_or_else(|| { @@ -101,6 +112,21 @@ impl MetaItem { } } + pub fn has_flag(&self, flag: &str) -> bool { + self.nested() + .map(|mut n| n.any(|m| m.expect_word() == flag)) + .unwrap_or_else(|e| { + e.emit(); + false + }) + } + + pub fn ty_value(&self) -> Result { + let mut str = self.lit_str_value()?.clone(); + str.span = self.span_or_pound_token(str.span); + str.parse().map_err(|_| str.span.error("Invalid Rust type")) + } + fn expect_str_value(&self) -> String { self.str_value().unwrap_or_else(|e| { e.emit(); @@ -113,6 +139,10 @@ impl MetaItem { } fn str_value(&self) -> Result { + self.lit_str_value().map(syn::LitStr::value) + } + + fn lit_str_value(&self) -> Result<&syn::LitStr, Diagnostic> { use syn::Meta::*; use syn::MetaNameValue; use syn::Lit::*; @@ -120,7 +150,7 @@ impl MetaItem { match self.meta { NameValue(MetaNameValue { lit: Str(ref s), .. - }) => Ok(s.value()), + }) => Ok(s), _ => Err(self.span().error(format!( "`{0}` must be in the form `{0} = \"value\"`", self.name() @@ -147,12 +177,7 @@ impl MetaItem { /// https://github.com/rust-lang/rust/issues/47941, /// returns the span of the pound token fn span_or_pound_token(&self, span: Span) -> Span { - let bad_span_debug = "Span(Span { lo: BytePos(0), hi: BytePos(0), ctxt: #0 })"; - if format!("{:?}", span) == bad_span_debug { - self.pound_span - } else { - span - } + fix_span(span, self.pound_span) } } diff --git a/diesel_derives2/src/model.rs b/diesel_derives2/src/model.rs index fad6afb65447..8d9fc1d85727 100644 --- a/diesel_derives2/src/model.rs +++ b/diesel_derives2/src/model.rs @@ -56,6 +56,10 @@ impl Model { .error(format!("No field with column name {}", column_name)) }) } + + pub fn has_table_name_attribute(&self) -> bool { + self.table_name_from_attribute.is_some() + } } pub fn camel_to_snake(name: &str) -> String { diff --git a/diesel_derives2/src/queryable_by_name.rs b/diesel_derives2/src/queryable_by_name.rs new file mode 100644 index 000000000000..f319368c0b6e --- /dev/null +++ b/diesel_derives2/src/queryable_by_name.rs @@ -0,0 +1,95 @@ +use syn; +use quote; + +use field::*; +use model::*; +use util::*; + +pub fn derive(item: syn::DeriveInput) -> Result { + let model = Model::from_item(&item)?; + + let struct_name = item.ident; + let field_expr = model.fields().iter().map(|f| field_expr(f, &model)); + + let (_, ty_generics, ..) = item.generics.split_for_impl(); + let mut generics = item.generics.clone(); + generics + .params + .push(parse_quote!(__DB: diesel::backend::Backend)); + + for field in model.fields() { + let where_clause = generics.where_clause.get_or_insert(parse_quote!(where)); + let field_ty = &field.ty; + if field.has_flag("embed") { + where_clause + .predicates + .push(parse_quote!(#field_ty: QueryableByName<__DB>)); + } else { + let st = sql_type(field, &model); + where_clause + .predicates + .push(parse_quote!(#field_ty: diesel::deserialize::FromSql<#st, __DB>)); + } + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(wrap_in_dummy_mod( + model.dummy_mod_name("queryable_by_name"), + quote! { + use self::diesel::deserialize::{self, QueryableByName}; + use self::diesel::row::NamedRow; + + impl #impl_generics QueryableByName<__DB> + for #struct_name #ty_generics + #where_clause + { + fn build<__R: NamedRow<__DB>>(row: &__R) -> deserialize::Result { + std::result::Result::Ok(Self { + #(#field_expr,)* + }) + } + } + }, + )) +} + +fn field_expr(field: &Field, model: &Model) -> syn::FieldValue { + if field.has_flag("embed") { + field + .name + .assign(parse_quote!(QueryableByName::build(row)?)) + } else { + let column_name = field.column_name(); + let st = sql_type(field, model); + field + .name + .assign(parse_quote!(row.get::<#st, _>(stringify!(#column_name))?)) + } +} + +fn sql_type(field: &Field, model: &Model) -> syn::Type { + let table_name = model.table_name(); + let column_name = field.column_name(); + + match field.sql_type { + Some(ref st) => st.clone(), + None => if model.has_table_name_attribute() { + parse_quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>) + } else { + let field_name = match field.name { + FieldName::Named(ref x) => x.as_ref(), + _ => "field", + }; + field + .span + .error(format!("Cannot determine the SQL type of {}", field_name)) + .help( + "Your struct must either be annotated with `#[table_name = \"foo\"]` \ + or have all of its fields annotated with `#[sql_type = \"Integer\"]`", + ) + .emit(); + parse_quote!(()) + }, + } +} diff --git a/diesel_derives2/src/util.rs b/diesel_derives2/src/util.rs index 76806ba4cfb2..da0a172dec7b 100644 --- a/diesel_derives2/src/util.rs +++ b/diesel_derives2/src/util.rs @@ -1,5 +1,6 @@ use syn::*; use quote::Tokens; +use proc_macro2::Span; pub use diagnostic_shim::*; @@ -44,3 +45,12 @@ fn option_ty_arg(ty: &Type) -> Option<&Type> { _ => None, } } + +pub fn fix_span(maybe_bad_span: Span, fallback: Span) -> Span { + let bad_span_debug = "Span(Span { lo: BytePos(0), hi: BytePos(0), ctxt: #0 })"; + if format!("{:?}", maybe_bad_span) == bad_span_debug { + fallback + } else { + maybe_bad_span + } +} diff --git a/diesel_derives/tests/queryable_by_name.rs b/diesel_derives2/tests/queryable_by_name.rs similarity index 95% rename from diesel_derives/tests/queryable_by_name.rs rename to diesel_derives2/tests/queryable_by_name.rs index fb8cd06907b2..0b53e5188ade 100644 --- a/diesel_derives/tests/queryable_by_name.rs +++ b/diesel_derives2/tests/queryable_by_name.rs @@ -1,6 +1,6 @@ use diesel::*; -use test_helpers::connection; +use helpers::connection; #[cfg(feature = "mysql")] type IntSql = ::diesel::sql_types::BigInt; @@ -38,7 +38,10 @@ fn named_struct_definition() { fn tuple_struct() { #[derive(Debug, Clone, Copy, PartialEq, Eq, QueryableByName)] #[table_name = "my_structs"] - struct MyStruct(#[column_name(foo)] IntRust, #[column_name(bar)] IntRust); + struct MyStruct( + #[column_name = "foo"] IntRust, + #[column_name = "bar"] IntRust, + ); let conn = connection(); let data = sql_query("SELECT 1 AS foo, 2 AS bar").get_result(&conn); diff --git a/diesel_derives2/tests/tests.rs b/diesel_derives2/tests/tests.rs index e6ec84be39fc..f0b1b0c3b906 100644 --- a/diesel_derives2/tests/tests.rs +++ b/diesel_derives2/tests/tests.rs @@ -12,3 +12,4 @@ mod schema; mod as_changeset; mod identifiable; +mod queryable_by_name;