Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly allow macro expanded format_args invocations to uses captures #106505

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions compiler/rustc_builtin_macros/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ enum PositionUsedAs {
}
use PositionUsedAs::*;

struct MacroInput {
fmtstr: P<Expr>,
args: FormatArguments,
/// Whether the first argument was a string literal or a result from eager macro expansion.
/// If it's not a string literal, we disallow implicit arugment capturing.
///
/// This does not correspond to whether we can treat spans to the literal normally, as the whole
/// invocation might be the result of another macro expansion, in which case this flag may still be true.
///
/// See [RFC 2795] for more information.
///
/// [RFC 2795]: https://rust-lang.github.io/rfcs/2795-format-args-implicit-identifiers.html#macro-hygiene
is_direct_literal: bool,
}

/// Parses the arguments from the given list of tokens, returning the diagnostic
/// if there's a parse error so we can continue parsing other format!
/// expressions.
Expand All @@ -45,11 +60,7 @@ use PositionUsedAs::*;
/// ```text
/// Ok((fmtstr, parsed arguments))
/// ```
fn parse_args<'a>(
ecx: &mut ExtCtxt<'a>,
sp: Span,
tts: TokenStream,
) -> PResult<'a, (P<Expr>, FormatArguments)> {
fn parse_args<'a>(ecx: &mut ExtCtxt<'a>, sp: Span, tts: TokenStream) -> PResult<'a, MacroInput> {
let mut args = FormatArguments::new();

let mut p = ecx.new_parser_from_tts(tts);
Expand All @@ -59,25 +70,21 @@ fn parse_args<'a>(
}

let first_token = &p.token;
let fmtstr = match first_token.kind {
token::TokenKind::Literal(token::Lit {
kind: token::LitKind::Str | token::LitKind::StrRaw(_),
..
}) => {
// If the first token is a string literal, then a format expression
// is constructed from it.
//
// This allows us to properly handle cases when the first comma
// after the format string is mistakenly replaced with any operator,
// which cause the expression parser to eat too much tokens.
p.parse_literal_maybe_minus()?
}
_ => {
// Otherwise, we fall back to the expression parser.
p.parse_expr()?
}

let fmtstr = if let token::Literal(lit) = first_token.kind && matches!(lit.kind, token::Str | token::StrRaw(_)) {
// This allows us to properly handle cases when the first comma
// after the format string is mistakenly replaced with any operator,
// which cause the expression parser to eat too much tokens.
p.parse_literal_maybe_minus()?
} else {
// Otherwise, we fall back to the expression parser.
p.parse_expr()?
};

// Only allow implicit captures to be used when the argument is a direct literal
// instead of a macro expanding to one.
let is_direct_literal = matches!(fmtstr.kind, ExprKind::Lit(_));

let mut first = true;

while p.token != token::Eof {
Expand Down Expand Up @@ -147,17 +154,19 @@ fn parse_args<'a>(
}
}
}
Ok((fmtstr, args))
Ok(MacroInput { fmtstr, args, is_direct_literal })
}

pub fn make_format_args(
fn make_format_args(
ecx: &mut ExtCtxt<'_>,
efmt: P<Expr>,
mut args: FormatArguments,
input: MacroInput,
append_newline: bool,
) -> Result<FormatArgs, ()> {
let msg = "format argument must be a string literal";
let unexpanded_fmt_span = efmt.span;
let unexpanded_fmt_span = input.fmtstr.span;

let MacroInput { fmtstr: efmt, mut args, is_direct_literal } = input;

let (fmt_str, fmt_style, fmt_span) = match expr_to_spanned_string(ecx, efmt, msg) {
Ok(mut fmt) if append_newline => {
fmt.0 = Symbol::intern(&format!("{}\n", fmt.0));
Expand Down Expand Up @@ -208,11 +217,11 @@ pub fn make_format_args(
}
}

let is_literal = parser.is_literal;
let is_source_literal = parser.is_source_literal;

if !parser.errors.is_empty() {
let err = parser.errors.remove(0);
let sp = if is_literal {
let sp = if is_source_literal {
fmt_span.from_inner(InnerSpan::new(err.span.start, err.span.end))
} else {
// The format string could be another macro invocation, e.g.:
Expand All @@ -230,7 +239,7 @@ pub fn make_format_args(
if let Some(note) = err.note {
e.note(&note);
}
if let Some((label, span)) = err.secondary_label && is_literal {
if let Some((label, span)) = err.secondary_label && is_source_literal {
e.span_label(fmt_span.from_inner(InnerSpan::new(span.start, span.end)), label);
}
if err.should_be_replaced_with_positional_argument {
Expand All @@ -256,7 +265,7 @@ pub fn make_format_args(
}

let to_span = |inner_span: rustc_parse_format::InnerSpan| {
is_literal.then(|| {
is_source_literal.then(|| {
fmt_span.from_inner(InnerSpan { start: inner_span.start, end: inner_span.end })
})
};
Expand Down Expand Up @@ -304,7 +313,7 @@ pub fn make_format_args(
// Name not found in `args`, so we add it as an implicitly captured argument.
let span = span.unwrap_or(fmt_span);
let ident = Ident::new(name, span);
let expr = if is_literal {
let expr = if is_direct_literal {
ecx.expr_ident(span, ident)
} else {
// For the moment capturing variables from format strings expanded from macros is
Expand Down Expand Up @@ -814,7 +823,7 @@ fn report_invalid_references(
// for `println!("{7:7$}", 1);`
indexes.sort();
indexes.dedup();
let span: MultiSpan = if !parser.is_literal || parser.arg_places.is_empty() {
let span: MultiSpan = if !parser.is_source_literal || parser.arg_places.is_empty() {
MultiSpan::from_span(fmt_span)
} else {
MultiSpan::from_spans(invalid_refs.iter().filter_map(|&(_, span, _, _)| span).collect())
Expand Down Expand Up @@ -855,8 +864,8 @@ fn expand_format_args_impl<'cx>(
) -> Box<dyn base::MacResult + 'cx> {
sp = ecx.with_def_site_ctxt(sp);
match parse_args(ecx, sp, tts) {
Ok((efmt, args)) => {
if let Ok(format_args) = make_format_args(ecx, efmt, args, nl) {
Ok(input) => {
if let Ok(format_args) = make_format_args(ecx, input, nl) {
MacEager::expr(ecx.expr(sp, ExprKind::FormatArgs(P(format_args))))
} else {
MacEager::expr(DummyResult::raw_expr(sp, true))
Expand Down
54 changes: 45 additions & 9 deletions compiler/rustc_parse_format/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// We want to be able to build this crate with a stable compiler, so no
// `#![feature]` attributes should be added.

use rustc_lexer::unescape;
pub use Alignment::*;
pub use Count::*;
pub use Piece::*;
Expand Down Expand Up @@ -234,8 +235,10 @@ pub struct Parser<'a> {
last_opening_brace: Option<InnerSpan>,
/// Whether the source string is comes from `println!` as opposed to `format!` or `print!`
append_newline: bool,
/// Whether this formatting string is a literal or it comes from a macro.
pub is_literal: bool,
/// Whether this formatting string was written directly in the source. This controls whether we
/// can use spans to refer into it and give better error messages.
/// N.B: This does _not_ control whether implicit argument captures can be used.
pub is_source_literal: bool,
/// Start position of the current line.
cur_line_start: usize,
/// Start and end byte offset of every line of the format string. Excludes
Expand All @@ -262,7 +265,7 @@ impl<'a> Iterator for Parser<'a> {
} else {
let arg = self.argument(lbrace_end);
if let Some(rbrace_pos) = self.must_consume('}') {
if self.is_literal {
if self.is_source_literal {
let lbrace_byte_pos = self.to_span_index(pos);
let rbrace_byte_pos = self.to_span_index(rbrace_pos);

Expand Down Expand Up @@ -302,7 +305,7 @@ impl<'a> Iterator for Parser<'a> {
_ => Some(String(self.string(pos))),
}
} else {
if self.is_literal {
if self.is_source_literal {
let span = self.span(self.cur_line_start, self.input.len());
if self.line_spans.last() != Some(&span) {
self.line_spans.push(span);
Expand All @@ -322,8 +325,8 @@ impl<'a> Parser<'a> {
append_newline: bool,
mode: ParseMode,
) -> Parser<'a> {
let input_string_kind = find_width_map_from_snippet(snippet, style);
let (width_map, is_literal) = match input_string_kind {
let input_string_kind = find_width_map_from_snippet(s, snippet, style);
let (width_map, is_source_literal) = match input_string_kind {
InputStringKind::Literal { width_mappings } => (width_mappings, true),
InputStringKind::NotALiteral => (Vec::new(), false),
};
Expand All @@ -339,7 +342,7 @@ impl<'a> Parser<'a> {
width_map,
last_opening_brace: None,
append_newline,
is_literal,
is_source_literal,
cur_line_start: 0,
line_spans: vec![],
}
Expand Down Expand Up @@ -532,13 +535,13 @@ impl<'a> Parser<'a> {
'{' | '}' => {
return &self.input[start..pos];
}
'\n' if self.is_literal => {
'\n' if self.is_source_literal => {
self.line_spans.push(self.span(self.cur_line_start, pos));
self.cur_line_start = pos + 1;
self.cur.next();
}
_ => {
if self.is_literal && pos == self.cur_line_start && c.is_whitespace() {
if self.is_source_literal && pos == self.cur_line_start && c.is_whitespace() {
self.cur_line_start = pos + c.len_utf8();
}
self.cur.next();
Expand Down Expand Up @@ -890,6 +893,7 @@ impl<'a> Parser<'a> {
/// written code (code snippet) and the `InternedString` that gets processed in the `Parser`
/// in order to properly synthesise the intra-string `Span`s for error diagnostics.
fn find_width_map_from_snippet(
input: &str,
snippet: Option<string::String>,
str_style: Option<usize>,
) -> InputStringKind {
Expand All @@ -902,8 +906,27 @@ fn find_width_map_from_snippet(
return InputStringKind::Literal { width_mappings: Vec::new() };
}

// Strip quotes.
let snippet = &snippet[1..snippet.len() - 1];

// Macros like `println` add a newline at the end. That technically doens't make them "literals" anymore, but it's fine
// since we will never need to point our spans there, so we lie about it here by ignoring it.
// Since there might actually be newlines in the source code, we need to normalize away all trailing newlines.
// If we only trimmed it off the input, `format!("\n")` would cause a mismatch as here we they actually match up.
// Alternatively, we could just count the trailing newlines and only trim one from the input if they don't match up.
petrochenkov marked this conversation as resolved.
Show resolved Hide resolved
let input_no_nl = input.trim_end_matches('\n');
let Some(unescaped) = unescape_string(snippet) else {
return InputStringKind::NotALiteral;
};

let unescaped_no_nl = unescaped.trim_end_matches('\n');

if unescaped_no_nl != input_no_nl {
// The source string that we're pointing at isn't our input, so spans pointing at it will be incorrect.
// This can for example happen with proc macros that respan generated literals.
return InputStringKind::NotALiteral;
}

let mut s = snippet.char_indices();
let mut width_mappings = vec![];
while let Some((pos, c)) = s.next() {
Expand Down Expand Up @@ -986,6 +1009,19 @@ fn find_width_map_from_snippet(
InputStringKind::Literal { width_mappings }
}

fn unescape_string(string: &str) -> Option<string::String> {
let mut buf = string::String::new();
let mut ok = true;
unescape::unescape_literal(string, unescape::Mode::Str, &mut |_, unescaped_char| {
match unescaped_char {
Ok(c) => buf.push(c),
Err(_) => ok = false,
}
});

ok.then_some(buf)
}

// Assert a reasonable size for `Piece`
#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))]
rustc_data_structures::static_assert_size!(Piece<'_>, 16);
Expand Down
36 changes: 26 additions & 10 deletions tests/ui/fmt/auxiliary/format-string-proc-macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,41 @@ pub fn err_with_input_span(input: TokenStream) -> TokenStream {
TokenStream::from(TokenTree::Literal(lit))
}

fn build_format(args: impl Into<TokenStream>) -> TokenStream {
TokenStream::from_iter([
TokenTree::from(Ident::new("format", Span::call_site())),
TokenTree::from(Punct::new('!', Spacing::Alone)),
TokenTree::from(Group::new(Delimiter::Parenthesis, args.into())),
])
}

#[proc_macro]
pub fn respan_to_invalid_format_literal(input: TokenStream) -> TokenStream {
let mut s = Literal::string("{");
s.set_span(input.into_iter().next().unwrap().span());
TokenStream::from_iter([
TokenTree::from(Ident::new("format", Span::call_site())),
TokenTree::from(Punct::new('!', Spacing::Alone)),
TokenTree::from(Group::new(Delimiter::Parenthesis, TokenTree::from(s).into())),
])

build_format(TokenTree::from(s))
}

#[proc_macro]
pub fn capture_a_with_prepended_space_preserve_span(input: TokenStream) -> TokenStream {
let mut s = Literal::string(" {a}");
s.set_span(input.into_iter().next().unwrap().span());
TokenStream::from_iter([
TokenTree::from(Ident::new("format", Span::call_site())),
TokenTree::from(Punct::new('!', Spacing::Alone)),
TokenTree::from(Group::new(Delimiter::Parenthesis, TokenTree::from(s).into())),
])

build_format(TokenTree::from(s))
}

#[proc_macro]
pub fn format_args_captures(_: TokenStream) -> TokenStream {
r#"{ let x = 5; format!("{x}") }"#.parse().unwrap()
}

#[proc_macro]
pub fn bad_format_args_captures(_: TokenStream) -> TokenStream {
r#"{ let x = 5; format!(concat!("{x}")) }"#.parse().unwrap()
}

#[proc_macro]
pub fn identity_pm(input: TokenStream) -> TokenStream {
input
}
21 changes: 21 additions & 0 deletions tests/ui/fmt/format-args-capture-first-literal-is-macro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// aux-build:format-string-proc-macro.rs

#[macro_use]
extern crate format_string_proc_macro;

macro_rules! identity_mbe {
($tt:tt) => {
$tt
//~^ ERROR there is no argument named `a`
};
}

fn main() {
let a = 0;

format!(identity_pm!("{a}"));
//~^ ERROR there is no argument named `a`
format!(identity_mbe!("{a}"));
format!(concat!("{a}"));
//~^ ERROR there is no argument named `a`
}
30 changes: 30 additions & 0 deletions tests/ui/fmt/format-args-capture-first-literal-is-macro.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
error: there is no argument named `a`
--> $DIR/format-args-capture-first-literal-is-macro.rs:16:26
|
LL | format!(identity_pm!("{a}"));
| ^^^^^
|
= note: did you intend to capture a variable `a` from the surrounding scope?
= note: to avoid ambiguity, `format_args!` cannot capture variables when the format string is expanded from a macro

error: there is no argument named `a`
--> $DIR/format-args-capture-first-literal-is-macro.rs:8:9
|
LL | $tt
| ^^^
|
= note: did you intend to capture a variable `a` from the surrounding scope?
= note: to avoid ambiguity, `format_args!` cannot capture variables when the format string is expanded from a macro

error: there is no argument named `a`
--> $DIR/format-args-capture-first-literal-is-macro.rs:19:13
|
LL | format!(concat!("{a}"));
| ^^^^^^^^^^^^^^
|
= note: did you intend to capture a variable `a` from the surrounding scope?
= note: to avoid ambiguity, `format_args!` cannot capture variables when the format string is expanded from a macro
= note: this error originates in the macro `concat` (in Nightly builds, run with -Z macro-backtrace for more info)

error: aborting due to 3 previous errors

Loading