diff --git a/arrow2_convert_derive/src/derive_struct.rs b/arrow2_convert_derive/src/derive_struct.rs index 04e0e4d..844f0be 100644 --- a/arrow2_convert_derive/src/derive_struct.rs +++ b/arrow2_convert_derive/src/derive_struct.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; use proc_macro_error::abort; -use quote::{quote, quote_spanned}; +use quote::{format_ident, quote, quote_spanned}; use syn::spanned::Spanned; use super::input::*; @@ -8,8 +8,9 @@ use super::input::*; struct Common<'a> { original_name: &'a proc_macro2::Ident, visibility: &'a syn::Visibility, - field_names: Vec<&'a proc_macro2::Ident>, - field_names_str: Vec, + field_members: Vec, + field_idents: Vec, + skipped_field_names: Vec, field_indices: Vec, field_types: Vec<&'a syn::TypePath>, } @@ -18,32 +19,53 @@ impl<'a> From<&'a DeriveStruct> for Common<'a> { fn from(input: &'a DeriveStruct) -> Self { let original_name = &input.common.name; let visibility = &input.common.visibility; - let fields = &input.fields; - //let (gen_serialize, gen_deserialize) = input.common.traits_to_derive.to_flags(); - let field_names = fields + let (skipped_fields, fields): (Vec<_>, Vec<_>) = + input.fields.iter().partition(|field| field.skip); + + let field_members = fields + .iter() + .enumerate() + .map(|(id, field)| { + field + .syn + .ident + .as_ref() + .cloned() + .map_or_else(|| syn::Member::Unnamed(id.into()), syn::Member::Named) + }) + .collect::>(); + + let field_idents = field_members + .iter() + .map(|f| match f { + // `Member` doesn't impl `IdentFragment` in a way that preserves the "r#" prefix stripping of `Ident`, so we go one level inside. + syn::Member::Named(ident) => format_ident!("field_{}", ident), + syn::Member::Unnamed(index) => format_ident!("field_{}", index), + }) + .collect::>(); + + let skipped_field_names = skipped_fields .iter() - .map(|field| field.syn.ident.as_ref().unwrap()) + .enumerate() + .map(|(id, field)| { + field + .syn + .ident + .as_ref() + .cloned() + .map_or_else(|| syn::Member::Unnamed(id.into()), syn::Member::Named) + }) .collect::>(); - if field_names.is_empty() { + if field_members.is_empty() { abort!( original_name.span(), "Expected struct to have more than one field" ); } - let field_names_str = field_names - .iter() - .map(|field| { - syn::LitStr::new( - strip_escape_prefix(&format!("{}", field)), - proc_macro2::Span::call_site(), - ) - }) - .collect::>(); - - let field_indices = field_names + let field_indices = field_members .iter() .enumerate() .map(|(idx, _ident)| { @@ -62,8 +84,9 @@ impl<'a> From<&'a DeriveStruct> for Common<'a> { Self { original_name, visibility, - field_names, - field_names_str, + field_members, + field_idents, + skipped_field_names, field_indices, field_types, } @@ -73,23 +96,39 @@ impl<'a> From<&'a DeriveStruct> for Common<'a> { pub fn expand_field(input: DeriveStruct) -> TokenStream { let Common { original_name, - field_names_str, + field_members, + //field_names_str, field_types, .. } = (&input).into(); + let data_type_impl = { + if input.fields.len() == 1 && input.is_transparent { + // Special case for single-field (tuple) structs + let field = &input.fields[0]; + let ty = &field.field_type; + quote! ( + <#ty as arrow2_convert::field::ArrowField>::data_type() + ) + } else { + let field_names = field_members.iter().map(|field| match field { + syn::Member::Named(ident) => format_ident!("{}", ident), + syn::Member::Unnamed(index) => format_ident!("field_{}", index), + }); + quote!(arrow2::datatypes::DataType::Struct(vec![ + #( + <#field_types as arrow2_convert::field::ArrowField>::field(stringify!(#field_names)), + )* + ])) + } + }; + quote!( impl arrow2_convert::field::ArrowField for #original_name { type Type = Self; fn data_type() -> arrow2::datatypes::DataType { - arrow2::datatypes::DataType::Struct( - vec![ - #( - <#field_types as arrow2_convert::field::ArrowField>::field(#field_names_str), - )* - ] - ) + #data_type_impl } } @@ -101,12 +140,13 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { let Common { original_name, visibility, - field_names, + field_members: field_names, + field_idents, field_types, .. } = (&input).into(); - let first_field = field_names[0]; + let first_field = &field_names[0]; let mutable_array_name = &input.common.mutable_array_name(); let mutable_field_array_types = field_types @@ -118,7 +158,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { #[derive(Debug)] #visibility struct #mutable_array_name { #( - #field_names: #mutable_field_array_types, + #field_idents: #mutable_field_array_types, )* data_type: arrow2::datatypes::DataType, validity: Option, @@ -129,7 +169,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { impl #mutable_array_name { pub fn new() -> Self { Self { - #(#field_names: <#field_types as arrow2_convert::serialize::ArrowSerialize>::new_array(),)* + #(#field_idents: <#field_types as arrow2_convert::serialize::ArrowSerialize>::new_array(),)* data_type: <#original_name as arrow2_convert::field::ArrowField>::data_type(), validity: None, } @@ -162,7 +202,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { Some(i) => { let i = i.borrow(); #( - <#field_types as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(i.#field_names.borrow(), &mut self.#field_names)?; + <#field_types as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(i.#field_names.borrow(), &mut self.#field_idents)?; )*; match &mut self.validity { Some(validity) => validity.push(true), @@ -171,7 +211,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { }, None => { #( - <#mutable_field_array_types as MutableArray>::push_null(&mut self.#field_names); + <#mutable_field_array_types as MutableArray>::push_null(&mut self.#field_idents); )*; match &mut self.validity { Some(validity) => validity.push(false), @@ -198,6 +238,8 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { } }; + let first_ident = &field_idents[0]; + let array_mutable_array_impl = quote! { impl arrow2::array::MutableArray for #mutable_array_name { fn data_type(&self) -> &arrow2::datatypes::DataType { @@ -205,7 +247,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { } fn len(&self) -> usize { - self.#first_field.len() + self.#first_ident.len() } fn validity(&self) -> Option<&arrow2::bitmap::MutableBitmap> { @@ -214,7 +256,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { fn as_box(&mut self) -> Box { let values = vec![#( - <#mutable_field_array_types as arrow2::array::MutableArray>::as_box(&mut self.#field_names), + <#mutable_field_array_types as arrow2::array::MutableArray>::as_box(&mut self.#field_idents), )*]; Box::new(arrow2::array::StructArray::new( @@ -226,7 +268,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { fn as_arc(&mut self) -> std::sync::Arc { let values = vec![#( - <#mutable_field_array_types as arrow2::array::MutableArray>::as_box(&mut self.#field_names), + <#mutable_field_array_types as arrow2::array::MutableArray>::as_box(&mut self.#field_idents), )*]; std::sync::Arc::new(arrow2::array::StructArray::new( @@ -251,7 +293,7 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { fn shrink_to_fit(&mut self) { #( - <#mutable_field_array_types as arrow2::array::MutableArray>::shrink_to_fit(&mut self.#field_names); + <#mutable_field_array_types as arrow2::array::MutableArray>::shrink_to_fit(&mut self.#field_idents); )* if let Some(validity) = &mut self.validity { validity.shrink_to_fit(); @@ -262,44 +304,66 @@ pub fn expand_serialize(input: DeriveStruct) -> TokenStream { if let Some(x) = self.validity.as_mut() { x.reserve(additional) } - #(<<#field_types as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType as arrow2::array::MutableArray>::reserve(&mut self.#field_names, additional);)* + #(<<#field_types as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType as arrow2::array::MutableArray>::reserve(&mut self.#field_idents, additional);)* } } }; - let field_arrow_serialize_impl = quote! { - impl arrow2_convert::serialize::ArrowSerialize for #original_name { - type MutableArrayType = #mutable_array_name; - - #[inline] - fn new_array() -> Self::MutableArrayType { - Self::MutableArrayType::default() - } + // Special case for single-field (tuple) structs. + if input.fields.len() == 1 && input.is_transparent { + let first_type = &field_types[0]; + // Everything delegates to first field. + quote! { + impl arrow2_convert::serialize::ArrowSerialize for #original_name { + type MutableArrayType = <#first_type as arrow2_convert::serialize::ArrowSerialize>::MutableArrayType; + + #[inline] + fn new_array() -> Self::MutableArrayType { + <#first_type as arrow2_convert::serialize::ArrowSerialize>::new_array() + } - #[inline] - fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> { - use arrow2::array::TryPush; - array.try_push(Some(v)) + #[inline] + fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> { + <#first_type as arrow2_convert::serialize::ArrowSerialize>::arrow_serialize(&v.#first_field, array) + } } } - }; + } else { + let field_arrow_serialize_impl = quote! { + impl arrow2_convert::serialize::ArrowSerialize for #original_name { + type MutableArrayType = #mutable_array_name; + + #[inline] + fn new_array() -> Self::MutableArrayType { + Self::MutableArrayType::default() + } - TokenStream::from_iter([ - array_decl, - array_impl, - array_default_impl, - array_try_push_impl, - array_try_extend_impl, - array_mutable_array_impl, - field_arrow_serialize_impl, - ]) + #[inline] + fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> { + use arrow2::array::TryPush; + array.try_push(Some(v)) + } + } + }; + TokenStream::from_iter([ + array_decl, + array_impl, + array_default_impl, + array_try_push_impl, + array_try_extend_impl, + array_mutable_array_impl, + field_arrow_serialize_impl, + ]) + } } pub fn expand_deserialize(input: DeriveStruct) -> TokenStream { let Common { original_name, visibility, - field_names, + field_members: field_names, + field_idents, + skipped_field_names, field_indices, field_types, .. @@ -307,6 +371,7 @@ pub fn expand_deserialize(input: DeriveStruct) -> TokenStream { let array_name = &input.common.array_name(); let iterator_name = &input.common.iterator_name(); + let is_tuple_struct = matches!(field_names[0], syn::Member::Unnamed(_)); let array_decl = quote! { #visibility struct #array_name @@ -328,7 +393,7 @@ pub fn expand_deserialize(input: DeriveStruct) -> TokenStream { // for now do a straight comp #iterator_name { #( - #field_names: <<#field_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(values[#field_indices].deref()), + #field_idents: <<#field_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(values[#field_indices].deref()), )* has_validity: validity.as_ref().is_some(), validity_iter: validity.as_ref().map(|x| x.iter()).unwrap_or_else(|| arrow2::bitmap::utils::BitmapIter::new(&[], 0, 0)) @@ -352,35 +417,45 @@ pub fn expand_deserialize(input: DeriveStruct) -> TokenStream { let iterator_decl = quote! { #visibility struct #iterator_name<'a> { #( - #field_names: <&'a <#field_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter, + #field_idents: <&'a <#field_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter, )* validity_iter: arrow2::bitmap::utils::BitmapIter<'a>, has_validity: bool } }; + let struct_inst: syn::Pat = if is_tuple_struct { + // If the fields are unnamed, we create a tuple-struct + syn::parse_quote! { + #original_name ( + #(<#field_types as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize_internal(#field_idents),)* + ) + } + } else { + syn::parse_quote! { + #original_name { + #(#field_names: <#field_types as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize_internal(#field_idents),)* + #(#skipped_field_names: std::default::Default::default(),)* + } + } + }; + let iterator_impl = quote! { impl<'a> #iterator_name<'a> { #[inline] fn return_next(&mut self) -> Option<#original_name> { if let (#( - Some(#field_names), + Some(#field_idents), )*) = ( - #(self.#field_names.next(),)* + #(self.#field_idents.next(),)* ) - { - Some(#original_name { - #(#field_names: <#field_types as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize_internal(#field_names),)* - }) - } - else { - None - } + { Some(#struct_inst) } + else { None } } #[inline] fn consume_next(&mut self) { - #(let _ = self.#field_names.next();)* + #(let _ = self.#field_idents.next();)* } } }; @@ -402,29 +477,48 @@ pub fn expand_deserialize(input: DeriveStruct) -> TokenStream { } }; - let field_arrow_deserialize_impl = quote! { - impl arrow2_convert::deserialize::ArrowDeserialize for #original_name { - type ArrayType = #array_name; - - #[inline] - fn arrow_deserialize<'a>(v: Option) -> Option { - v + // Special case for single-field (tuple) structs. + if input.fields.len() == 1 && input.is_transparent { + let first_type = &field_types[0]; + + let deser_body_mapper = if is_tuple_struct { + quote! { #original_name } + } else { + let first_name = &field_names[0]; + quote! { |v| #original_name { #first_name: v } } + }; + + // Everything delegates to first field. + quote! { + impl arrow2_convert::deserialize::ArrowDeserialize for #original_name { + type ArrayType = <#first_type as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType; + + #[inline] + fn arrow_deserialize<'a>(v: <&Self::ArrayType as IntoIterator>::Item) -> Option { + <#first_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(#deser_body_mapper) + } } } - }; - - TokenStream::from_iter([ - array_decl, - array_impl, - array_into_iterator_impl, - iterator_decl, - iterator_impl, - iterator_iterator_impl, - field_arrow_deserialize_impl, - ]) -} - -/// Removes the 'r#' from escaped identifiers. -fn strip_escape_prefix(name: &str) -> &str { - name.strip_prefix("r#").unwrap_or(name) + } else { + let field_arrow_deserialize_impl = quote! { + impl arrow2_convert::deserialize::ArrowDeserialize for #original_name { + type ArrayType = #array_name; + + #[inline] + fn arrow_deserialize<'a>(v: Option) -> Option { + v + } + } + }; + + TokenStream::from_iter([ + array_decl, + array_impl, + array_into_iterator_impl, + iterator_decl, + iterator_impl, + iterator_iterator_impl, + field_arrow_deserialize_impl, + ]) + } } diff --git a/arrow2_convert_derive/src/input.rs b/arrow2_convert_derive/src/input.rs index d148772..3235e7b 100644 --- a/arrow2_convert_derive/src/input.rs +++ b/arrow2_convert_derive/src/input.rs @@ -5,10 +5,13 @@ use syn::{DeriveInput, Ident, Lit, Meta, MetaNameValue, Visibility}; pub const ARROW_FIELD: &str = "arrow_field"; pub const FIELD_TYPE: &str = "type"; +pub const FIELD_SKIP: &str = "skip"; pub const UNION_TYPE: &str = "type"; pub const UNION_TYPE_SPARSE: &str = "sparse"; pub const UNION_TYPE_DENSE: &str = "dense"; +pub const TRANSPARENT: &str = "transparent"; +#[derive(Debug)] pub struct DeriveCommon { /// The input name pub name: Ident, @@ -16,33 +19,44 @@ pub struct DeriveCommon { pub visibility: Visibility, } +#[derive(Debug)] pub struct DeriveStruct { pub common: DeriveCommon, /// The list of fields in the struct pub fields: Vec, + pub is_transparent: bool, } +#[derive(Debug)] pub struct DeriveEnum { pub common: DeriveCommon, /// The list of variants in the enum pub variants: Vec, pub is_dense: bool, } + /// All container attributes +#[derive(Debug)] pub struct ContainerAttrs { pub is_dense: Option, + pub transparent: Option, } /// All field attributes +#[derive(Debug)] pub struct FieldAttrs { pub field_type: Option, + pub skip: bool, } +#[derive(Debug)] pub struct DeriveField { pub syn: syn::Field, pub field_type: syn::Type, + pub skip: bool, } +#[derive(Debug)] pub struct DeriveVariant { pub syn: syn::Variant, pub field_type: syn::Type, @@ -73,6 +87,7 @@ impl DeriveCommon { impl ContainerAttrs { pub fn from_ast(attrs: &[syn::Attribute]) -> ContainerAttrs { let mut is_dense: Option = None; + let mut is_transparent: Option = None; for attr in attrs { if let Ok(meta) = attr.parse_meta() { @@ -85,26 +100,24 @@ impl ContainerAttrs { lit: Lit::Str(string), path, .. - }) => { - if path.is_ident(UNION_TYPE) { - match string.value().as_ref() { - UNION_TYPE_DENSE => { - is_dense = Some(true); - } - UNION_TYPE_SPARSE => { - is_dense = Some(false); - } - _ => { - abort!( - path.span(), - "Unexpected value for mode" - ); - } + }) if path.is_ident(UNION_TYPE) => { + match string.value().as_ref() { + UNION_TYPE_DENSE => { + is_dense = Some(true); + } + UNION_TYPE_SPARSE => { + is_dense = Some(false); + } + _ => { + abort!(path.span(), "Unexpected value for mode"); } - } else { - abort!(path.span(), "Unexpected attribute"); } } + + Meta::Path(path) if path.is_ident(TRANSPARENT) => { + is_transparent = Some(path.span()); + } + _ => { abort!(meta.span(), "Unexpected attribute"); } @@ -116,13 +129,17 @@ impl ContainerAttrs { } } - ContainerAttrs { is_dense } + ContainerAttrs { + is_dense, + transparent: is_transparent, + } } } impl FieldAttrs { pub fn from_ast(input: &[syn::Attribute]) -> FieldAttrs { let mut field_type: Option = None; + let mut skip = false; for attr in input { if let Ok(meta) = attr.parse_meta() { @@ -135,13 +152,11 @@ impl FieldAttrs { lit: Lit::Str(string), path, .. - }) => { - if path.is_ident(FIELD_TYPE) { - field_type = Some( - syn::parse_str(&string.value()).unwrap_or_abort(), - ); - } + }) if path.is_ident(FIELD_TYPE) => { + field_type = + Some(syn::parse_str(&string.value()).unwrap_or_abort()); } + Meta::Path(path) if path.is_ident(FIELD_SKIP) => skip = true, _ => { abort!(meta.span(), "Unexpected attribute"); } @@ -153,7 +168,7 @@ impl FieldAttrs { } } - FieldAttrs { field_type } + FieldAttrs { field_type, skip } } } @@ -162,6 +177,15 @@ impl DeriveStruct { let container_attrs = ContainerAttrs::from_ast(&input.attrs); let common = DeriveCommon::from_ast(input, &container_attrs); + let is_transparent = if let Some(span) = container_attrs.transparent { + if ast.fields.len() > 1 { + abort!(span, "'transparent' is only supported on length-1 structs!"); + } + true + } else { + false + }; + DeriveStruct { common, fields: ast @@ -169,6 +193,7 @@ impl DeriveStruct { .iter() .map(DeriveField::from_ast) .collect::>(), + is_transparent, } } } @@ -199,6 +224,7 @@ impl DeriveField { DeriveField { syn: input.clone(), field_type: attrs.field_type.unwrap_or_else(|| input.ty.clone()), + skip: attrs.skip, } } }