Skip to content

Commit

Permalink
Refactor where predicates, and reserve for attributes support
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-king committed Nov 11, 2024
1 parent 328b759 commit 78668c5
Show file tree
Hide file tree
Showing 45 changed files with 478 additions and 372 deletions.
37 changes: 31 additions & 6 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use rustc_macros::{Decodable, Encodable, HashStable_Generic};
pub use rustc_span::AttrId;
use rustc_span::source_map::{Spanned, respan};
use rustc_span::symbol::{Ident, Symbol, kw, sym};
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span, SyntaxContext};
use thin_vec::{ThinVec, thin_vec};

pub use crate::format::*;
Expand Down Expand Up @@ -428,7 +428,32 @@ impl Default for WhereClause {

/// A single predicate in a where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicate {
pub struct WherePredicate {
pub kind: WherePredicateKind,
pub id: NodeId,
pub span: Span,
}

impl WherePredicate {
pub fn with_kind(&self, kind: WherePredicateKind) -> WherePredicate {
self.map_kind(None, |_| kind)
}
pub fn map_kind(
&self,
ctxt: Option<SyntaxContext>,
f: impl FnOnce(&WherePredicateKind) -> WherePredicateKind,
) -> WherePredicate {
WherePredicate {
kind: f(&self.kind),
id: DUMMY_NODE_ID,
span: ctxt.map_or(self.span, |ctxt| self.span.with_ctxt(ctxt)),
}
}
}

/// Predicate kind in where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicateKind {
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
BoundPredicate(WhereBoundPredicate),
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
Expand All @@ -437,12 +462,12 @@ pub enum WherePredicate {
EqPredicate(WhereEqPredicate),
}

impl WherePredicate {
impl WherePredicateKind {
pub fn span(&self) -> Span {
match self {
WherePredicate::BoundPredicate(p) => p.span,
WherePredicate::RegionPredicate(p) => p.span,
WherePredicate::EqPredicate(p) => p.span,
WherePredicateKind::BoundPredicate(p) => p.span,
WherePredicateKind::RegionPredicate(p) => p.span,
WherePredicateKind::EqPredicate(p) => p.span,
}
}
}
Expand Down
34 changes: 26 additions & 8 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,15 @@ pub trait MutVisitor: Sized {
walk_where_clause(self, where_clause);
}

fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
walk_where_predicate(self, where_predicate);
fn flat_map_where_predicate(
&mut self,
where_predicate: WherePredicate,
) -> SmallVec<[WherePredicate; 1]> {
walk_flat_map_where_predicate(self, where_predicate)
}

fn visit_where_predicate_kind(&mut self, kind: &mut WherePredicateKind) {
walk_where_predicate_kind(self, kind)
}

fn visit_vis(&mut self, vis: &mut Visibility) {
Expand Down Expand Up @@ -987,26 +994,37 @@ fn walk_ty_alias_where_clauses<T: MutVisitor>(vis: &mut T, tawcs: &mut TyAliasWh

fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
let WhereClause { has_where_token: _, predicates, span } = wc;
visit_thin_vec(predicates, |predicate| vis.visit_where_predicate(predicate));
predicates.flat_map_in_place(|predicate| vis.flat_map_where_predicate(predicate));
vis.visit_span(span);
}

fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
match pred {
WherePredicate::BoundPredicate(bp) => {
pub fn walk_flat_map_where_predicate<T: MutVisitor>(
vis: &mut T,
mut pred: WherePredicate,
) -> SmallVec<[WherePredicate; 1]> {
let WherePredicate { ref mut kind, ref mut id, ref mut span } = pred;
vis.visit_id(id);
vis.visit_where_predicate_kind(kind);
vis.visit_span(span);
smallvec![pred]
}

pub fn walk_where_predicate_kind<T: MutVisitor>(vis: &mut T, kind: &mut WherePredicateKind) {
match kind {
WherePredicateKind::BoundPredicate(bp) => {
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
vis.visit_ty(bounded_ty);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::RegionPredicate(rp) => {
WherePredicateKind::RegionPredicate(rp) => {
let WhereRegionPredicate { span, lifetime, bounds } = rp;
vis.visit_lifetime(lifetime);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::EqPredicate(ep) => {
WherePredicateKind::EqPredicate(ep) => {
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
vis.visit_ty(lhs_ty);
vis.visit_ty(rhs_ty);
Expand Down
19 changes: 15 additions & 4 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
walk_where_predicate(self, p)
}
fn visit_where_predicate_kind(&mut self, k: &'ast WherePredicateKind) -> Self::Result {
walk_where_predicate_kind(self, k)
}
fn visit_fn(&mut self, fk: FnKind<'ast>, _: Span, _: NodeId) -> Self::Result {
walk_fn(self, fk)
}
Expand Down Expand Up @@ -786,8 +789,16 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
) -> V::Result {
match predicate {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let WherePredicate { kind, id: _, span: _ } = predicate;
visitor.visit_where_predicate_kind(kind)
}

pub fn walk_where_predicate_kind<'a, V: Visitor<'a>>(
visitor: &mut V,
kind: &'a WherePredicateKind,
) -> V::Result {
match kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bounded_ty,
bounds,
bound_generic_params,
Expand All @@ -797,11 +808,11 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
try_visit!(visitor.visit_ty(bounded_ty));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
try_visit!(visitor.visit_ty(lhs_ty));
try_visit!(visitor.visit_ty(rhs_ty));
}
Expand Down
13 changes: 4 additions & 9 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,10 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
}

fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
match predicate {
WherePredicate::BoundPredicate(pred) => {
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
self.with_parent(pred.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
})
}
_ => intravisit::walk_where_predicate(self, predicate),
}
self.insert(predicate.span, predicate.hir_id, Node::WherePredicate(predicate));
self.with_parent(predicate.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
});
}

fn visit_array_length(&mut self, len: &'hir ArrayLen<'hir>) {
Expand Down
59 changes: 32 additions & 27 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
// checks both param bounds and where clauses for `?Sized`.
for pred in &generics.where_clause.predicates {
let WherePredicate::BoundPredicate(bound_pred) = pred else {
let WherePredicateKind::BoundPredicate(ref bound_pred) = pred.kind else {
continue;
};
let compute_is_param = || {
Expand Down Expand Up @@ -1538,8 +1538,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
});
let span = self.lower_span(span);

match kind {
GenericParamKind::Const { .. } => None,
let kind = match kind {
GenericParamKind::Const { .. } => return None,
GenericParamKind::Type { .. } => {
let def_id = self.local_def_id(id).to_def_id();
let hir_id = self.next_id();
Expand All @@ -1554,38 +1554,38 @@ impl<'hir> LoweringContext<'_, 'hir> {
let ty_id = self.next_id();
let bounded_ty =
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bounded_ty: self.arena.alloc(bounded_ty),
bounds,
span,
bound_generic_params: &[],
origin,
}))
})
}
GenericParamKind::Lifetime => {
let ident = self.lower_ident(ident);
let lt_id = self.next_node_id();
let lifetime = self.new_named_lifetime(id, lt_id, ident);
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
lifetime,
span,
bounds,
in_where_clause: false,
}))
})
}
}
};
Some(hir::WherePredicate { hir_id: self.next_id(), kind: self.arena.alloc(kind), span })
}

fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
match pred {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let hir_id = self.lower_node_id(pred.id);
let kind = match &pred.kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bound_generic_params,
bounded_ty,
bounds,
span,
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bound_generic_params: self
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
bounded_ty: self
Expand All @@ -1597,26 +1597,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
span: self.lower_span(*span),
origin: PredicateOrigin::WhereClause,
}),
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
})
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
WherePredicateKind::RegionPredicate(WhereRegionPredicate {
lifetime,
bounds,
span,
}) => hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
}),
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
lhs_ty: self
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
rhs_ty: self
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
span: self.lower_span(*span),
})
}
}
};
let kind = self.arena.alloc(kind);
let span = self.lower_span(pred.span);
hir::WherePredicate { hir_id, kind, span }
}
}
20 changes: 10 additions & 10 deletions compiler/rustc_ast_passes/src/ast_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,14 +1201,14 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
validate_generic_param_order(self.dcx(), &generics.params, generics.span);

for predicate in &generics.where_clause.predicates {
if let WherePredicate::EqPredicate(predicate) = predicate {
if let WherePredicateKind::EqPredicate(ref predicate) = predicate.kind {
deny_equality_constraints(self, predicate, generics);
}
}
walk_list!(self, visit_generic_param, &generics.params);
for predicate in &generics.where_clause.predicates {
match predicate {
WherePredicate::BoundPredicate(bound_pred) => {
match predicate.kind {
WherePredicateKind::BoundPredicate(ref bound_pred) => {
// This is slightly complicated. Our representation for poly-trait-refs contains a single
// binder and thus we only allow a single level of quantification. However,
// the syntax of Rust permits quantification in two places in where clauses,
Expand Down Expand Up @@ -1592,18 +1592,18 @@ fn deny_equality_constraints(
let mut preds = generics.where_clause.predicates.iter().peekable();
// Find the predicate that shouldn't have been in the where bound list.
while let Some(pred) = preds.next() {
if let WherePredicate::EqPredicate(pred) = pred
if let WherePredicateKind::EqPredicate(ref pred) = pred.kind
&& pred.span == predicate.span
{
if let Some(next) = preds.peek() {
// This is the first predicate, remove the trailing comma as well.
span = span.with_hi(next.span().lo());
span = span.with_hi(next.kind.span().lo());
} else if let Some(prev) = prev {
// Remove the previous comma as well.
span = span.with_lo(prev.hi());
}
}
prev = Some(pred.span());
prev = Some(pred.kind.span());
}
span
};
Expand All @@ -1620,8 +1620,8 @@ fn deny_equality_constraints(
if let TyKind::Path(None, full_path) = &predicate.lhs_ty.kind {
// Given `A: Foo, Foo::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
for bounds in generics.params.iter().map(|p| &p.bounds).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p) => Some(&p.bounds),
generics.where_clause.predicates.iter().filter_map(|pred| match pred.kind {
WherePredicateKind::BoundPredicate(ref p) => Some(&p.bounds),
_ => None,
}),
) {
Expand All @@ -1644,8 +1644,8 @@ fn deny_equality_constraints(
// Given `A: Foo, A::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
if let [potential_param, potential_assoc] = &full_path.segments[..] {
for (ident, bounds) in generics.params.iter().map(|p| (p.ident, &p.bounds)).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p)
generics.where_clause.predicates.iter().filter_map(|pred| match pred.kind {
WherePredicateKind::BoundPredicate(ref p)
if let ast::TyKind::Path(None, path) = &p.bounded_ty.kind
&& let [segment] = &path.segments[..] =>
{
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ impl<'a> Visitor<'a> for PostExpansionVisitor<'a> {

fn visit_generics(&mut self, g: &'a ast::Generics) {
for predicate in &g.where_clause.predicates {
match predicate {
ast::WherePredicate::BoundPredicate(bound_pred) => {
match predicate.kind {
ast::WherePredicateKind::BoundPredicate(ref bound_pred) => {
// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
self.check_late_bound_lifetime_defs(&bound_pred.bound_generic_params);
}
Expand Down
Loading

0 comments on commit 78668c5

Please sign in to comment.