Skip to content

Commit

Permalink
Remove unnecessary Visitor usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Nak committed Oct 9, 2023
1 parent f49a638 commit 6da8ccf
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 109 deletions.
144 changes: 57 additions & 87 deletions crates/hir-analysis/src/ty/constraint.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
use std::collections::{BTreeMap, BTreeSet};

use hir::{
hir_def::{
self, scope_graph::ScopeId, GenericParamOwner, IngotId, TraitRefId, TypeBound,
WherePredicate,
},
visitor::prelude::*,
};
use std::collections::BTreeSet;

use hir::hir_def::{scope_graph::ScopeId, GenericParam, GenericParamOwner, IngotId, TypeBound};
use rustc_hash::FxHashMap;
use salsa::function::Configuration;

Expand All @@ -19,7 +13,7 @@ use super::{
constraint_solver::{is_goal_satisfiable, GoalSatisfiability},
trait_::{Implementor, TraitDef, TraitInstId},
trait_lower::lower_trait_ref,
ty_def::{AdtDef, AdtRef, InvalidCause, Subst, TyConcrete, TyData, TyId},
ty_def::{AdtDef, Subst, TyConcrete, TyData, TyId},
ty_lower::{collect_generic_params, lower_hir_ty, GenericParamOwnerId},
};

Expand Down Expand Up @@ -153,34 +147,19 @@ pub(crate) fn collect_trait_constraints(
trait_: TraitDef,
) -> ConstraintListId {
let hir_trait = trait_.trait_(db);
let mut collector =
ConstraintCollector::new(db, GenericParamOwnerId::new(db, hir_trait.into()));
let collector = ConstraintCollector::new(db, GenericParamOwnerId::new(db, hir_trait.into()));

let mut ctxt = VisitorCtxt::with_trait(db.as_hir_db(), hir_trait);
collector.visit_trait(&mut ctxt, hir_trait);

collector.finalize()
collector.collect()
}

#[salsa::tracked]
pub(crate) fn collect_adt_constraints(db: &dyn HirAnalysisDb, adt: AdtDef) -> ConstraintListId {
let Some(owner) = adt.as_generic_param_owner(db) else {
return ConstraintListId::empty_list(db);
};
let mut collector = ConstraintCollector::new(db, owner);
match adt.adt_ref(db).data(db) {
AdtRef::Contract(_) => return ConstraintListId::empty_list(db),
AdtRef::Enum(enum_) => {
let mut ctxt = VisitorCtxt::with_enum(db.as_hir_db(), enum_);
collector.visit_enum(&mut ctxt, enum_);
}
AdtRef::Struct(struct_) => {
let mut ctxt = VisitorCtxt::with_struct(db.as_hir_db(), struct_);
collector.visit_struct(&mut ctxt, struct_);
}
}
let collector = ConstraintCollector::new(db, owner);

collector.finalize()
collector.collect()
}

#[salsa::tracked]
Expand All @@ -189,12 +168,9 @@ pub(crate) fn collect_implementor_constraints(
implementor: Implementor,
) -> ConstraintListId {
let impl_trait = implementor.impl_trait(db);
let mut collector =
ConstraintCollector::new(db, GenericParamOwnerId::new(db, impl_trait.into()));
let mut ctxt = VisitorCtxt::with_impl_trait(db.as_hir_db(), impl_trait);
collector.visit_impl_trait(&mut ctxt, impl_trait);
let collector = ConstraintCollector::new(db, GenericParamOwnerId::new(db, impl_trait.into()));

collector.finalize()
collector.collect()
}

/// Returns a list of assumptions obtained by the given assumptions by looking
Expand Down Expand Up @@ -349,13 +325,8 @@ impl SuperTraitCycle {

struct ConstraintCollector<'db> {
db: &'db dyn HirAnalysisDb,

owner: GenericParamOwnerId,

predicates: BTreeSet<PredicateId>,
predicate_span_map: BTreeMap<PredicateId, DynLazySpan>,

current_ty: TyId,
}

impl<'db> ConstraintCollector<'db> {
Expand All @@ -365,13 +336,12 @@ impl<'db> ConstraintCollector<'db> {
owner,

predicates: BTreeSet::new(),
predicate_span_map: BTreeMap::new(),

current_ty: TyId::invalid(db, InvalidCause::Other),
}
}

fn finalize(mut self) -> ConstraintListId {
fn collect(mut self) -> ConstraintListId {
self.collect_constraints_from_generic_params();
self.collect_constraints_from_where_clause();
self.simplify()
}

Expand All @@ -384,7 +354,7 @@ impl<'db> ConstraintCollector<'db> {
let trait_def = lower_trait(self.db, trait_);
let self_param = trait_def.self_param(self.db);
for &inst in trait_def.super_traits(self.db).iter() {
self.push_predicate(self_param, inst, DynLazySpan::invalid());
self.push_predicate(self_param, inst)
}
}

Expand Down Expand Up @@ -413,63 +383,63 @@ impl<'db> ConstraintCollector<'db> {
}
}

fn push_predicate(&mut self, ty: TyId, trait_inst: TraitInstId, span: DynLazySpan) {
fn push_predicate(&mut self, ty: TyId, trait_inst: TraitInstId) {
let pred = PredicateId::new(self.db, ty, trait_inst);
self.predicates.insert(pred);
self.predicate_span_map.insert(pred, span);
}
}

impl<'db> Visitor for ConstraintCollector<'db> {
fn visit_where_predicate(
&mut self,
ctxt: &mut VisitorCtxt<'_, LazyWherePredicateSpan>,
pred: &WherePredicate,
) {
let Some(hir_ty) = pred.ty.to_opt() else {
fn collect_constraints_from_where_clause(&mut self) {
let Some(where_clause) = self.owner.where_clause(self.db) else {
return;
};

let ty = lower_hir_ty(self.db, hir_ty, self.owner.scope(self.db));
for hir_pred in where_clause.data(self.db.as_hir_db()) {
let Some(hir_ty) = hir_pred.ty.to_opt() else {
continue;
};

// We don't need to collect super traits, please refer to
// `collect_super_traits` // function for details.
if ty.is_invalid(self.db) || ty.is_trait_self(self.db) {
return;
}
let ty = lower_hir_ty(self.db, hir_ty, self.owner.scope(self.db));

self.current_ty = ty;
walk_where_predicate(self, ctxt, pred);
}

fn visit_trait_ref(
&mut self,
ctxt: &mut VisitorCtxt<'_, LazyTraitRefSpan>,
trait_ref: TraitRefId,
) {
let trait_inst = lower_trait_ref(self.db, trait_ref, self.owner.scope(self.db));
// We don't need to collect super traits, please refer to
// `collect_super_traits` // function for details.
if ty.is_invalid(self.db) || ty.is_trait_self(self.db) {
return;
}

let Ok(trait_ref) = trait_inst else {
return;
};
self.add_bounds(ty, &hir_pred.bounds);
}
}

self.push_predicate(self.current_ty, trait_ref, ctxt.span().unwrap().into());
fn collect_constraints_from_generic_params(&mut self) {
let param_set = collect_generic_params(self.db, self.owner);
let params_list = self.owner.params(self.db);
assert!(param_set.params.len() == params_list.len(self.db.as_hir_db()));
for (&ty, hir_param) in param_set
.params
.iter()
.zip(params_list.data(self.db.as_hir_db()))
{
let GenericParam::Type(hir_param) = hir_param else {
continue;
};

let bounds = &hir_param.bounds;
self.add_bounds(ty, bounds)
}
}

fn visit_generic_param(
&mut self,
ctxt: &mut VisitorCtxt<'_, LazyGenericParamSpan>,
param: &hir_def::GenericParam,
) {
let ScopeId::GenericParam(_, param_idx) = ctxt.scope() else {
unreachable!()
};
self.current_ty = collect_generic_params(self.db, self.owner).params[param_idx];
fn add_bounds(&mut self, bound_ty: TyId, bounds: &[TypeBound]) {
for bound in bounds {
let TypeBound::Trait(trait_ref) = bound else {
continue;
};

walk_generic_param(self, ctxt, param);
}
let Ok(trait_inst) = lower_trait_ref(self.db, *trait_ref, self.owner.scope(self.db))
else {
continue;
};

fn visit_item(&mut self, _: &mut VisitorCtxt<'_, LazyItemSpan>, _: hir::hir_def::ItemKind) {
// We don't want to visit nested items.
self.push_predicate(bound_ty, trait_inst);
}
}
}
16 changes: 13 additions & 3 deletions crates/hir-analysis/src/ty/ty_lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::collections::BTreeSet;
use either::Either;
use hir::hir_def::{
kw, scope_graph::ScopeId, FieldDefListId, GenericArg, GenericArgListId, GenericParam,
GenericParamOwner, IdentId, IngotId, ItemKind, KindBound as HirKindBound, Partial, PathId,
TupleTypeId, TypeAlias as HirTypeAlias, TypeBound, TypeId as HirTyId, TypeKind as HirTyKind,
VariantDefListId, VariantKind,
GenericParamListId, GenericParamOwner, IdentId, IngotId, ItemKind, KindBound as HirKindBound,
Partial, PathId, TupleTypeId, TypeAlias as HirTypeAlias, TypeBound, TypeId as HirTyId,
TypeKind as HirTyKind, VariantDefListId, VariantKind, WhereClauseId,
};
use rustc_hash::FxHashMap;
use salsa::function::Configuration;
Expand Down Expand Up @@ -654,4 +654,14 @@ impl GenericParamOwnerId {
pub(super) fn ingot(self, db: &dyn HirAnalysisDb) -> IngotId {
self.data(db).top_mod(db.as_hir_db()).ingot(db.as_hir_db())
}

pub(super) fn where_clause(self, db: &dyn HirAnalysisDb) -> Option<WhereClauseId> {
self.data(db)
.where_clause_owner()
.map(|owner| owner.where_clause(db.as_hir_db()))
}

pub(super) fn params(self, db: &dyn HirAnalysisDb) -> GenericParamListId {
self.data(db).params(db.as_hir_db())
}
}
32 changes: 13 additions & 19 deletions crates/uitest/tests/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ fn run_ty_def(fixture: Fixture<&str>) {
snap_test!(diags, fixture.path());
}

#[dir_test(
dir: "$CARGO_MANIFEST_DIR/fixtures/ty/trait_bound",
glob: "*.fe"
)]
fn run_trait_bound(fixture: Fixture<&str>) {
let mut driver = DriverDataBase::default();
let path = Path::new(fixture.path());
let top_mod = driver.top_mod_from_file(path, fixture.content());
driver.run_on_top_mod(top_mod);
let diags = driver.format_diags();
snap_test!(diags, fixture.path());
}

#[cfg(target_family = "wasm")]
mod wasm {
use super::*;
Expand All @@ -36,25 +49,6 @@ mod wasm {
let top_mod = driver.top_mod_from_file(path, fixture.content());
driver.run_on_top_mod(top_mod);
}
}

#[dir_test(
dir: "$CARGO_MANIFEST_DIR/fixtures/ty/trait_bound",
glob: "*.fe"
)]
fn run_trait_bound(fixture: Fixture<&str>) {
let mut driver = DriverDataBase::default();
let path = Path::new(fixture.path());
let top_mod = driver.top_mod_from_file(path, fixture.content());
driver.run_on_top_mod(top_mod);
let diags = driver.format_diags();
snap_test!(diags, fixture.path());
}

#[cfg(target_family = "wasm")]
mod wasm {
use super::*;
use wasm_bindgen_test::wasm_bindgen_test;

#[dir_test(
dir: "$CARGO_MANIFEST_DIR/fixtures/ty/trait_bound",
Expand Down

0 comments on commit 6da8ccf

Please sign in to comment.