Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generator_interior: Count match pattern bindings as borrowed for the whole guard expression #97029

Merged
merged 7 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,20 @@ pub enum Guard<'hir> {
IfLet(&'hir Let<'hir>),
}

impl<'hir> Guard<'hir> {
/// Returns the body of the guard
///
/// In other words, returns the e in either of the following:
///
/// - `if e`
/// - `if let x = e`
pub fn body(&self) -> &'hir Expr<'hir> {
match self {
Guard::If(e) | Guard::IfLet(Let { init: e, .. }) => e,
}
}
}

#[derive(Debug, HashStable_Generic)]
pub struct ExprField<'hir> {
#[stable_hasher(ignore)]
Expand Down
134 changes: 44 additions & 90 deletions compiler/rustc_typeck/src/check/generator_interior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use rustc_middle::middle::region::{self, Scope, ScopeData, YieldData};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::symbol::sym;
use rustc_span::Span;
use smallvec::SmallVec;
use tracing::debug;

mod drop_ranges;
Expand All @@ -29,13 +28,6 @@ struct InteriorVisitor<'a, 'tcx> {
expr_count: usize,
kind: hir::GeneratorKind,
prev_unresolved_span: Option<Span>,
/// Match arm guards have temporary borrows from the pattern bindings.
/// In case there is a yield point in a guard with a reference to such bindings,
/// such borrows can span across this yield point.
/// As such, we need to track these borrows and record them despite of the fact
/// that they may succeed the said yield point in the post-order.
guard_bindings: SmallVec<[SmallVec<[HirId; 4]>; 1]>,
guard_bindings_set: HirIdSet,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay for simpler

linted_values: HirIdSet,
drop_ranges: DropRanges,
}
Expand All @@ -48,7 +40,6 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
scope: Option<region::Scope>,
expr: Option<&'tcx Expr<'tcx>>,
source_span: Span,
guard_borrowing_from_pattern: bool,
) {
use rustc_span::DUMMY_SP;

Expand Down Expand Up @@ -89,8 +80,7 @@ impl<'a, 'tcx> InteriorVisitor<'a, 'tcx> {
// If it is a borrowing happening in the guard,
// it needs to be recorded regardless because they
// do live across this yield point.
guard_borrowing_from_pattern
|| yield_data.expr_and_pat_count >= self.expr_count
yield_data.expr_and_pat_count >= self.expr_count
})
.cloned()
})
Expand Down Expand Up @@ -196,8 +186,6 @@ pub fn resolve_interior<'a, 'tcx>(
expr_count: 0,
kind,
prev_unresolved_span: None,
guard_bindings: <_>::default(),
guard_bindings_set: <_>::default(),
linted_values: <_>::default(),
drop_ranges: drop_ranges::compute_drop_ranges(fcx, def_id, body),
};
Expand Down Expand Up @@ -284,15 +272,47 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
let Arm { guard, pat, body, .. } = arm;
self.visit_pat(pat);
if let Some(ref g) = guard {
self.guard_bindings.push(<_>::default());
ArmPatCollector {
guard_bindings_set: &mut self.guard_bindings_set,
guard_bindings: self
.guard_bindings
.last_mut()
.expect("should have pushed at least one earlier"),
{
// If there is a guard, we need to count all variables bound in the pattern as
// borrowed for the entire guard body, regardless of whether they are accessed.
// We do this by walking the pattern bindings and recording `&T` for any `x: T`
// that is bound.

struct ArmPatCollector<'a, 'b, 'tcx> {
interior_visitor: &'a mut InteriorVisitor<'b, 'tcx>,
scope: Scope,
}

impl<'a, 'b, 'tcx> Visitor<'tcx> for ArmPatCollector<'a, 'b, 'tcx> {
fn visit_pat(&mut self, pat: &'tcx Pat<'tcx>) {
intravisit::walk_pat(self, pat);
if let PatKind::Binding(_, id, ident, ..) = pat.kind {
let ty =
self.interior_visitor.fcx.typeck_results.borrow().node_type(id);
let tcx = self.interior_visitor.fcx.tcx;
let ty = tcx.mk_ref(
// Use `ReErased` as `resolve_interior` is going to replace all the
// regions anyway.
tcx.mk_region(ty::ReErased),
ty::TypeAndMut { ty, mutbl: hir::Mutability::Not },
);
self.interior_visitor.record(
ty,
id,
Some(self.scope),
None,
ident.span,
);
}
}
}

ArmPatCollector {
interior_visitor: self,
scope: Scope { id: g.body().hir_id.local_id, data: ScopeData::Node },
}
.visit_pat(pat);
}
.visit_pat(pat);

match g {
Guard::If(ref e) => {
Expand All @@ -302,12 +322,6 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
self.visit_let_expr(l);
}
}

let mut scope_var_ids =
self.guard_bindings.pop().expect("should have pushed at least one earlier");
for var_id in scope_var_ids.drain(..) {
self.guard_bindings_set.remove(&var_id);
}
}
self.visit_expr(body);
}
Expand All @@ -320,13 +334,11 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
if let PatKind::Binding(..) = pat.kind {
let scope = self.region_scope_tree.var_scope(pat.hir_id.local_id).unwrap();
let ty = self.fcx.typeck_results.borrow().pat_ty(pat);
self.record(ty, pat.hir_id, Some(scope), None, pat.span, false);
self.record(ty, pat.hir_id, Some(scope), None, pat.span);
}
}

fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
let mut guard_borrowing_from_pattern = false;

match &expr.kind {
ExprKind::Call(callee, args) => match &callee.kind {
ExprKind::Path(qpath) => {
Expand All @@ -353,16 +365,6 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
}
_ => intravisit::walk_expr(self, expr),
},
ExprKind::Path(qpath) => {
intravisit::walk_expr(self, expr);
let res = self.fcx.typeck_results.borrow().qpath_res(qpath, expr.hir_id);
match res {
Res::Local(id) if self.guard_bindings_set.contains(&id) => {
guard_borrowing_from_pattern = true;
}
_ => {}
}
}
_ => intravisit::walk_expr(self, expr),
}

Expand Down Expand Up @@ -391,14 +393,7 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
// If there are adjustments, then record the final type --
// this is the actual value that is being produced.
if let Some(adjusted_ty) = self.fcx.typeck_results.borrow().expr_ty_adjusted_opt(expr) {
self.record(
adjusted_ty,
expr.hir_id,
scope,
Some(expr),
expr.span,
guard_borrowing_from_pattern,
);
self.record(adjusted_ty, expr.hir_id, scope, Some(expr), expr.span);
}

// Also record the unadjusted type (which is the only type if
Expand Down Expand Up @@ -426,54 +421,13 @@ impl<'a, 'tcx> Visitor<'tcx> for InteriorVisitor<'a, 'tcx> {
// The type table might not have information for this expression
// if it is in a malformed scope. (#66387)
if let Some(ty) = self.fcx.typeck_results.borrow().expr_ty_opt(expr) {
if guard_borrowing_from_pattern {
// Match guards create references to all the bindings in the pattern that are used
// in the guard, e.g. `y if is_even(y) => ...` becomes `is_even(*r_y)` where `r_y`
// is a reference to `y`, so we must record a reference to the type of the binding.
let tcx = self.fcx.tcx;
let ref_ty = tcx.mk_ref(
// Use `ReErased` as `resolve_interior` is going to replace all the regions anyway.
tcx.mk_region(ty::ReErased),
ty::TypeAndMut { ty, mutbl: hir::Mutability::Not },
);
self.record(
ref_ty,
expr.hir_id,
scope,
Some(expr),
expr.span,
guard_borrowing_from_pattern,
);
}
self.record(
ty,
expr.hir_id,
scope,
Some(expr),
expr.span,
guard_borrowing_from_pattern,
);
self.record(ty, expr.hir_id, scope, Some(expr), expr.span);
} else {
self.fcx.tcx.sess.delay_span_bug(expr.span, "no type for node");
}
}
}

struct ArmPatCollector<'a> {
guard_bindings_set: &'a mut HirIdSet,
guard_bindings: &'a mut SmallVec<[HirId; 4]>,
}

impl<'a, 'tcx> Visitor<'tcx> for ArmPatCollector<'a> {
fn visit_pat(&mut self, pat: &'tcx Pat<'tcx>) {
intravisit::walk_pat(self, pat);
if let PatKind::Binding(_, id, ..) = pat.kind {
self.guard_bindings.push(id);
self.guard_bindings_set.insert(id);
}
}
}

#[derive(Default)]
pub struct SuspendCheckData<'a, 'tcx> {
expr: Option<&'tcx Expr<'tcx>>,
Expand Down
21 changes: 17 additions & 4 deletions compiler/rustc_typeck/src/expr_use_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use rustc_middle::hir::place::ProjectionKind;
use rustc_middle::mir::FakeReadCause;
use rustc_middle::ty::{self, adjustment, AdtKind, Ty, TyCtxt};
use rustc_target::abi::VariantIdx;
use ty::BorrowKind::ImmBorrow;

use crate::mem_categorization as mc;

Expand Down Expand Up @@ -621,7 +622,7 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
FakeReadCause::ForMatchedPlace(closure_def_id),
discr_place.hir_id,
);
self.walk_pat(discr_place, arm.pat);
self.walk_pat(discr_place, arm.pat, arm.guard.is_some());

if let Some(hir::Guard::If(e)) = arm.guard {
self.consume_expr(e)
Expand All @@ -645,12 +646,17 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
FakeReadCause::ForLet(closure_def_id),
discr_place.hir_id,
);
self.walk_pat(discr_place, pat);
self.walk_pat(discr_place, pat, false);
}

/// The core driver for walking a pattern
fn walk_pat(&mut self, discr_place: &PlaceWithHirId<'tcx>, pat: &hir::Pat<'_>) {
debug!("walk_pat(discr_place={:?}, pat={:?})", discr_place, pat);
fn walk_pat(
&mut self,
discr_place: &PlaceWithHirId<'tcx>,
pat: &hir::Pat<'_>,
has_guard: bool,
) {
debug!("walk_pat(discr_place={:?}, pat={:?}, has_guard={:?})", discr_place, pat, has_guard);

let tcx = self.tcx();
let ExprUseVisitor { ref mc, body_owner: _, ref mut delegate } = *self;
Expand All @@ -671,6 +677,13 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
delegate.bind(binding_place, binding_place.hir_id);
}

// Subtle: MIR desugaring introduces immutable borrows for each pattern
// binding when lowering pattern guards to ensure that the guard does not
// modify the scrutinee.
if has_guard {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests pass with or without this change, so arguably we should remove it. That said, this feels like the more correct behavior, so I think we should leave it in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Yes, it seems correct, but it's interesting that the tests pass either way! There is a certain amount of redundancy, it seems, between the analysis of what is borrowed etc and the "what types appear in the generator".

I think perhaps we could observe the difference with this test, but maybe not?

I think this will compile successfully:

#![feature(generators)]

fn is_send<T: Send>(_: T) { }

fn main() {
    is_send(async move {
        let x = std::rc::Rc::new(22);
        //match x { ref y if foo().await => () }
        drop(x);
        foo().await;
    });
}

async fn foo() -> bool {
    true 
}

but if you remove the comment from //match, it will not. I'm curious what happens if you remove this line here?

I guess that to really eliminate this redundancy we would drive the "do we need to include &T" check based on some set generated by this code (e.g., "is lvalue borrowed" or whatever).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test you provided does compile successfully. When I uncomment the match line, it gives a compiler error saying that the async block is not Send due to y and x being borrowed across the await point.

The behavior is the same with and without the if has_guard { ... } section in expr_use_visitor.rs.

delegate.borrow(place, discr_place.hir_id, ImmBorrow);
}

// It is also a borrow or copy/move of the value being matched.
// In a cases of pattern like `let pat = upvar`, don't use the span
// of the pattern, as this just looks confusing, instead use the span
Expand Down
12 changes: 12 additions & 0 deletions src/test/ui/generator/drop-tracking-yielding-in-match-guards.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-pass
// edition:2018
// compile-flags: -Zdrop-tracking

#![feature(generators)]

fn main() {
let _ = static |x: u8| match x {
y if { yield } == y + 1 => (),
_ => (),
};
}