Skip to content

Commit

Permalink
perf improvements to solve-eqs and euf-completion
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Nov 17, 2022
1 parent 2c77999 commit 6662afd
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 51 deletions.
70 changes: 48 additions & 22 deletions src/ast/simplifiers/euf_completion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,43 @@ namespace euf {
}

void completion::reduce() {
unsigned rounds = 0;
do {
++m_epoch;
++rounds;
m_has_new_eq = false;
add_egraph();
map_canonical();
read_egraph();
IF_VERBOSE(11, verbose_stream() << "(euf.completion :rounds " << rounds << ")\n");
}
while (m_has_new_eq);
while (m_has_new_eq && rounds <= 3);
}

void completion::add_egraph() {
m_nodes.reset();
m_nodes_to_canonize.reset();
unsigned sz = m_fmls.size();
auto add_children = [&](enode* n) {
for (auto* ch : enode_args(n))
m_nodes_to_canonize.push_back(ch);
};

for (unsigned i = m_qhead; i < sz; ++i) {
auto [f,d] = m_fmls[i]();
auto* n = mk_enode(f);
if (m.is_eq(f)) {
m_egraph.merge(n->get_arg(0), n->get_arg(1), d);
m_nodes.push_back(n->get_arg(0));
m_nodes.push_back(n->get_arg(1));
add_children(n->get_arg(0));
add_children(n->get_arg(1));
}
if (m.is_not(f))
if (m.is_not(f)) {
m_egraph.merge(n->get_arg(0), m_ff, d);
else
add_children(n->get_arg(0));
}
else {
m_egraph.merge(n, m_tt, d);
add_children(n);
}
}
m_egraph.propagate();
}
Expand All @@ -106,28 +118,42 @@ namespace euf {
m_fmls.update(i, dependent_expr(m, g, dep));
m_stats.m_num_rewrites++;
IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n");
expr* x, * y;
if (m.is_eq(g, x, y) && new_eq(x, y))
m_has_new_eq = true;
if (m.is_and(g) && !m_has_new_eq)
for (expr* arg : *to_app(g))
if (m.is_eq(arg, x, y))
m_has_new_eq |= new_eq(x, y);
else if (!m.is_true(arg))
m_has_new_eq = true;
update_has_new_eq(g);
}
CTRACE("euf_completion", g != f, tout << mk_bounded_pp(f, m) << " -> " << mk_bounded_pp(g, m) << "\n");
}
if (!m_has_new_eq)
advance_qhead(m_fmls.size());
}

bool completion::new_eq(expr* a, expr* b) {
bool completion::is_new_eq(expr* a, expr* b) {
enode* na = m_egraph.find(a);
enode* nb = m_egraph.find(b);
if (!na)
IF_VERBOSE(11, verbose_stream() << "not internalied " << mk_bounded_pp(a, m) << "\n");
if (!nb)
IF_VERBOSE(11, verbose_stream() << "not internalied " << mk_bounded_pp(b, m) << "\n");
if (na && nb && na->get_root() != nb->get_root())
IF_VERBOSE(11, verbose_stream() << m_egraph.bpp(na) << " " << m_egraph.bpp(nb) << "\n");
return !na || !nb || na->get_root() != nb->get_root();
}

void completion::update_has_new_eq(expr* g) {
expr* x, * y;
if (m_has_new_eq)
return;
else if (m.is_eq(g, x, y))
m_has_new_eq |= is_new_eq(x, y);
else if (m.is_and(g)) {
for (expr* arg : *to_app(g))
update_has_new_eq(arg);
}
else if (m.is_not(g, g))
m_has_new_eq |= is_new_eq(g, m.mk_false());
else
m_has_new_eq |= is_new_eq(g, m.mk_true());
}

enode* completion::mk_enode(expr* e) {
m_todo.push_back(e);
enode* n;
Expand All @@ -138,7 +164,7 @@ namespace euf {
continue;
}
if (!is_app(e)) {
m_nodes.push_back(m_egraph.mk(e, 0, 0, nullptr));
m_nodes_to_canonize.push_back(m_egraph.mk(e, 0, 0, nullptr));
m_todo.pop_back();
continue;
}
Expand All @@ -152,7 +178,7 @@ namespace euf {
m_todo.push_back(arg);
}
if (sz == m_todo.size()) {
m_nodes.push_back(m_egraph.mk(e, 0, m_args.size(), m_args.data()));
m_nodes_to_canonize.push_back(m_egraph.mk(e, 0, m_args.size(), m_args.data()));
m_todo.pop_back();
}
}
Expand Down Expand Up @@ -314,10 +340,10 @@ namespace euf {
void completion::map_canonical() {
m_todo.reset();
enode_vector roots;
if (m_nodes.empty())
if (m_nodes_to_canonize.empty())
return;
for (unsigned i = 0; i < m_nodes.size(); ++i) {
enode* n = m_nodes[i]->get_root();
for (unsigned i = 0; i < m_nodes_to_canonize.size(); ++i) {
enode* n = m_nodes_to_canonize[i]->get_root();
if (n->is_marked1())
continue;
n->mark1();
Expand All @@ -334,7 +360,7 @@ namespace euf {
for (enode* arg : enode_args(n)) {
arg = arg->get_root();
if (!arg->is_marked1())
m_nodes.push_back(arg);
m_nodes_to_canonize.push_back(arg);
}
}
for (enode* r : roots)
Expand Down
6 changes: 3 additions & 3 deletions src/ast/simplifiers/euf_completion.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace euf {
egraph m_egraph;
enode* m_tt, *m_ff;
ptr_vector<expr> m_todo;
enode_vector m_args, m_reps, m_nodes;
enode_vector m_args, m_reps, m_nodes_to_canonize;
expr_ref_vector m_canonical, m_eargs;
expr_dependency_ref_vector m_deps;
unsigned m_epoch = 0;
Expand All @@ -43,11 +43,11 @@ namespace euf {
bool m_has_new_eq = false;

enode* mk_enode(expr* e);
bool new_eq(expr* a, expr* b);
bool is_new_eq(expr* a, expr* b);
void update_has_new_eq(expr* g);
expr_ref mk_and(expr* a, expr* b);
void add_egraph();
void map_canonical();
void saturate();
void read_egraph();
expr_ref canonize(expr* f, expr_dependency_ref& dep);
expr_ref canonize_fml(expr* f, expr_dependency_ref& dep);
Expand Down
54 changes: 40 additions & 14 deletions src/ast/simplifiers/solve_context_eqs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ namespace euf {

void solve_context_eqs::collect_nested_equalities(dep_eq_vector& eqs) {
expr_mark visited;
for (unsigned i = m_solve_eqs.m_qhead; i < m_fmls.size(); ++i)
unsigned sz = m_fmls.size();
for (unsigned i = m_solve_eqs.m_qhead; i < sz; ++i)
collect_nested_equalities(m_fmls[i], visited, eqs);

if (eqs.empty())
Expand All @@ -156,29 +157,55 @@ namespace euf {
std::stable_sort(eqs.begin(), eqs.end(), [&](dependent_eq const& e1, dependent_eq const& e2) {
return e1.var->get_id() < e2.var->get_id(); });

// quickly weed out variables that occur in more than two assertions.
unsigned_vector refcount;

// record the first and last occurrence of variables
// if the first and last occurrence coincide, the variable occurs in only one formula.
// otherwise it occurs in multiple formulas and should not be considered for solving.
unsigned_vector occurs1(m.get_num_asts() + 1, sz);
unsigned_vector occurs2(m.get_num_asts() + 1, sz);

struct visitor {
unsigned_vector& occurrence;
unsigned i = 0;
unsigned sz = 0;
visitor(unsigned_vector& occurrence) : occurrence(occurrence), i(0), sz(0) {}
void operator()(expr* t) {
occurrence.setx(t->get_id(), i, sz);
}
};

{
expr_mark visited;
for (unsigned i = m_solve_eqs.m_qhead; i < m_fmls.size(); ++i) {
visited.reset();
expr* f = m_fmls[i].fml();
for (expr* t : subterms::all(expr_ref(f, m), &m_todo, &visited))
refcount.setx(t->get_id(), refcount.get(t->get_id(), 0) + 1, 0);
visitor visitor1(occurs1);
visitor visitor2(occurs2);
visitor1.sz = sz;
visitor2.sz = sz;
expr_fast_mark1 fast_visited;
for (unsigned i = 0; i < sz; ++i) {
visitor1.i = i;
quick_for_each_expr(visitor1, fast_visited, m_fmls[i].fml());
}
fast_visited.reset();
for (unsigned i = sz; i-- > 0; ) {
visitor2.i = i;
quick_for_each_expr(visitor2, fast_visited, m_fmls[i].fml());
}
}

unsigned j = 0;
expr* last_var = nullptr;
bool was_unsafe = false;
for (auto const& eq : eqs) {

if (refcount.get(eq.var->get_id(), 0) > 1)
if (!eq.var)
continue;
unsigned occ1 = occurs1.get(eq.var->get_id(), sz);
unsigned occ2 = occurs2.get(eq.var->get_id(), sz);
if (occ1 >= sz)
continue;
if (occ1 != occ2)
continue;

SASSERT(!m.is_bool(eq.var));


if (eq.var != last_var) {

m_contains_v.reset();
Expand All @@ -195,8 +222,7 @@ namespace euf {
}

// then mark occurrences
for (unsigned i = 0; i < m_fmls.size(); ++i)
m_todo.push_back(m_fmls[i].fml());
m_todo.push_back(m_fmls[occ1].fml());
mark_occurs(m_todo, eq.var, m_contains_v);
SASSERT(m_todo.empty());
}
Expand Down
50 changes: 38 additions & 12 deletions src/ast/simplifiers/solve_eqs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ namespace euf {
return m_id2level[id] != UINT_MAX;
};

auto is_safe = [&](unsigned lvl, expr* t) {
for (auto* e : subterms::all(expr_ref(t, m), &m_todo, &m_visited))
if (is_var(e) && m_id2level[var2id(e)] < lvl)
return false;
return true;
};

unsigned init_level = UINT_MAX;
unsigned_vector todo;

Expand All @@ -94,26 +87,59 @@ namespace euf {
init_level -= m_id2var.size() + 1;
unsigned curr_level = init_level;
todo.push_back(id);

while (!todo.empty()) {
unsigned j = todo.back();
todo.pop_back();
if (is_explored(j))
continue;
m_id2level[j] = curr_level++;

for (auto const& eq : m_next[j]) {
auto const& [orig, v, t, d] = eq;
SASSERT(j == var2id(v));
if (!is_safe(curr_level, t))
bool is_safe = true;
unsigned todo_sz = todo.size();

// determine if substitution is safe.
// all time-stamps must be at or above current level
// unexplored variables that are part of substitution are appended to work list.
SASSERT(m_todo.empty());
m_todo.push_back(t);
expr_fast_mark1 visited;
while (!m_todo.empty()) {
expr* e = m_todo.back();
m_todo.pop_back();
if (visited.is_marked(e))
continue;
visited.mark(e, true);
if (is_app(e)) {
for (expr* arg : *to_app(e))
m_todo.push_back(arg);
}
else if (is_quantifier(e))
m_todo.push_back(to_quantifier(e)->get_expr());
if (!is_var(e))
continue;
if (m_id2level[var2id(e)] < curr_level) {
is_safe = false;
break;
}
if (!is_explored(var2id(e)))
todo.push_back(var2id(e));
}
m_todo.reset();

if (!is_safe) {
todo.shrink(todo_sz);
continue;
}
SASSERT(!occurs(v, t));
m_next[j][0] = eq;
m_subst_ids.push_back(j);
for (expr* e : subterms::all(expr_ref(t, m), &m_todo, &m_visited))
if (is_var(e) && !is_explored(var2id(e)))
todo.push_back(var2id(e));
break;
}
}
}
}
}

Expand Down

0 comments on commit 6662afd

Please sign in to comment.