diff --git a/src/doc.rs b/src/doc.rs index 81d4184..5fa2ad9 100644 --- a/src/doc.rs +++ b/src/doc.rs @@ -1,5 +1,8 @@ +use crate::error::Result; +use crate::segment::{self, Segment}; use proc_macro::{Delimiter, Span, TokenStream, TokenTree}; use std::iter; +use std::mem; use std::str::FromStr; pub fn is_pasted_doc(input: &TokenStream) -> bool { @@ -26,16 +29,33 @@ pub fn is_pasted_doc(input: &TokenStream) -> bool { state == State::Rest } -pub fn do_paste_doc(attr: &TokenStream, span: Span) -> TokenStream { +pub fn do_paste_doc(attr: &TokenStream, span: Span) -> Result { let mut expanded = TokenStream::new(); - let mut tokens = attr.clone().into_iter(); + let mut tokens = attr.clone().into_iter().peekable(); expanded.extend(tokens.by_ref().take(2)); // `doc =` - let mut lit = String::new(); - lit.push('"'); - for token in tokens { - lit += &escaped_string_value(&token).unwrap(); + let mut segments = segment::parse(&mut tokens)?; + + for segment in &mut segments { + if let Segment::String(string) = segment { + if let Some(open_quote) = string.value.find('"') { + if open_quote == 0 { + string.value.truncate(string.value.len() - 1); + string.value.remove(0); + } else { + let begin = open_quote + 1; + let end = string.value.rfind('"').unwrap(); + let raw_string = mem::replace(&mut string.value, String::new()); + for ch in raw_string[begin..end].chars() { + string.value.extend(ch.escape_default()); + } + } + } + } } + + let mut lit = segment::paste(&segments)?; + lit.insert(0, '"'); lit.push('"'); let mut lit = TokenStream::from_str(&lit) @@ -45,48 +65,26 @@ pub fn do_paste_doc(attr: &TokenStream, span: Span) -> TokenStream { .unwrap(); lit.set_span(span); expanded.extend(iter::once(lit)); - expanded + Ok(expanded) } fn is_stringlike(token: &TokenTree) -> bool { - escaped_string_value(token).is_some() -} - -fn escaped_string_value(token: &TokenTree) -> Option { match token { - TokenTree::Ident(ident) => Some(ident.to_string()), + TokenTree::Ident(_) => true, TokenTree::Literal(literal) => { - let mut repr = literal.to_string(); - if repr.starts_with('b') || repr.starts_with('\'') { - None - } else if repr.starts_with('"') { - repr.truncate(repr.len() - 1); - repr.remove(0); - Some(repr) - } else if repr.starts_with('r') { - let begin = repr.find('"').unwrap() + 1; - let end = repr.rfind('"').unwrap(); - let mut escaped = String::new(); - for ch in repr[begin..end].chars() { - escaped.extend(ch.escape_default()); - } - Some(escaped) - } else { - Some(repr) - } + let repr = literal.to_string(); + !repr.starts_with('b') && !repr.starts_with('\'') } TokenTree::Group(group) => { if group.delimiter() != Delimiter::None { - return None; + return false; } let mut inner = group.stream().into_iter(); - let first = inner.next()?; - if inner.next().is_none() { - escaped_string_value(&first) - } else { - None + match inner.next() { + Some(first) => inner.next().is_none() && is_stringlike(&first), + None => false, } } - TokenTree::Punct(_) => None, + TokenTree::Punct(punct) => punct.as_char() == '\'' || punct.as_char() == ':', } } diff --git a/src/lib.rs b/src/lib.rs index e38d034..e342a33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -210,7 +210,7 @@ fn expand(input: TokenStream, contains_paste: &mut bool) -> Result && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang) && is_pasted_doc(&content) { - let pasted = do_paste_doc(&content, span); + let pasted = do_paste_doc(&content, span)?; let mut group = Group::new(delimiter, pasted); group.set_span(span); expanded.extend(iter::once(TokenTree::Group(group))); diff --git a/tests/test_doc.rs b/tests/test_doc.rs index 96fe3a0..1ceaf23 100644 --- a/tests/test_doc.rs +++ b/tests/test_doc.rs @@ -42,3 +42,13 @@ fn test_literals() { let expected = "int=0x1 bool=true float=0.01"; assert_eq!(doc, expected); } + +#[test] +fn test_case() { + let doc = paste! { + get_doc!(#[doc = "HTTP " get:upper "!"]) + }; + + let expected = "HTTP GET!"; + assert_eq!(doc, expected); +}