Skip to content

Commit

Permalink
Refactor where predicates, and reserve for attributes support
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-king committed Nov 17, 2024
1 parent 328b759 commit d71d19e
Show file tree
Hide file tree
Showing 44 changed files with 395 additions and 409 deletions.
23 changes: 9 additions & 14 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,15 @@ impl Default for WhereClause {

/// A single predicate in a where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicate {
pub struct WherePredicate {
pub kind: WherePredicateKind,
pub id: NodeId,
pub span: Span,
}

/// Predicate kind in where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicateKind {
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
BoundPredicate(WhereBoundPredicate),
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
Expand All @@ -437,22 +445,11 @@ pub enum WherePredicate {
EqPredicate(WhereEqPredicate),
}

impl WherePredicate {
pub fn span(&self) -> Span {
match self {
WherePredicate::BoundPredicate(p) => p.span,
WherePredicate::RegionPredicate(p) => p.span,
WherePredicate::EqPredicate(p) => p.span,
}
}
}

/// A type bound.
///
/// E.g., `for<'c> Foo: Send + Clone + 'c`.
#[derive(Clone, Encodable, Decodable, Debug)]
pub struct WhereBoundPredicate {
pub span: Span,
/// Any generics from a `for` binding.
pub bound_generic_params: ThinVec<GenericParam>,
/// The type being bounded.
Expand All @@ -466,7 +463,6 @@ pub struct WhereBoundPredicate {
/// E.g., `'a: 'b + 'c`.
#[derive(Clone, Encodable, Decodable, Debug)]
pub struct WhereRegionPredicate {
pub span: Span,
pub lifetime: Lifetime,
pub bounds: GenericBounds,
}
Expand All @@ -476,7 +472,6 @@ pub struct WhereRegionPredicate {
/// E.g., `T = int`.
#[derive(Clone, Encodable, Decodable, Debug)]
pub struct WhereEqPredicate {
pub span: Span,
pub lhs_ty: P<Ty>,
pub rhs_ty: P<Ty>,
}
Expand Down
32 changes: 20 additions & 12 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,11 @@ pub trait MutVisitor: Sized {
}

fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
walk_where_predicate(self, where_predicate);
walk_where_predicate(self, where_predicate)
}

fn visit_where_predicate_kind(&mut self, kind: &mut WherePredicateKind) {
walk_where_predicate_kind(self, kind)
}

fn visit_vis(&mut self, vis: &mut Visibility) {
Expand Down Expand Up @@ -991,26 +995,30 @@ fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
vis.visit_span(span);
}

fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
match pred {
WherePredicate::BoundPredicate(bp) => {
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
pub fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
let WherePredicate { kind, id, span } = pred;
vis.visit_id(id);
vis.visit_where_predicate_kind(kind);
vis.visit_span(span);
}

pub fn walk_where_predicate_kind<T: MutVisitor>(vis: &mut T, kind: &mut WherePredicateKind) {
match kind {
WherePredicateKind::BoundPredicate(bp) => {
let WhereBoundPredicate { bound_generic_params, bounded_ty, bounds } = bp;
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
vis.visit_ty(bounded_ty);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::RegionPredicate(rp) => {
let WhereRegionPredicate { span, lifetime, bounds } = rp;
WherePredicateKind::RegionPredicate(rp) => {
let WhereRegionPredicate { lifetime, bounds } = rp;
vis.visit_lifetime(lifetime);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::EqPredicate(ep) => {
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
WherePredicateKind::EqPredicate(ep) => {
let WhereEqPredicate { lhs_ty, rhs_ty } = ep;
vis.visit_ty(lhs_ty);
vis.visit_ty(rhs_ty);
vis.visit_span(span);
}
}
}
Expand Down
20 changes: 15 additions & 5 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
walk_where_predicate(self, p)
}
fn visit_where_predicate_kind(&mut self, k: &'ast WherePredicateKind) -> Self::Result {
walk_where_predicate_kind(self, k)
}
fn visit_fn(&mut self, fk: FnKind<'ast>, _: Span, _: NodeId) -> Self::Result {
walk_fn(self, fk)
}
Expand Down Expand Up @@ -786,22 +789,29 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
) -> V::Result {
match predicate {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let WherePredicate { kind, id: _, span: _ } = predicate;
visitor.visit_where_predicate_kind(kind)
}

pub fn walk_where_predicate_kind<'a, V: Visitor<'a>>(
visitor: &mut V,
kind: &'a WherePredicateKind,
) -> V::Result {
match kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bounded_ty,
bounds,
bound_generic_params,
span: _,
}) => {
walk_list!(visitor, visit_generic_param, bound_generic_params);
try_visit!(visitor.visit_ty(bounded_ty));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds }) => {
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty }) => {
try_visit!(visitor.visit_ty(lhs_ty));
try_visit!(visitor.visit_ty(rhs_ty));
}
Expand Down
13 changes: 4 additions & 9 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,10 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
}

fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
match predicate {
WherePredicate::BoundPredicate(pred) => {
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
self.with_parent(pred.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
})
}
_ => intravisit::walk_where_predicate(self, predicate),
}
self.insert(predicate.span, predicate.hir_id, Node::WherePredicate(predicate));
self.with_parent(predicate.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
});
}

fn visit_array_length(&mut self, len: &'hir ArrayLen<'hir>) {
Expand Down
46 changes: 21 additions & 25 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
// checks both param bounds and where clauses for `?Sized`.
for pred in &generics.where_clause.predicates {
let WherePredicate::BoundPredicate(bound_pred) = pred else {
let WherePredicateKind::BoundPredicate(bound_pred) = &pred.kind else {
continue;
};
let compute_is_param = || {
Expand Down Expand Up @@ -1537,9 +1537,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
});
let span = self.lower_span(span);

match kind {
GenericParamKind::Const { .. } => None,
let hir_id = self.next_id();
let kind = self.arena.alloc(match kind {
GenericParamKind::Const { .. } => return None,
GenericParamKind::Type { .. } => {
let def_id = self.local_def_id(id).to_def_id();
let hir_id = self.next_id();
Expand All @@ -1554,38 +1554,36 @@ impl<'hir> LoweringContext<'_, 'hir> {
let ty_id = self.next_id();
let bounded_ty =
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bounded_ty: self.arena.alloc(bounded_ty),
bounds,
span,
bound_generic_params: &[],
origin,
}))
})
}
GenericParamKind::Lifetime => {
let ident = self.lower_ident(ident);
let lt_id = self.next_node_id();
let lifetime = self.new_named_lifetime(id, lt_id, ident);
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
lifetime,
span,
bounds,
in_where_clause: false,
}))
})
}
}
});
Some(hir::WherePredicate { hir_id, span, kind })
}

fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
match pred {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let hir_id = self.lower_node_id(pred.id);
let span = self.lower_span(pred.span);
let kind = self.arena.alloc(match &pred.kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bound_generic_params,
bounded_ty,
bounds,
span,
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bound_generic_params: self
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
bounded_ty: self
Expand All @@ -1594,12 +1592,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
span: self.lower_span(*span),
origin: PredicateOrigin::WhereClause,
}),
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds }) => {
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
Expand All @@ -1608,15 +1604,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
in_where_clause: true,
})
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty }) => {
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
lhs_ty: self
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
rhs_ty: self
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
span: self.lower_span(*span),
})
}
}
});
hir::WherePredicate { hir_id, span, kind }
}
}
32 changes: 17 additions & 15 deletions compiler/rustc_ast_passes/src/ast_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,14 +1201,15 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
validate_generic_param_order(self.dcx(), &generics.params, generics.span);

for predicate in &generics.where_clause.predicates {
if let WherePredicate::EqPredicate(predicate) = predicate {
deny_equality_constraints(self, predicate, generics);
let span = predicate.span;
if let WherePredicateKind::EqPredicate(predicate) = &predicate.kind {
deny_equality_constraints(self, predicate, span, generics);
}
}
walk_list!(self, visit_generic_param, &generics.params);
for predicate in &generics.where_clause.predicates {
match predicate {
WherePredicate::BoundPredicate(bound_pred) => {
match &predicate.kind {
WherePredicateKind::BoundPredicate(bound_pred) => {
// This is slightly complicated. Our representation for poly-trait-refs contains a single
// binder and thus we only allow a single level of quantification. However,
// the syntax of Rust permits quantification in two places in where clauses,
Expand Down Expand Up @@ -1511,9 +1512,10 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
fn deny_equality_constraints(
this: &AstValidator<'_>,
predicate: &WhereEqPredicate,
predicate_span: Span,
generics: &Generics,
) {
let mut err = errors::EqualityInWhere { span: predicate.span, assoc: None, assoc2: None };
let mut err = errors::EqualityInWhere { span: predicate_span, assoc: None, assoc2: None };

// Given `<A as Foo>::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
if let TyKind::Path(Some(qself), full_path) = &predicate.lhs_ty.kind
Expand Down Expand Up @@ -1557,7 +1559,7 @@ fn deny_equality_constraints(
}
}
err.assoc = Some(errors::AssociatedSuggestion {
span: predicate.span,
span: predicate_span,
ident: *ident,
param: param.ident,
path: pprust::path_to_string(&assoc_path),
Expand Down Expand Up @@ -1587,23 +1589,23 @@ fn deny_equality_constraints(
// We're removing th eonly where bound left, remove the whole thing.
generics.where_clause.span
} else {
let mut span = predicate.span;
let mut span = predicate_span;
let mut prev: Option<Span> = None;
let mut preds = generics.where_clause.predicates.iter().peekable();
// Find the predicate that shouldn't have been in the where bound list.
while let Some(pred) = preds.next() {
if let WherePredicate::EqPredicate(pred) = pred
&& pred.span == predicate.span
if let WherePredicateKind::EqPredicate(_) = pred.kind
&& pred.span == predicate_span
{
if let Some(next) = preds.peek() {
// This is the first predicate, remove the trailing comma as well.
span = span.with_hi(next.span().lo());
span = span.with_hi(next.span.lo());
} else if let Some(prev) = prev {
// Remove the previous comma as well.
span = span.with_lo(prev.hi());
}
}
prev = Some(pred.span());
prev = Some(pred.span);
}
span
};
Expand All @@ -1620,8 +1622,8 @@ fn deny_equality_constraints(
if let TyKind::Path(None, full_path) = &predicate.lhs_ty.kind {
// Given `A: Foo, Foo::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
for bounds in generics.params.iter().map(|p| &p.bounds).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p) => Some(&p.bounds),
generics.where_clause.predicates.iter().filter_map(|pred| match &pred.kind {
WherePredicateKind::BoundPredicate(p) => Some(&p.bounds),
_ => None,
}),
) {
Expand All @@ -1644,8 +1646,8 @@ fn deny_equality_constraints(
// Given `A: Foo, A::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
if let [potential_param, potential_assoc] = &full_path.segments[..] {
for (ident, bounds) in generics.params.iter().map(|p| (p.ident, &p.bounds)).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p)
generics.where_clause.predicates.iter().filter_map(|pred| match &pred.kind {
WherePredicateKind::BoundPredicate(p)
if let ast::TyKind::Path(None, path) = &p.bounded_ty.kind
&& let [segment] = &path.segments[..] =>
{
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ impl<'a> Visitor<'a> for PostExpansionVisitor<'a> {

fn visit_generics(&mut self, g: &'a ast::Generics) {
for predicate in &g.where_clause.predicates {
match predicate {
ast::WherePredicate::BoundPredicate(bound_pred) => {
match &predicate.kind {
ast::WherePredicateKind::BoundPredicate(bound_pred) => {
// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
self.check_late_bound_lifetime_defs(&bound_pred.bound_generic_params);
}
Expand Down
Loading

0 comments on commit d71d19e

Please sign in to comment.