Skip to content

Commit

Permalink
Allow non-ident's as column names for QueryableByName
Browse files Browse the repository at this point in the history
  • Loading branch information
weiznich authored and pksunkara committed Nov 24, 2021
1 parent 683c8a7 commit 30075e2
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 47 deletions.
5 changes: 4 additions & 1 deletion diesel/src/type_impls/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ mod foreign_impls {
struct U16Proxy(u16);

#[derive(FromSqlRow)]
#[cfg_attr(any(feature = "mysql_backend", feature = "postgres"), derive(AsExpression))]
#[cfg_attr(
any(feature = "mysql_backend", feature = "postgres"),
derive(AsExpression)
)]
#[diesel(foreign_derive)]
#[cfg_attr(feature = "mysql_backend", diesel(sql_type = crate::sql_types::Unsigned<Integer>))]
#[cfg_attr(feature = "postgres_backend", diesel(sql_type = crate::sql_types::Oid))]
Expand Down
7 changes: 0 additions & 7 deletions diesel_compile_tests/tests/fail/derive/bad_column_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,4 @@ struct User3 {
name: String,
}

#[derive(Queryable)]
#[diesel(table_name = users)]
struct User4 {
#[diesel(column_name = "name")]
name: String,
}

fn main() {}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@ error: expected `=`
21 | #[diesel(column_name(another))]
| ^^^^^^^^^

error: expected identifier
error: expected string literal
--> $DIR/bad_column_name.rs:28:28
|
28 | #[diesel(column_name = true)]
| ^^^^

error: expected identifier
--> $DIR/bad_column_name.rs:35:28
|
35 | #[diesel(column_name = "name")]
| ^^^^^^
7 changes: 6 additions & 1 deletion diesel_derives/src/as_changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ pub fn derive(item: DeriveInput) -> TokenStream {
let fields_for_update = model
.fields()
.iter()
.filter(|f| !model.primary_key_names.contains(f.column_name()))
.filter(|f| {
!model
.primary_key_names
.iter()
.any(|p| f.column_name() == *p)
})
.collect::<Vec<_>>();

if fields_for_update.is_empty() {
Expand Down
65 changes: 63 additions & 2 deletions diesel_derives/src/attrs.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::fmt::{Display, Formatter};

use proc_macro2::{Span, TokenStream};
use proc_macro_error::ResultExt;
use quote::ToTokens;
use syn::parse::discouraged::Speculative;
use syn::parse::{Parse, ParseStream, Parser, Result};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{parenthesized, Attribute, Ident, LitBool, Path, Type};
use syn::{parenthesized, Attribute, Ident, LitBool, LitStr, Path, Type};

use deprecated::ParseDeprecated;
use parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
Expand All @@ -11,12 +16,68 @@ use util::{parse_eq, parse_paren, unknown_attribute};
pub enum FieldAttr {
Embed(Ident),

ColumnName(Ident, Ident),
ColumnName(Ident, SqlIdentifier),
SqlType(Ident, Type),
SerializeAs(Ident, Type),
DeserializeAs(Ident, Type),
}

#[derive(Clone)]
pub struct SqlIdentifier {
field_name: String,
span: Span,
}

impl SqlIdentifier {
pub fn span(&self) -> Span {
self.span
}
}

impl ToTokens for SqlIdentifier {
fn to_tokens(&self, tokens: &mut TokenStream) {
Ident::new(&self.field_name, self.span).to_tokens(tokens)
}
}

impl Display for SqlIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.field_name)
}
}

impl PartialEq<Ident> for SqlIdentifier {
fn eq(&self, other: &Ident) -> bool {
*other == self.field_name
}
}

impl From<&'_ Ident> for SqlIdentifier {
fn from(ident: &'_ Ident) -> Self {
Self {
span: ident.span(),
field_name: ident.to_string(),
}
}
}

impl Parse for SqlIdentifier {
fn parse(input: ParseStream) -> Result<Self> {
let fork = input.fork();

if let Ok(ident) = fork.parse::<Ident>() {
input.advance_to(&fork);
Ok((&ident).into())
} else {
let name = input.parse::<LitStr>()?;
Ok(Self {
field_name: name.value(),
span: name.span(),
})
}
}
}

impl Parse for FieldAttr {
fn parse(input: ParseStream) -> Result<Self> {
let name: Ident = input.parse()?;
Expand Down
22 changes: 12 additions & 10 deletions diesel_derives/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use quote::ToTokens;
use syn::spanned::Spanned;
use syn::{Field as SynField, Ident, Index, Type};

use attrs::{parse_attributes, FieldAttr};
use attrs::{parse_attributes, FieldAttr, SqlIdentifier};

pub struct Field {
pub ty: Type,
pub span: Span,
pub name: FieldName,
column_name: Option<Ident>,
column_name: Option<SqlIdentifier>,
pub sql_type: Option<Type>,
pub serialize_as: Option<Type>,
pub deserialize_as: Option<Type>,
Expand Down Expand Up @@ -55,14 +55,16 @@ impl Field {
}
}

pub fn column_name(&self) -> &Ident {
self.column_name.as_ref()
.unwrap_or_else(|| match self.name {
FieldName::Named(ref x) => x,
FieldName::Unnamed(ref x) => {
abort!(x, "All fields of tuple structs must be annotated with `#[diesel(column_name)]`");
}
})
pub fn column_name(&self) -> SqlIdentifier {
self.column_name.clone().unwrap_or_else(|| match self.name {
FieldName::Named(ref x) => x.into(),
FieldName::Unnamed(ref x) => {
abort!(
x,
"All fields of tuple structs must be annotated with `#[diesel(column_name)]`"
);
}
})
}

pub fn ty_for_deserialize(&self) -> &Type {
Expand Down
2 changes: 1 addition & 1 deletion diesel_derives/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Model {
pub fn find_column(&self, column_name: &Ident) -> &Field {
self.fields()
.iter()
.find(|f| f.column_name() == column_name)
.find(|f| f.column_name() == *column_name)
.unwrap_or_else(|| abort!(column_name, "No field with column name {}", column_name))
}

Expand Down
36 changes: 18 additions & 18 deletions diesel_derives/tests/queryable_by_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ fn struct_with_no_table() {
assert_eq!(Ok(MyStructNamedSoYouCantInferIt { foo: 1, bar: 2 }), data);
}

// #[test]
// fn struct_with_non_ident_column_name() {
// #[derive(Debug, Clone, PartialEq, Eq, QueryableByName)]
// struct QueryPlan {
// #[diesel(sql_type = diesel::sql_types::Text)]
// #[diesel(column_name = "QUERY PLAN")]
// qp: String,
// }

// let conn = &mut connection();
// let data = sql_query("SELECT 'some plan' AS \"QUERY PLAN\"").get_result(conn);
// assert_eq!(
// Ok(QueryPlan {
// qp: "some plan".to_string()
// }),
// data
// );
// }
#[test]
fn struct_with_non_ident_column_name() {
#[derive(Debug, Clone, PartialEq, Eq, QueryableByName)]
struct QueryPlan {
#[diesel(sql_type = diesel::sql_types::Text)]
#[diesel(column_name = "QUERY PLAN")]
qp: String,
}

let conn = &mut connection();
let data = sql_query("SELECT 'some plan' AS \"QUERY PLAN\"").get_result(conn);
assert_eq!(
Ok(QueryPlan {
qp: "some plan".to_string()
}),
data
);
}

#[test]
fn embedded_struct() {
Expand Down

0 comments on commit 30075e2

Please sign in to comment.