Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LP extractor #128

Merged
merged 7 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ jobs:
with:
path: ~/.cargo/bin
key: ${{ runner.os }}-cargo-bin
- name: Install graphviz
run: sudo apt-get install graphviz
- name: Install graphviz and cbc
run: sudo apt-get install graphviz coinor-libcbc-dev
- name: Test
run: make test
nits:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
persist-credentials: false
- name: Build Docs
run: |
cargo doc --no-deps
cargo doc --no-deps --all-features
touch target/doc/.nojekyll # prevent jekyll from running
- name: Deploy 🚀
uses: JamesIves/[email protected]
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ fxhash = "0.2"
hashbrown = { version = "0.11", default-features = false, features = ["inline-more"] }
thiserror = "1"

# for the lp feature
coin_cbc = { version = "0.1.6", optional = true }

# for the serde-1 feature
serde = { version = "1", features = ["derive"], optional = true }
vectorize = { version = "0.2", optional = true }
Expand All @@ -33,6 +36,7 @@ env_logger = {version = "0.9", default-features = false}
ordered-float = "2"

[features]
lp = [ "coin_cbc" ]
wasm-bindgen = [ "instant/wasm-bindgen" ]
serde-1 = [ "serde", "indexmap/serde-1", "hashbrown/serde", "vectorize" ]
reports = [ "serde-1", "serde_json" ]
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ all: test nits

.PHONY: test
test:
cargo build --release
cargo test --release
cargo test --release --features=lp
# don't run examples in proof-production mode
cargo test --release --features "test-explanations"

Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ mod egraph;
mod explain;
mod extract;
mod language;
#[cfg(feature = "lp")]
mod lp_extract;
mod machine;
mod pattern;
mod rewrite;
Expand Down Expand Up @@ -91,6 +93,9 @@ pub use {
util::*,
};

#[cfg(feature = "lp")]
pub use lp_extract::*;

#[cfg(test)]
fn init_logger() {
let _ = env_logger::builder().is_test(true).try_init();
Expand Down
260 changes: 260 additions & 0 deletions src/lp_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
use coin_cbc::{Col, Model, Sense};

use crate::*;

/// A cost function to be used by an [`LpExtractor`].
pub trait LpCostFunction<L: Language, N: Analysis<L>> {
/// Returns the cost of the given e-node.
///
/// This function may look at other parts of the e-graph to compute the cost
/// of the given e-node.
fn node_cost(&mut self, egraph: &EGraph<L, N>, eclass: Id, enode: &L) -> f64;
}

impl<L: Language, N: Analysis<L>> LpCostFunction<L, N> for AstSize {
fn node_cost(&mut self, _egraph: &EGraph<L, N>, _eclass: Id, _enode: &L) -> f64 {
1.0
}
}

/// A structure to perform extraction using integer linear programming.
/// This uses the [`cbc`](https://projects.coin-or.org/Cbc) solver.
/// You must have it installed on your machine to use this feature.
/// You can install it using:
///
/// | OS | Command |
/// |------------------|------------------------------------------|
/// | Fedora / Red Hat | `sudo dnf install coin-or-Cbc-devel` |
/// | Ubuntu / Debian | `sudo apt-get install coinor-libcbc-dev` |
/// | macOS | `brew install cbc` |
///
/// # Example
/// ```
/// use egg::*;
/// let mut egraph = EGraph::<SymbolLang, ()>::default();
///
/// let f = egraph.add_expr(&"(f x x x)".parse().unwrap());
/// let g = egraph.add_expr(&"(g (g x))".parse().unwrap());
/// egraph.union(f, g);
/// egraph.rebuild();
///
/// let best = Extractor::new(&egraph, AstSize).find_best(f).1;
/// let lp_best = LpExtractor::new(&egraph, AstSize).solve(f);
///
/// // In regular extraction, cost is measures on the tree.
/// assert_eq!(best.to_string(), "(g (g x))");
///
/// // Using ILP only counts common sub-expressions once,
/// // so it can lead to a smaller DAG expression.
/// assert_eq!(lp_best.to_string(), "(f x x x)");
/// assert_eq!(lp_best.as_ref().len(), 2);
/// ```
pub struct LpExtractor<'a, L: Language, N: Analysis<L>> {
egraph: &'a EGraph<L, N>,
model: Model,
vars: HashMap<Id, ClassVars>,
}

struct ClassVars {
active: Col,
order: Col,
nodes: Vec<Col>,
}

impl<'a, L, N> LpExtractor<'a, L, N>
where
L: Language,
N: Analysis<L>,
{
/// Create an [`LpExtractor`] using costs from the given [`LpCostFunction`].
/// See those docs for details.
pub fn new<CF>(egraph: &'a EGraph<L, N>, mut cost_function: CF) -> Self
where
CF: LpCostFunction<L, N>,
{
let max_order = egraph.total_number_of_nodes() as f64 * 10.0;

let mut model = Model::default();

let vars: HashMap<Id, ClassVars> = egraph
.classes()
.map(|class| {
let cvars = ClassVars {
active: model.add_binary(),
order: model.add_col(),
nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
};
model.set_col_upper(cvars.order, max_order);
(class.id, cvars)
})
.collect();

let mut cycles: HashSet<(Id, usize)> = Default::default();
find_cycles(egraph, |id, i| {
cycles.insert((id, i));
});

for (&id, class) in &vars {
// class active == some node active
// sum(for node_active in class) == class_active
let row = model.add_row();
model.set_row_equal(row, 0.0);
model.set_weight(row, class.active, -1.0);
for &node_active in &class.nodes {
model.set_weight(row, node_active, 1.0);
}

for (i, (node, &node_active)) in egraph[id].iter().zip(&class.nodes).enumerate() {
if cycles.contains(&(id, i)) {
model.set_col_upper(node_active, 0.0);
model.set_col_lower(node_active, 0.0);
continue;
}

for child in node.children() {
let child_active = vars[child].active;
// node active implies child active, encoded as:
// node_active <= child_active
// node_active - child_active <= 0
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, node_active, 1.0);
model.set_weight(row, child_active, -1.0);
}
}
}

model.set_obj_sense(Sense::Minimize);
for class in egraph.classes() {
for (node, &node_active) in class.iter().zip(&vars[&class.id].nodes) {
model.set_obj_coeff(node_active, cost_function.node_cost(egraph, class.id, node));
}
}

dbg!(max_order);

Self {
egraph,
model,
vars,
}
}

/// Set the cbc timeout in seconds.
pub fn timeout(&mut self, seconds: f64) -> &mut Self {
self.model.set_parameter("seconds", &seconds.to_string());
self
}

/// Extract a single rooted term.
///
/// This is just a shortcut for [`LpExtractor::solve_multiple_using`].
pub fn solve(&mut self, root: Id) -> RecExpr<L> {
self.solve_multiple(&[root]).0
}

/// Extract (potentially multiple) roots
pub fn solve_multiple(&mut self, roots: &[Id]) -> (RecExpr<L>, Vec<Id>) {
let egraph = self.egraph;

for class in self.vars.values() {
self.model.set_binary(class.active);
}

for root in roots {
let root = &egraph.find(*root);
self.model.set_col_lower(self.vars[root].active, 1.0);
}

let solution = self.model.solve();
log::info!(
"CBC status {:?}, {:?}",
solution.raw().status(),
solution.raw().secondary_status()
);

let mut todo: Vec<Id> = roots.iter().map(|id| self.egraph.find(*id)).collect();
let mut expr = RecExpr::default();
// converts e-class ids to e-node ids
let mut ids: HashMap<Id, Id> = HashMap::default();

while let Some(&id) = todo.last() {
if ids.contains_key(&id) {
todo.pop();
continue;
}
let v = &self.vars[&id];
assert!(solution.col(v.active) > 0.0);
let node_idx = v.nodes.iter().position(|&n| solution.col(n) > 0.0).unwrap();
let node = &self.egraph[id].nodes[node_idx];
if node.all(|child| ids.contains_key(&child)) {
let new_id = expr.add(node.clone().map_children(|i| ids[&self.egraph.find(i)]));
ids.insert(id, new_id);
todo.pop();
} else {
todo.extend_from_slice(node.children())
}
}

let root_idxs = roots.iter().map(|root| ids[&root]).collect();

assert!(expr.is_dag(), "LpExtract found a cyclic term!: {:?}", expr);
(expr, root_idxs)
}
}

fn find_cycles<L, N>(egraph: &EGraph<L, N>, mut f: impl FnMut(Id, usize))
where
L: Language,
N: Analysis<L>,
{
enum Color {
White,
Gray,
Black,
}
type Enter = bool;

let mut color: HashMap<Id, Color> = egraph.classes().map(|c| (c.id, Color::White)).collect();
let mut stack: Vec<(Enter, Id)> = egraph.classes().map(|c| (true, c.id)).collect();

while let Some((enter, id)) = stack.pop() {
if enter {
*color.get_mut(&id).unwrap() = Color::Gray;
stack.push((false, id));
for (i, node) in egraph[id].iter().enumerate() {
for child in node.children() {
match &color[child] {
Color::White => stack.push((true, *child)),
Color::Gray => f(id, i),
Color::Black => (),
}
}
}
} else {
*color.get_mut(&id).unwrap() = Color::Black;
}
}
}

#[cfg(test)]
mod tests {
use crate::{SymbolLang as S, *};

#[test]
fn simple_lp_extract_two() {
let mut egraph = EGraph::<S, ()>::default();
let a = egraph.add(S::leaf("a"));
let plus = egraph.add(S::new("+", vec![a, a]));
let f = egraph.add(S::new("f", vec![plus]));
let g = egraph.add(S::new("g", vec![plus]));

let mut ext = LpExtractor::new(&egraph, AstSize);
ext.timeout(10.0); // way too much time
let (exp, ids) = ext.solve_multiple(&[f, g]);
println!("{:?}", exp);
println!("{}", exp);
assert_eq!(exp.as_ref().len(), 4);
assert_eq!(ids.len(), 2);
}
}
22 changes: 22 additions & 0 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,28 @@ fn assoc_mul_saturates() {
assert!(matches!(runner.stop_reason, Some(StopReason::Saturated)));
}

#[cfg(feature = "lp")]
#[test]
fn math_lp_extract() {
let expr: RecExpr<Math> = "(pow (+ x (+ x x)) (+ x x))".parse().unwrap();

let runner: Runner<Math, ConstantFold> = Runner::default()
.with_iter_limit(3)
.with_expr(&expr)
.run(&rules());
let root = runner.roots[0];

let best = Extractor::new(&runner.egraph, AstSize).find_best(root).1;
let lp_best = LpExtractor::new(&runner.egraph, AstSize).solve(root);

println!("input [{}] {}", expr.as_ref().len(), expr);
println!("normal [{}] {}", best.as_ref().len(), best);
println!("ilp cse [{}] {}", lp_best.as_ref().len(), lp_best);

assert_ne!(best, lp_best);
assert_eq!(lp_best.as_ref().len(), 4);
}

#[test]
fn math_ematching_bench() {
let exprs = &[
Expand Down