diff --git a/NEWS.md b/NEWS.md index 35609225..623e2ab3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,6 @@ +# 3.0 +- Updated TermInterface to 1.0.1 + # 2.0 - No longer dispatch against types, but instead dispatch against objects. - Faster E-Graph Analysis @@ -6,6 +9,7 @@ - New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` - Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. - Remove duplicates in E-Graph analyses data. + ## 1.2 - Fixes when printing patterns - Can pass custom `similarterm` to `SaturationParams` by using `SaturationParams.simterm`. diff --git a/Project.toml b/Project.toml index 0d8e3fc1..b9c345e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,24 +1,28 @@ name = "Metatheory" uuid = "e9d8d322-4543-424a-9be4-0cc815abe26c" authors = ["Alessandro Cheli - 0x0f0f0f "] -version = "2.0.2" +version = "3.0.0" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +[weakdeps] +GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" + +[extensions] +Plotting = ["GraphViz"] + [compat] AutoHashEquals = "2.1.0" -DataStructures = "0.18" DocStringExtensions = "0.8, 0.9" Reexport = "0.2, 1" -TermInterface = "0.3.3" +TermInterface = "2.0" TimerOutputs = "0.5" -julia = "1.8" +julia = "1.9" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/README.md b/README.md index a93386b5..cd836431 100644 --- a/README.md +++ b/README.md @@ -18,36 +18,154 @@ [![status](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9/status.svg)](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9) [![Zulip](https://img.shields.io/badge/Chat-Zulip-blue)](https://julialang.zulipchat.com/#narrow/stream/277860-metatheory.2Ejl) -**Metatheory.jl** is a general purpose term rewriting, metaprogramming and algebraic computation library for the Julia programming language, designed to take advantage of the powerful reflection capabilities to bridge the gap between symbolic mathematics, abstract interpretation, equational reasoning, optimization, composable compiler transforms, and advanced -homoiconic pattern matching features. The core features of Metatheory.jl are a powerful rewrite rule definition language, a vast library of functional combinators for classical term rewriting and an *e-graph rewriting*, a fresh approach to term rewriting achieved through an equality saturation algorithm. Metatheory.jl can manipulate any kind of -Julia symbolic expression type, as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). +**Metatheory.jl** is a general purpose term rewriting, metaprogramming and +algebraic computation library for the Julia programming language, designed to +take advantage of the powerful reflection capabilities to bridge the gap between +symbolic mathematics, abstract interpretation, equational reasoning, +optimization, composable compiler transforms, and advanced homoiconic pattern +matching features. The core features of Metatheory.jl are a powerful rewrite +rule definition language, a vast library of functional combinators for classical +term rewriting and an *[e-graph](https://en.wikipedia.org/wiki/E-graph) +rewriting*, a fresh approach to term rewriting achieved through an equality +saturation algorithm. Metatheory.jl can manipulate any kind of Julia symbolic +expression type, as long as it satisfies [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). Metatheory.jl provides: -- An eDSL (domain specific language) to define different kinds of symbolic rewrite rules. +- An eDSL (embedded domain specific language) to define different kinds of symbolic rewrite rules. - A classical rewriting backend, derived from the [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) pattern matcher, supporting associative-commutative rules. It is based on the pattern matcher in the [SICM book](https://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html). - A flexible library of rewriter combinators. -- An e-graph rewriting (equality saturation) backend and pattern matcher, based on the [egg](https://egraphs-good.github.io/) library, supporting backtracking and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. +- An [e-graph](https://en.wikipedia.org/wiki/E-graph) rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called [e-graph](https://en.wikipedia.org/wiki/E-graph), efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. - `@capture` macro for flexible metaprogramming. Intuitively, Metatheory.jl transforms Julia expressions -in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. +in other Julia expressions at both compile and run time. + +This allows users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. + Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia. -## 2.0 is out! -Second stable version is out: +## 3.0 Alpha + +- [ ] Rewrite integration test files in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format, becoming narrative tutorials available in the docs. +- [ ] Proof production algorithm: explanations. +- [x] Using new TermInterface. +- [x] Performance optimization: use vectors of UInt to internally represent terms in e-graphs. +- [x] Comprehensive suite of benchmarks that are run automatically on PR. +- [x] Complete overhaul of the rebuilding algorithm. + +--- + +## We need your help! - Practical and Research Contributions + +There's lot of room for improvement for Metatheory.jl, by making it more performant and by extending its features. +Any contribution is welcome! + +**Performance**: +- Improving the speed of the e-graph pattern matcher. [(Useful paper)](https://arxiv.org/abs/2108.02290) +- Reducing allocations used by Equality Saturation. +- [#50](https://github.com/JuliaSymbolics/Metatheory.jl/issues/50) - Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites. + +**Features**: +- [#111](https://github.com/JuliaSymbolics/Metatheory.jl/issues/111) Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs). +- Common Subexpression Elimination when extracting from an e-graph [#158](https://github.com/JuliaSymbolics/Metatheory.jl/issues/158) +- Integer Linear Programming extraction of expressions. +- Pattern matcher enhancements: [#43 Better parsing of blocks](https://github.com/JuliaSymbolics/Metatheory.jl/issues/43), [#3 Support `...` variables in e-graphs](https://github.com/JuliaSymbolics/Metatheory.jl/issues/3), [#89 syntax for vectors](https://github.com/JuliaSymbolics/Metatheory.jl/issues/89) +- [#75 E-Graph intersection algorithm](https://github.com/JuliaSymbolics/Metatheory.jl/issues/75) + +**Documentation**: +- Port more [integration tests](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/integration) to [tutorials](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/tutorials) that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl) +- Document [Functional Rewrite Combinators](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/Rewriters.jl) and add a tutorial. + +## Real World Applications + +Most importantly, there are many **practical real world applications** where Metatheory.jl could be used. Let's +work together to turn this list into some new Julia packages: + + +#### Integration with Symbolics.jl + +Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently +in-development**, as we recently released a new version of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). + +An integration between Metatheory.jl and [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) is possible and has previously been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. Once we reach consensus for a shared symbolic term interface, Metatheory.jl can be used to: + +- Rewrite Symbolics.jl expressions with **bi-directional equations** instead of simple directed rewrite rules. +- Search for the space of mathematically equivalent Symbolics.jl expressions for more computationally efficient forms to speed various packages like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) that numerically evaluate Symbolics.jl expressions. +- When proof production is introduced in Metatheory.jl, automatically search the space of a domain-specific equational theory to prove that Symbolics.jl expressions are equal in that theory. +- Other scientific domains extending Symbolics.jl for system modeling. + +#### Simplifying Quantum Algebras + +[QuantumCumulants.jl](https://github.com/qojulia/QuantumCumulants.jl/) automates +the symbolic derivation of mean-field equations in quantum mechanics, expanding +them in cumulants and generating numerical solutions using state-of-the-art +solvers like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) +and +[DifferentialEquations.jl](https://github.com/SciML/DifferentialEquations.jl). A +potential application for Metatheory.jl is domain-specific code optimization for +QuantumCumulants.jl, aiming to be the first symbolic simplification engine for +Fock algebras. + -- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine. -- No longer dispatch against types, but instead dispatch against objects. -- Faster E-Graph Analysis -- Better library macros -- Updated TermInterface to 0.3.3 -- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` -- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. -- Remove duplicates in E-Graph analyses data. +#### Automatic Floating Point Error Fixer -Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. +[Herbie](https://herbie.uwplse.org/) is a tool using equality saturation to automatically rewrites mathematical expressions to enhance +floating-point accuracy. Recently, Herbie's core has been rewritten using +[egg](https://egraphs-good.github.io/), with the tool originally implemented in +a mix of Racket, Scheme, and Rust. While effective, its usage involves multiple +languages, making it impractical for non-experts. The text suggests the theoretical +possibility of porting this technique to a pure Julia solution, seamlessly +integrating with the language, in a single macro `@fp_optimize` that fixes +floating-point errors in expressions just before code compilation and execution. + +#### Automatic Theorem Proving in Julia + +Metatheory.jl can be used to make a pure Julia Automated Theorem Prover (ATP) +inspired by the use of E-graphs in existing ATP environments like +[Z3](https://github.com/Z3Prover/z3), [Simplify](https://dl.acm.org/doi/10.1145/1066100.1066102) and [CVC4](https://en.wikipedia.org/wiki/CVC4), +in the context of [Satisfiability Modulo Theories (SMT)](https://en.wikipedia.org/wiki/Satisfiability_modulo_theories). + +The two-language problem in program verification can be addressed by allowing users to define high-level +theories about their code, that are statically verified before executing the program. This holds potential for various applications in +software verification, offering a flexible and generic environment for proving +formulae in different logics, and statically verifying such constraints on Julia +code before it gets compiled (see +[Mixtape.jl](https://github.com/JuliaCompilerPlugins/Mixtape.jl)). + +To develop such a package, Metatheory.jl needs: + +- Introduction of Proof Production in equality saturation. +- SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl) +- Experiments with various logic theories and software verification applications. + +#### Other potential applications + +Many projects that could potentially be ported to Julia are listed on the [egg website](https://egraphs-good.github.io/). +A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows many new articles that leverage the techniques used in this packages. + +PLDI is a premier academic forum in the field of programming languages and programming systems research, which organizes an [e-graph symposium](https://pldi23.sigplan.org/home/egraphs-2023) where many interesting research and projects have been presented. + +--- + +## Theoretical Developments + +There's also lots of room for theoretical improvements to the e-graph data structure and equality saturation rewriting. + +#### Associative-Commutative-Distributive e-matching + +In classical rewriting SymbolicUtils.jl offers a mechanism for matching expressions with associative and commutative operations: [`@acrule`](https://docs.sciml.ai/SymbolicUtils/stable/manual/rewrite/#Associative-Commutative-Rules) - a special kind of rule that considers all permutations and combinations of arguments. In e-graph rewriting in Metatheory.jl, associativity and commutativity have to be explicitly defined as rules. However, the presence of such rules, together with distributivity, will likely cause equality saturation to loop infinitely. See ["Why reasonable rules can create infinite loops"](https://github.com/egraphs-good/egg/discussions/60) for an explanation. + +Some workaround exists for ensuring termination of equality saturation: bounding the depth of search, or merge-only rewriting without introducing new terms (see ["Ensuring the Termination of EqSat over a Terminating Term Rewriting System"](https://effect.systems/blog/ta-completion.html)). + +There's a few theoretical questions left: + +- **What kind of rewrite systems terminate in equality saturation**? +- Can associative-commutative matching be applied efficiently to e-graphs while avoiding combinatory explosion? +- Can e-graphs be extended to include nodes with special algebraic properties, in order to mitigate the downsides of non-terminating systems? + +--- ## Recommended Readings - Selected Publications @@ -72,7 +190,7 @@ You can install the stable version: julia> using Pkg; Pkg.add("Metatheory") ``` -Or you can install the developer version (recommended by now for latest bugfixes) +Or you can install the development version (recommended by now for latest bugfixes) ```julia julia> using Pkg; Pkg.add(url="https://github.com/JuliaSymbolics/Metatheory.jl") ``` diff --git a/benchmark/tune.json b/benchmark/tune.json new file mode 100644 index 00000000..b4e5f699 --- /dev/null +++ b/benchmark/tune.json @@ -0,0 +1 @@ +[{"Julia":"1.9.4","BenchmarkTools":"1.0.0"},[["BenchmarkGroup",{"data":{"logic":["BenchmarkGroup",{"data":{"prove1":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rewrite":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":["egraph","logic"]}],"maths":["BenchmarkGroup",{"data":{"simpl1":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":["egraphs"]}]},"tags":[]}]]] \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index 4866d7b6..e4f35144 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" diff --git a/docs/liveserver.jl b/docs/liveserver.jl new file mode 100755 index 00000000..dd845f97 --- /dev/null +++ b/docs/liveserver.jl @@ -0,0 +1,27 @@ +#!/usr/bin/env julia + +# Root of the repository +const repo_root = dirname(@__DIR__) + +# Make sure docs environment is active +import Pkg +Pkg.activate(@__DIR__) +using Metatheory + +# Communicate with docs/make.jl that we are running in live mode +push!(ARGS, "liveserver") + +# Run LiveServer.servedocs(...) +import LiveServer +LiveServer.servedocs(; + # Documentation root where make.jl and src/ are located + foldername = joinpath(repo_root, "docs"), + skip_dirs = [ + # exclude assets folder because it is modified by docs/make.jl + joinpath(repo_root, "docs", "src", "assets"), + # exclude tutorial .md files (auto-generated via Literate.jl) + abspath(joinpath(@__DIR__, "src", "tutorials")) + ], + # include tutorial .jl files (generate .md files) + include_dirs=[joinpath(dirname(pathof(Metatheory)), "..", "test", "tutorials")] +) diff --git a/docs/src/api.md b/docs/src/api.md index 4cc2fbd5..69a6458b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,6 +1,6 @@ # API Documentation - +``` ## Syntax ```@autodocs @@ -25,26 +25,28 @@ Modules = [Metatheory.Rules] --- -## Rules +## Rewriters ```@autodocs -Modules = [Metatheory.Rules] +Modules = [Metatheory.Rewriters] ``` --- -## Rewriters +## EGraphs ```@autodocs -Modules = [Metatheory.Rewriters] +Modules = [Metatheory.EGraphs] ``` ---- +## VecExprs (aka e-nodes) -## EGraphs +```@docs +Metatheory.VecExprModule.VecExpr +``` ```@autodocs -Modules = [Metatheory.EGraphs] +Modules = [Metatheory.VecExprModule] ``` --- @@ -53,4 +55,4 @@ Modules = [Metatheory.EGraphs] ```@autodocs Modules = [Metatheory.EGraphs.Schedulers] -``` \ No newline at end of file +``` diff --git a/docs/src/assets/graphviz.svg b/docs/src/assets/graphviz.svg index bf4a076c..00a9564f 100644 --- a/docs/src/assets/graphviz.svg +++ b/docs/src/assets/graphviz.svg @@ -1,240 +1,121 @@ - - - - -cluster_7 - -#7. Smallest: a - - -cluster_14 - -#14. Smallest: -1 + + + + +cluster_1 + +%1. Smallest: a -cluster_8 - -#8. Smallest: 1 +cluster_5 + +%5. Smallest: -1 - -cluster_4 - -#4. Smallest: a - a + +cluster_3 + +%3. Smallest: -a - -cluster_15 - -#15. Smallest: -a + +cluster_7 + +%7. Smallest: a - a - + -8.1 - -* - - - -8.1:sw->8.1 - - - - - -8.1:se->8.1 - - +5.1 + +-1 - + -8.2 - -1 +7.1 + ++ - - -14.1 - --1 + + +3.1 + +- - - -7.1 - -* + + +7.1:sw->3.1 + + - - -7.1:sw->8.1 - - + + +1.1 + +a - - -7.1:se->7.1 - - + + +7.1:se->1.1 + + - + 7.2 - -a + ++ + + + +7.2:se->3.1 + + + + + +7.2:sw->1.1 + + - + 7.3 - -* - - - -7.3:se->8.1 - - + +- - + -7.3:sw->7.3 - - +7.3:sw->1.1 + + - - -4.1 - -+ + + +7.3:se->1.1 + + - + -4.1:sw->7.1 - - +3.1->1.1 + + - - -15.1 - -* + + +3.2 + +* - + -4.1:se->15.1 - - - - - -4.2 - -- +3.2:sw->5.1 + + - + -4.2:sw->7.1 - - - - - -4.2:se->7.1 - - - - - -4.3 - -* - - - -4.3:sw->8.1 - - - - - -4.3:se->4.3 - - - - - -4.4 - -+ - - - -4.4:se->7.1 - - - - - -4.4:sw->15.1 - - - - - -4.5 - -* - - - -4.5:se->8.1 - - - - - -4.5:sw->4.5 - - - - - -15.1:sw->14.1 - - - - - -15.1:se->7.1 - - - - - -15.2 - -- - - - -15.2->7.1 - - - - - -15.3 - -* - - - -15.3:sw->8.1 - - - - - -15.3:se->15.3 - - +3.2:se->1.1 + + \ No newline at end of file diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md index b9a458cd..9fe1d229 100644 --- a/docs/src/egraphs.md +++ b/docs/src/egraphs.md @@ -6,7 +6,7 @@ have very recently repurposed EGraphs to implement state-of-the-art, rewrite-driven compiler optimizations and program synthesizers using a technique known as equality saturation. Metatheory.jl provides a general purpose, customizable implementation of EGraphs and equality saturation, inspired from -the [egg](https://egraphs-good.github.io/) library for Rust. You can read more +the [egg](https://egraphs-good.github.io/) Rust library. You can read more about the design of the EGraph data structure and equality saturation algorithm in the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304). @@ -35,7 +35,7 @@ system governed by equational rules, about non obviously oriented equations, suc )`? E-Graphs come to our help. -EGraphs are bipartite graphs of [ENode](@ref)s and [EClass](@ref)es: +EGraphs are bipartite graphs of e-nodes (stored in [`VecExpr`](@ref)s) and [`EClass`](@ref)es: a data structure for efficiently represent and rewrite on many equivalent expressions at the same time. A sort of fast data structure for sets of trees. Subtrees and parents are shared if possible. This makes EGraphs similar to DAGs. Most importantly, with EGraph rewriting you can use **bidirectional rewrite rules**, such as **equalities** without worrying about the ordering and confluence of your rewrite system! @@ -83,20 +83,14 @@ commutativity and distributivity**, rules that are otherwise known of causing loops and require extensive user reasoning in classical rewriting. -```jldoctest +```@example basic_theory +using Metatheory + t = @theory a b c begin a * b == b * a a * 1 == a a * (b * c) == (a * b) * c end - -# output - -3-element Vector{EqualityRule}: - ~a * ~b == ~b * ~a - ~a * 1 == ~a - ~a * (~b * ~c) == (~a * ~b) * ~c - ``` @@ -109,19 +103,14 @@ customizable parameters include a `timeout` on the number of iterations, a `eclasslimit` on the number of e-classes in the EGraph, a `stopwhen` functions that stops saturation when it evaluates to true. -```@example +```@example basic_theory +using Metatheory g = EGraph(:((a * b) * (1 * (b + c)))); report = saturate!(g, t); ``` With the EGraph equality saturation backend, Metatheory.jl can prove **simple** -equalities very efficiently. The `@areequal` macro takes a theory and some -expressions and returns true iff the expressions are equal according to the -theory. The following example may return true with an appropriate example theory. - -```julia -julia> @areequal some_theory (x+y)*(a+b) ((a*(x+y))+b*(x+y)) ((x*(a+b))+y*(a+b)) -``` +equalities very efficiently. ## Configurable Parameters @@ -148,7 +137,7 @@ Given a starting e-graph `g`, a set of rewrite rules `t` and some parameters `p` * For each rule in `t`, search through the e-graph for l.h.s. * For each match produced, apply the rewrite * Do a bottom-up traversal of the e-graph to rebuild the congruence closure -* If the e-graph hasn’t changed from last iteration, it has saturated. If so, halt saturation. +* If the e-graph hasn't changed from last iteration, it has saturated. If so, halt saturation. * Loop at most n times. Note that knowing if an expression with a set of rules saturates an e-graph or never terminates @@ -222,43 +211,50 @@ A *cost function* for *EGraph extraction* is a function used to determine which *e-node* will be extracted from an *e-class*. It must return a positive, non-complex number value and, must accept 3 arguments. -1) The current [ENode](@ref) `n` that is being inspected. -2) The current [EGraph](@ref) `g`. -3) The current analysis name `an::Symbol`. +1) The current e-node [VecExpr](@ref) `n` that is being inspected. +2) The object corresponding to the e-node operation. +3) The cost of children as a `Vector`. From those 3 parameters, one can access all the data needed to compute the cost of an e-node recursively. -* One can use [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) methods to access the operation and child arguments of an e-node: `operation(n)`, `arity(n)` and `arguments(n)` -* Since e-node children always point to e-classes in the same e-graph, one can retrieve the [EClass](@ref) object for each child of the currently visited enode with `g[id] for id in arguments(n)` -* One can inspect the analysis data for a given eclass and a given analysis name `an`, by using [hasdata](@ref) and [getdata](@ref). -* Extraction analyses always associate a tuple of 2 values to a single e-class: which e-node is the one that minimizes the cost -and its cost. More details can be found in the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) in the *Analyses* section. - Here's an example: -```julia +```@example cost_function +using Metatheory # This is a cost function that behaves like `astsize` but increments the cost # of nodes containing the `^` operation. This results in a tendency to avoid # extraction of expressions containing '^'. -function cost_function(n::ENodeTerm, g::EGraph) - cost = 1 + arity(n) - - operation(n) == :^ && (cost += 2) - - for id in arguments(n) - eclass = g[id] - # if the child e-class has not yet been analyzed, return +Inf - !hasdata(eclass, cost_function) && (cost += Inf; break) - cost += last(getdata(eclass, cost_function)) - end - return cost +function cost_function(n::VecExpr, op, children_costs::Vector{Float64})::Float64 + # All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1 + v_isexpr(n) || return 1 + cost = op == :^ ? 2 : 1 + cost + sum(children_costs) end +``` -# All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1 -cost_function(n::ENodeLiteral, g::EGraph) = 1 +We can now compare how the two cost functions behave. + +```@example cost_function +t = @theory begin + ~a * ~a --> (~a)^2 + ~a --> (~a)^1 + (~a)^~n * (~a)^~m --> (~a)^(~n + ~m) + log((~a)^~n) == ~n * log(~a) + log(~x * ~y) --> log(~x) + log(~y) + log(1) --> 0 + log(:e) --> 1 + :e^(log(~x)) --> ~x +end +expr = :(log(x^2)) +g = EGraph(expr) +saturate!(g, t) +extract!(g, astsize), extract!(g, cost_function) ``` +We can see that our custom `cost_function` tends to avoid terms that +contain the `^` operator as it yields a higher cost for such terms. + ## EGraph Analyses An *EGraph Analysis* is an efficient and automated way of analyzing all the possible @@ -267,18 +263,15 @@ automate the process of EGraph Analysis. An *EGraph Analysis* defines a domain of values and associates a value from the domain to each [EClass](@ref) in the graph. Theoretically, the domain should form a [join semilattice](https://en.wikipedia.org/wiki/Semilattice). Rewrites can cooperate with e-class analyses by depending on analysis facts and adding equivalences that in turn establish additional facts. -In Metatheory.jl, **EGraph Analyses are uniquely identified** by either - -* An unique name of type `Symbol`. -* A function object `f`, used for cost function analysis. This will use built-in definitions of `make` and `join`. +In Metatheory.jl, **EGraph Analyses are uniquely identified** by a type: +The `EGraph{E,A}` type is parametrized by the expression type `E` and the +**analysis type** `A`. -If you are specifying a custom analysis by its `Symbol` name, -the following functions define an interface for analyses based on multiple dispatch -on `Val{analysis_name::Symbol}`: -* [islazy(an)](@ref) should return true if the analysis name `an` should NOT be computed on-the-fly during egraphs operation, but only when inspected. -* [make(an, egraph, n)](@ref) should take an ENode `n` and return a value from the analysis domain. -* [join(an, x,y)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*). If `an` is a `Function`, it is treated as a cost function analysis, it is automatically defined to be the minimum analysis value between `x` and `y`. Typically, the domain value of cost functions are real numbers, but if you really do want to have your own cost type, make sure that `Base.isless` is defined. -* [modify!(an, egraph, eclassid)](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclassid]` on-the-fly during an e-graph saturation iteration, given its analysis value. +The following functions define an interface for analyses based on multiple dispatch: + +* [make(g::EGraph{ExprType, AnalysisType}, n)](@ref) should take an e-node `n::VecExpr` and return a value from the analysis domain. +* [join(x::AnalysisType, y::AnalysisType)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?* or *how should they be merged?*).`Base.isless` must be defined. +* [modify!(g::EGraph{ExprType, AnalysisType}, eclass::EClass{AnalysisType})](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclass.id]` on-the-fly during an e-graph saturation iteration, given its analysis value, typically by adding an e-node. ### Defining a custom analysis @@ -294,24 +287,24 @@ the symbolic expressions that will result in an even or an odd number. Defining an EGraph Analysis is similar to the process [Mathematical Induction](https://en.wikipedia.org/wiki/Mathematical_induction). To define a custom EGraph Analysis, one should start by defining a name of type `Symbol` that will be used to identify this specific analysis and to dispatch against the required methods. -```julia -using Metatheory -using Metatheory.EGraphs -``` - -The next step, the base case of induction, is to define a method for +The first step is to define a method for [make](@ref) dispatching against our `OddEvenAnalysis`. First, we want to -associate an analysis value only to the *literals* contained in the EGraph. To do this we -take advantage of multiple dispatch against `ENodeLiteral`. +associate an analysis value only to the *literals* contained in the EGraph (the base case of induction). -```julia -function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeLiteral) - if n.value isa Integer - return iseven(n.value) ? :even : :odd - else - return nothing - end +```@example custom_analysis +using Metatheory + +struct OddEvenAnalysis + s::Symbol # :odd or :even end + +# Should be called only if isexpr(n) is false +odd_even_base_case(val::Integer) = OddEvenAnalysis(iseven(val) ? :even : :odd) +# By default, literals that are not integers yield no analysis value. +# In this case you should return `nothing` +odd_even_base_case(val) = nothing + +# ... Rest of code defined below ``` Now we have to consider the *induction step*. @@ -325,67 +318,55 @@ And we know that * odd + even = odd * even + even = even -We can now define a method for `make` dispatching against -`OddEvenAnalysis` and `ENodeTerm`s to compute the analysis value for *nested* symbolic terms. +We can now extend the function defined above to compute the analysis value for *nested* symbolic terms. We take advantage of the methods in [TermInterface](https://github.com/JuliaSymbolics/TermInterface.jl) -to inspect the content of an `ENodeTerm`. -From the definition of an [ENode](@ref), we know that children of ENodes are always IDs pointing -to EClasses in the EGraph. - -```julia -function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeTerm) +to inspect the children of an e-node that is a tree-like expression and not a literal. +From the definition of an e-node, we know that children of e-nodes are always IDs pointing +to e-classes in the `EGraph`. + +```@example custom_analysis +function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr) where {ExpressionType} + v_isexpr(n) || return odd_even_base_case(op) + # The e-node is not a literal value, # Let's consider only binary function call terms. - if exprhead(n) == :call && arity(n) == 2 + if v_iscall(n) && arity(n) == 2 op = operation(n) # Get the left and right child eclasses child_eclasses = arguments(n) - l = g[child_eclasses[1]] - r = g[child_eclasses[2]] - - # Get the corresponding OddEvenAnalysis value of the children - # defaulting to nothing - ldata = getdata(l, :OddEvenAnalysis, nothing) - rdata = getdata(r, :OddEvenAnalysis, nothing) + l,r = g[child_eclasses[1]], g[child_eclasses[2]] - if ldata isa Symbol && rdata isa Symbol + if !isnothing(l.data) && !isnothing(r.data) if op == :* - if ldata == rdata - ldata - elseif (ldata == :even || rdata == :even) - :even - else - nothing + if l.data == r.data + l.data + elseif (l.data.s == :even || r.data.s == :even) + OddEvenAnalysis(:even) end elseif op == :+ - (ldata == rdata) ? :even : :odd + (l.data == r.data) ? OddEvenAnalysis(:even) : OddEvenAnalysis(:odd) end - elseif isnothing(ldata) && rdata isa Symbol && op == :* - rdata - elseif ldata isa Symbol && isnothing(rdata) && op == :* - ldata + elseif isnothing(l.data) && !isnothing(r.data) && op == :* + r.data + elseif !isnothing(l.data) && isnothing(r.data) && op == :* + l.data end end - - return nothing end ``` -We have now defined a way of tagging each ENode in the EGraph with `:odd` or `:even`, reasoning +We have now defined a way of tagging each e-node in the EGraph with `:odd` or `:even`, reasoning inductively on the analyses values. The [analyze!](@ref) function will do the dirty job of doing a recursive walk over the EGraph. The missing piece, is now telling Metatheory.jl how to merge together analysis values. Since EClasses represent many equal ENodes, we have to inform the automated analysis how to extract a single value out of the many analyses values contained in an EGraph. We do this by defining a method for [join](@ref). -```julia -function EGraphs.join(::Val{:OddEvenAnalysis}, a, b) - if a == b - return a - else - # an expression cannot be odd and even at the same time! - # this is contradictory, so we ignore the analysis value - return nothing - end +```@example custom_analysis +function EGraphs.join(a::OddEvenAnalysis, b::OddEvenAnalysis) + # an expression cannot be odd and even at the same time! + # this is contradictory, so we ignore the analysis value + a != b && error("contradiction") + a end ``` @@ -393,7 +374,7 @@ We do not care to modify the content of EClasses in consequence of our analysis. Therefore, we can skip the definition of [modify!](@ref). We are now ready to test our analysis. -```julia +```@example custom_analysis t = @theory a b c begin a * (b * c) == (a * b) * c a + (b + c) == (a + b) + c @@ -403,10 +384,9 @@ t = @theory a b c begin end function custom_analysis(expr) - g = EGraph(expr) + g = EGraph{Expr, OddEvenAnalysis}(expr) saturate!(g, t) - analyze!(g, OddEvenAnalysis) - return getdata(g[g.root], OddEvenAnalysis) + return g[g.root].data end custom_analysis(:(2*a)) # :even diff --git a/docs/src/index.md b/docs/src/index.md index 8ddf9009..efda3e8b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,9 @@ -# Metatheory.jl 2.0 - ```@raw html

``` +# Metatheory.jl [![Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://juliasymbolics.github.io/Metatheory.jl/dev/) [![Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliasymbolics.github.io/Metatheory.jl/stable/) @@ -14,45 +13,164 @@ [![status](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9/status.svg)](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9) [![Zulip](https://img.shields.io/badge/Chat-Zulip-blue)](https://julialang.zulipchat.com/#narrow/stream/277860-metatheory.2Ejl) -**Metatheory.jl** is a general purpose term rewriting, metaprogramming and algebraic computation library for the Julia programming language, designed to take advantage of the powerful reflection capabilities to bridge the gap between symbolic mathematics, abstract interpretation, equational reasoning, optimization, composable compiler transforms, and advanced -homoiconic pattern matching features. The core features of Metatheory.jl are a powerful rewrite rule definition language, a vast library of functional combinators for classical term rewriting and an *e-graph rewriting*, a fresh approach to term rewriting achieved through an equality saturation algorithm. Metatheory.jl can manipulate any kind of -Julia symbolic expression type, as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). +**Metatheory.jl** is a general purpose term rewriting, metaprogramming and +algebraic computation library for the Julia programming language, designed to +take advantage of the powerful reflection capabilities to bridge the gap between +symbolic mathematics, abstract interpretation, equational reasoning, +optimization, composable compiler transforms, and advanced homoiconic pattern +matching features. The core features of Metatheory.jl are a powerful rewrite +rule definition language, a vast library of functional combinators for classical +term rewriting and an *[e-graph](https://en.wikipedia.org/wiki/E-graph) +rewriting*, a fresh approach to term rewriting achieved through an equality +saturation algorithm. Metatheory.jl can manipulate any kind of Julia symbolic +expression type, as long as it satisfies [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). Metatheory.jl provides: -- An eDSL (domain specific language) to define different kinds of symbolic rewrite rules. +- An eDSL (embedded domain specific language) to define different kinds of symbolic rewrite rules. - A classical rewriting backend, derived from the [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) pattern matcher, supporting associative-commutative rules. It is based on the pattern matcher in the [SICM book](https://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html). - A flexible library of rewriter combinators. -- An e-graph rewriting (equality saturation) backend and pattern matcher, based on the [egg](https://egraphs-good.github.io/) library, supporting backtracking and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. +- An [e-graph](https://en.wikipedia.org/wiki/E-graph) rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called [e-graph](https://en.wikipedia.org/wiki/E-graph), efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. - `@capture` macro for flexible metaprogramming. Intuitively, Metatheory.jl transforms Julia expressions -in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. +in other Julia expressions at both compile and run time. + +This allows users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. + Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia. -## 2.0 is out! -Second stable version is out: +## 3.0 Alpha + +- [ ] Rewrite integration test files in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format, becoming narrative tutorials available in the docs. +- [ ] Proof production algorithm: explanations. +- [x] Using new TermInterface. +- [x] Performance optimization: use vectors of UInt to internally represent terms in e-graphs. +- [x] Comprehensive suite of benchmarks that are run automatically on PR. +- [x] Complete overhaul of the rebuilding algorithm. + +--- + +## We need your help! - Practical and Research Contributions + +There's lot of room for improvement for Metatheory.jl, by making it more performant and by extending its features. +Any contribution is welcome! + +**Performance**: +- Improving the speed of the e-graph pattern matcher. [(Useful paper)](https://arxiv.org/abs/2108.02290) +- Reducing allocations used by Equality Saturation. +- [#50](https://github.com/JuliaSymbolics/Metatheory.jl/issues/50) - Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites. + +**Features**: +- [#111](https://github.com/JuliaSymbolics/Metatheory.jl/issues/111) Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs). +- Common Subexpression Elimination when extracting from an e-graph [#158](https://github.com/JuliaSymbolics/Metatheory.jl/issues/158) +- Integer Linear Programming extraction of expressions. +- Pattern matcher enhancements: [#43 Better parsing of blocks](https://github.com/JuliaSymbolics/Metatheory.jl/issues/43), [#3 Support `...` variables in e-graphs](https://github.com/JuliaSymbolics/Metatheory.jl/issues/3), [#89 syntax for vectors](https://github.com/JuliaSymbolics/Metatheory.jl/issues/89) +- [#75 E-Graph intersection algorithm](https://github.com/JuliaSymbolics/Metatheory.jl/issues/75) + +**Documentation**: +- Port more [integration tests](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/integration) to [tutorials](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/tutorials) that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl) +- Document [Functional Rewrite Combinators](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/Rewriters.jl) and add a tutorial. + +## Real World Applications + +Most importantly, there are many **practical real world applications** where Metatheory.jl could be used. Let's +work together to turn this list into some new Julia packages: + + +#### Integration with Symbolics.jl + +Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently +in-development**, as we recently released a new version of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). + +An integration between Metatheory.jl and [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) is possible and has previously been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. Once we reach consensus for a shared symbolic term interface, Metatheory.jl can be used to: + +- Rewrite Symbolics.jl expressions with **bi-directional equations** instead of simple directed rewrite rules. +- Search for the space of mathematically equivalent Symbolics.jl expressions for more computationally efficient forms to speed various packages like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) that numerically evaluate Symbolics.jl expressions. +- When proof production is introduced in Metatheory.jl, automatically search the space of a domain-specific equational theory to prove that Symbolics.jl expressions are equal in that theory. +- Other scientific domains extending Symbolics.jl for system modeling. + +#### Simplifying Quantum Algebras + +[QuantumCumulants.jl](https://github.com/qojulia/QuantumCumulants.jl/) automates +the symbolic derivation of mean-field equations in quantum mechanics, expanding +them in cumulants and generating numerical solutions using state-of-the-art +solvers like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) +and +[DifferentialEquations.jl](https://github.com/SciML/DifferentialEquations.jl). A +potential application for Metatheory.jl is domain-specific code optimization for +QuantumCumulants.jl, aiming to be the first symbolic simplification engine for +Fock algebras. + + +#### Automatic Floating Point Error Fixer + -- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine. -- No longer dispatch against types, but instead dispatch against objects. -- Faster E-Graph Analysis -- Better library macros -- Updated TermInterface to 0.3.3 -- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` -- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. -- Remove duplicates in E-Graph analyses data. +[Herbie](https://herbie.uwplse.org/) is a tool using equality saturation to automatically rewrites mathematical expressions to enhance +floating-point accuracy. Recently, Herbie's core has been rewritten using +[egg](https://egraphs-good.github.io/), with the tool originally implemented in +a mix of Racket, Scheme, and Rust. While effective, its usage involves multiple +languages, making it impractical for non-experts. The text suggests the theoretical +possibility of porting this technique to a pure Julia solution, seamlessly +integrating with the language, in a single macro `@fp_optimize` that fixes +floating-point errors in expressions just before code compilation and execution. +#### Automatic Theorem Proving in Julia -Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. +Metatheory.jl can be used to make a pure Julia Automated Theorem Prover (ATP) +inspired by the use of E-graphs in existing ATP environments like +[Z3](https://github.com/Z3Prover/z3), [Simplify](https://dl.acm.org/doi/10.1145/1066100.1066102) and [CVC4](https://en.wikipedia.org/wiki/CVC4), +in the context of [Satisfiability Modulo Theories (SMT)](https://en.wikipedia.org/wiki/Satisfiability_modulo_theories). + +The two-language problem in program verification can be addressed by allowing users to define high-level +theories about their code, that are statically verified before executing the program. This holds potential for various applications in +software verification, offering a flexible and generic environment for proving +formulae in different logics, and statically verifying such constraints on Julia +code before it gets compiled (see +[Mixtape.jl](https://github.com/JuliaCompilerPlugins/Mixtape.jl)). + +To develop such a package, Metatheory.jl needs: + +- Introduction of Proof Production in equality saturation. +- SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl) +- Experiments with various logic theories and software verification applications. + +#### Other potential applications + +Many projects that could potentially be ported to Julia are listed on the [egg website](https://egraphs-good.github.io/). +A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows many new articles that leverage the techniques used in this packages. + +PLDI is a premier academic forum in the field of programming languages and programming systems research, which organizes an [e-graph symposium](https://pldi23.sigplan.org/home/egraphs-2023) where many interesting research and projects have been presented. + +--- + +## Theoretical Developments + +There's also lots of room for theoretical improvements to the e-graph data structure and equality saturation rewriting. + +#### Associative-Commutative-Distributive e-matching + +In classical rewriting SymbolicUtils.jl offers a mechanism for matching expressions with associative and commutative operations: [`@acrule`](https://docs.sciml.ai/SymbolicUtils/stable/manual/rewrite/#Associative-Commutative-Rules) - a special kind of rule that considers all permutations and combinations of arguments. In e-graph rewriting in Metatheory.jl, associativity and commutativity have to be explicitly defined as rules. However, the presence of such rules, together with distributivity, will likely cause equality saturation to loop infinitely. See ["Why reasonable rules can create infinite loops"](https://github.com/egraphs-good/egg/discussions/60) for an explanation. + +Some workaround exists for ensuring termination of equality saturation: bounding the depth of search, or merge-only rewriting without introducing new terms (see ["Ensuring the Termination of EqSat over a Terminating Term Rewriting System"](https://effect.systems/blog/ta-completion.html)). + +There's a few theoretical questions left: + +- **What kind of rewrite systems terminate in equality saturation**? +- Can associative-commutative matching be applied efficiently to e-graphs while avoiding combinatory explosion? +- Can e-graphs be extended to include nodes with special algebraic properties, in order to mitigate the downsides of non-terminating systems? + +--- ## Recommended Readings - Selected Publications - The [Metatheory.jl manual](https://juliasymbolics.github.io/Metatheory.jl/stable/) -- The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities. +- **OUT OF DATE**: The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities. - The Julia Manual [metaprogramming section](https://docs.julialang.org/en/v1/manual/metaprogramming/) is fundamental to understand what homoiconic expression manipulation is and how it happens in Julia. - An [introductory blog post on SIGPLAN](https://blog.sigplan.org/2021/04/06/equality-saturation-with-egg/) about `egg` and e-graphs rewriting. - [egg: Fast and Extensible Equality Saturation](https://dl.acm.org/doi/pdf/10.1145/3434304) contains the definition of *E-Graphs* on which Metatheory.jl's equality saturation rewriting backend is based. This is a strongly recommended reading. - [High-performance symbolic-numerics via multiple dispatch](https://arxiv.org/abs/2105.03949): a paper about how we used Metatheory.jl to optimize code generation in [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) +- [Automated Code Optimization with E-Graphs](https://arxiv.org/abs/2112.14714). Alessandro Cheli's Thesis on Metatheory.jl ## Contributing @@ -60,8 +178,6 @@ If you'd like to give us a hand and contribute to this repository you can: - Find a high level description of the project architecture in [ARCHITECTURE.md](https://github.com/juliasymbolics/Metatheory.jl/blob/master/ARCHITECTURE.md) - Read the contribution guidelines in [CONTRIBUTING.md](https://github.com/juliasymbolics/Metatheory.jl/blob/master/CONTRIBUTING.md) -If you enjoyed Metatheory.jl and would like to help, please also consider a [tiny donation 💕](https://github.com/sponsors/0x0f0f0f/)! - ## Installation You can install the stable version: @@ -69,7 +185,7 @@ You can install the stable version: julia> using Pkg; Pkg.add("Metatheory") ``` -Or you can install the developer version (recommended by now for latest bugfixes) +Or you can install the development version (recommended by now for latest bugfixes) ```julia julia> using Pkg; Pkg.add(url="https://github.com/JuliaSymbolics/Metatheory.jl") ``` @@ -84,10 +200,12 @@ If you use Metatheory.jl in your research, please [cite](https://github.com/juli --- -```@raw html +# Sponsors + +If you enjoyed Metatheory.jl and would like to help, you can donate a coffee or choose place your logo and name in this page. [See 0x0f0f0f's Github Sponsors page](https://github.com/sponsors/0x0f0f0f/)! +

-``` \ No newline at end of file diff --git a/docs/src/rewrite.md b/docs/src/rewrite.md index a8923b60..cdb1c2da 100644 --- a/docs/src/rewrite.md +++ b/docs/src/rewrite.md @@ -27,9 +27,9 @@ The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent **Rule operators**: - `LHS => RHS`: create a `DynamicRule`. The RHS is *evaluated* on rewrite. -- `LHS --> RHS`: create a `RewriteRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite. -- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `RewriteRule` but can go in both directions. Doesn't work in classical rewriting. -- `LHS ≠ RHS`: create a `UnequalRule`. Can only be used in e-graphs, and is used to eagerly stop the process of rewriting if LHS is found to be equal to RHS. +- `LHS --> RHS`: create a `DirectedRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite. +- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `DirectedRule` but can go in both directions. Doesn't work in classical rewriting. +- `LHS != RHS`: create a `UnequalRule`. Can only be used in e-graphs, and is used to eagerly stop the process of rewriting if LHS is found to be equal to RHS. You can use **dynamic rules**, defined with the `=>` diff --git a/docs/src/visualizing.md b/docs/src/visualizing.md index 10adde91..a300870c 100644 --- a/docs/src/visualizing.md +++ b/docs/src/visualizing.md @@ -1,14 +1,13 @@ # Visualizing E-Graphs -You can visualize e-graphs in VSCode by using [GraphViz.jl]() +You can visualize e-graphs in VSCode by using [GraphViz.jl](https://github.com/JuliaGraphs/GraphViz.jl) -All you need to do is to install GraphViz.jl and to evaluate an e-graph after including the extra script: +All you need to do is to install GraphViz.jl and load it: ```julia +using Metatheory using GraphViz -include(dirname(pathof(Metatheory)) * "/extras/graphviz.jl") - algebra_rules = @theory a b c begin a * (b * c) == (a * b) * c a + (b + c) == (a + b) + c @@ -38,4 +37,4 @@ g And you will see a nice e-graph drawing in the Julia Plots VSCode panel: -![E-Graph Drawing](/assets/graphviz.svg) \ No newline at end of file +![E-Graph Drawing](assets/graphviz.svg) diff --git a/examples/basic_maths_theory.jl b/examples/basic_maths_theory.jl index 7fd39df4..cdcb5949 100644 --- a/examples/basic_maths_theory.jl +++ b/examples/basic_maths_theory.jl @@ -40,8 +40,8 @@ function customlt(x, y) end end +# restores n-arity of binarized + and * expressions canonical_t = @theory x y xs ys begin - # restore n-arity (x + (+)(ys...)) --> +(x, ys...) ((+)(xs...) + y) --> +(xs..., y) (x * (*)(ys...)) --> *(x, ys...) diff --git a/examples/propositional_logic_theory.jl b/examples/propositional_logic_theory.jl index 32c1670c..41cde057 100644 --- a/examples/propositional_logic_theory.jl +++ b/examples/propositional_logic_theory.jl @@ -1,5 +1,8 @@ # # Rewriting +using Metatheory +using Metatheory.TermInterface + fold = @theory p q begin (p::Bool == q::Bool) => (p == q) (p::Bool || q::Bool) => (p || q) @@ -25,24 +28,20 @@ and_alg = @theory p q r begin end comb = @theory p q r begin - # DeMorgan - !(p || q) == (!p && !q) + !(p || q) == (!p && !q) # DeMorgan !(p && q) == (!p || !q) - # distrib - (p && (q || r)) == ((p && q) || (p && r)) + (p && (q || r)) == ((p && q) || (p && r)) # Distributivity (p || (q && r)) == ((p || q) && (p || r)) - # absorb - (p && (p || q)) --> p + (p && (p || q)) --> p # Absorb (p || (p && q)) --> p - # complement - (p && (!p || q)) --> p && q + (p && (!p || q)) --> p && q # Complement (p || (!p && q)) --> p || q end negt = @theory p begin (p && !p) --> false (p || !(p)) --> true - !(!p) == p + !(!p) --> p end impl = @theory p q begin @@ -53,41 +52,3 @@ impl = @theory p q begin end propositional_logic_theory = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold - - -# Sketch function for basic iterative saturation and extraction -function prove( - t, - ex, - steps = 1, - timeout = 10, - params = SaturationParams( - timeout = timeout, - scheduler = Schedulers.BackoffScheduler, - schedulerparams = (6000, 5), - timer = false, - ), -) - hist = UInt64[] - push!(hist, hash(ex)) - for i in 1:steps - g = EGraph(ex) - - exprs = [true, g[g.root]] - ids = [addexpr!(g, e) for e in exprs] - - goal = (g::EGraph) -> in_same_class(g, ids...) - params.goal = goal - saturate!(g, t, params) - ex = extract!(g, astsize) - if !Metatheory.istree(ex) - return ex - end - if hash(ex) ∈ hist - return ex - end - push!(hist, hash(ex)) - end - return ex -end - diff --git a/examples/prove.jl b/examples/prove.jl index cfc679cf..dfce791f 100644 --- a/examples/prove.jl +++ b/examples/prove.jl @@ -1,4 +1,3 @@ -# TODO: should this go in MT itself? # Sketch function for basic iterative saturation and extraction function prove( t, @@ -8,13 +7,11 @@ function prove( params = SaturationParams( timeout = timeout, scheduler = Schedulers.BackoffScheduler, - schedulerparams = (6000, 5), + schedulerparams = (match_limit = 6000, ban_length = 5), timer = false, ), ) - # hist = UInt64[] - # push!(hist, hash(ex)) - for i in 1:steps + for _ in 1:steps g = EGraph(ex) ids = [addexpr!(g, true), g.root] @@ -25,11 +22,21 @@ function prove( if !TermInterface.isexpr(ex) return ex end - # if hash(ex) ∈ hist - # return ex - # end - # push!(hist, hash(ex)) end return ex end +function test_equality(t, exprs...; params = SaturationParams(), g = EGraph()) + length(exprs) == 1 && return true + ids = [addexpr!(g, ex) for ex in exprs] + params = deepcopy(params) + params.goal = (g::EGraph) -> in_same_class(g, ids...) + + report = saturate!(g, t, params) + goal_reached = params.goal(g) + + if !(report.reason === :saturated) && !goal_reached + return false # failed to prove + end + return goal_reached +end diff --git a/src/extras/graphviz.jl b/ext/Plotting.jl similarity index 88% rename from src/extras/graphviz.jl rename to ext/Plotting.jl index 2316f97b..57fb0c3b 100644 --- a/src/extras/graphviz.jl +++ b/ext/Plotting.jl @@ -1,3 +1,5 @@ +module Plotting + using GraphViz using Metatheory using TermInterface @@ -24,7 +26,7 @@ function render_eclass!(io::IO, g::EGraph, eclass::EClass) """ subgraph cluster_$(eclass.id) { style="dotted,rounded"; rank=same; - label="#$(eclass.id). Smallest: $(extract!(g, astsize; root=eclass.id))" + label="%$(eclass.id). Smallest: $(extract!(g, astsize, eclass.id))" fontcolor = gray fontsize = 8 """, @@ -46,8 +48,8 @@ function render_eclass!(io::IO, g::EGraph, eclass::EClass) end -function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::AbstractENode) - label = operation(node) +function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::VecExpr) + label = get_constant(g, v_head(node)) # (mr, style) = if node in diff && get(report.cause, node, missing) !== missing # pair = get(report.cause, node, missing) # split(split("$(pair[1].rule) ", "=>")[1], "-->")[1], " color=\"red\"" @@ -58,11 +60,10 @@ function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::Abstract println(io, " $eclass_id.$i [label=<$label> shape=box style=rounded]") end -render_enode_edges!(::IO, ::EGraph, eclass_id, i, ::ENodeLiteral) = nothing - -function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENodeTerm) - len = length(arguments(node)) - for (ite, child) in enumerate(arguments(node)) +function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::VecExpr) + v_isexpr(node) || return nothing + len = length(v_children(node)) + for (ite, child) in enumerate(v_children(node)) cluster_id = find(g, child) # The limitation of graphviz is that it cannot point to the eclass outer frame, # so when pointing to the same e-class, the next best thing is to point to the same e-node. @@ -93,4 +94,6 @@ end function Base.show(io::IO, mime::MIME"image/svg+xml", g::EGraph) show(io, mime, convert(GraphViz.Graph, g)) -end \ No newline at end of file +end + +end diff --git a/scratch/Cargo.toml b/scratch/Cargo.toml deleted file mode 100644 index 078765aa..00000000 --- a/scratch/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "benchmarks" -version = "0.1.0" -authors = ["0x0f0f0f "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -egg = "0.6.0" diff --git a/scratch/Project.toml b/scratch/Project.toml deleted file mode 100644 index 2dfe1985..00000000 --- a/scratch/Project.toml +++ /dev/null @@ -1,6 +0,0 @@ -[deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/scratch/benchmark_logic.jl b/scratch/benchmark_logic.jl deleted file mode 100644 index 5746b608..00000000 --- a/scratch/benchmark_logic.jl +++ /dev/null @@ -1,6 +0,0 @@ -include("prop_logic_theory.jl") -include("prover.jl") - -ex = rewrite(:(((p => q) && (r => s) && (p || r)) => (q || s)), impl) -prove(t, ex, 1, 25) -@profview prove(t, ex, 2, 7) diff --git a/scratch/egg_logic.jl b/scratch/egg_logic.jl deleted file mode 100644 index c26e98fb..00000000 --- a/scratch/egg_logic.jl +++ /dev/null @@ -1,86 +0,0 @@ -include("eggify.jl") -using Metatheory.Library -using Metatheory.EGraphs.Schedulers - -or_alg = @theory begin - ((p || q) || r) == (p || (q || r)) - (p || q) == (q || p) - (p || p) => p - (p || true) => true - (p || false) => p -end - -and_alg = @theory begin - ((p && q) && r) == (p && (q && r)) - (p && q) == (q && p) - (p && p) => p - (p && true) => p - (p && false) => false -end - -comb = @theory begin - # DeMorgan - !(p || q) == (!p && !q) - !(p && q) == (!p || !q) - # distrib - (p && (q || r)) == ((p && q) || (p && r)) - (p || (q && r)) == ((p || q) && (p || r)) - # absorb - (p && (p || q)) => p - (p || (p && q)) => p - # complement - (p && (!p || q)) => p && q - (p || (!p && q)) => p || q -end - -negt = @theory begin - (p && !p) => false - (p || !(p)) => true - !(!p) == p -end - -impl = @theory begin - (p == !p) => false - (p == p) => true - (p == q) => (!p || q) && (!q || p) - (p => q) => (!p || q) -end - -fold = @theory begin - (true == false) => false - (false == true) => false - (true == true) => true - (false == false) => true - (true || false) => true - (false || true) => true - (true || true) => true - (false || false) => false - (true && true) => true - (false && true) => false - (true && false) => false - (false && false) => false - !(true) => false - !(false) => true -end - -theory = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold - - -query = :(!(((!p || q) && (!r || s)) && (p || r)) || (q || s)) - -########################################### - -params = SaturationParams(timeout = 22, eclasslimit = 3051, scheduler = ScoredScheduler)#, schedulerparams=(1000,5, Schedulers.exprsize)) - -for i in 1:2 - G = EGraph(query) - report = saturate!(G, theory, params) - ex = extract!(G, astsize) - println("Best found: $ex") - println(report) -end - - -open("src/main.rs", "w") do f - write(f, rust_code(theory, query, params)) -end diff --git a/scratch/egg_maths.jl b/scratch/egg_maths.jl deleted file mode 100644 index 0ee1c72c..00000000 --- a/scratch/egg_maths.jl +++ /dev/null @@ -1,88 +0,0 @@ -include("eggify.jl") -using Metatheory.Library -using Metatheory.EGraphs.Schedulers - -mult_t = commutative_monoid(:(*), 1) -plus_t = commutative_monoid(:(+), 0) - -minus_t = @theory begin - a - a => 0 - a + (-b) => a - b -end - -mulplus_t = @theory begin - 0 * a => 0 - a * 0 => 0 - a * (b + c) == ((a * b) + (a * c)) - a + (b * a) => ((b + 1) * a) -end - -pow_t = @theory begin - (y^n) * y => y^(n + 1) - x^n * x^m == x^(n + m) - (x * y)^z == x^z * y^z - (x^p)^q == x^(p * q) - x^0 => 1 - 0^x => 0 - 1^x => 1 - x^1 => x - inv(x) == x^(-1) -end - -function customlt(x, y) - if typeof(x) == Expr && Expr == typeof(y) - false - elseif typeof(x) == typeof(y) - isless(x, y) - elseif x isa Symbol && y isa Number - false - else - true - end -end - -canonical_t = @theory begin - # restore n-arity - (x + (+)(ys...)) => +(x, ys...) - ((+)(xs...) + y) => +(xs..., y) - (x * (*)(ys...)) => *(x, ys...) - ((*)(xs...) * y) => *(xs..., y) - - (*)(xs...) |> Expr(:call, :*, sort!(xs; lt = customlt)...) - (+)(xs...) |> Expr(:call, :+, sort!(xs; lt = customlt)...) -end - - -cas = mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t ∪ pow_t -theory = cas - -query = cleanast(:(a + b + (0 * c) + d)) - - -function simplify(ex) - g = EGraph(ex) - params = SaturationParams( - scheduler = BackoffScheduler, - timeout = 20, - schedulerparams = (1000, 5), # fuel and bantime - ) - report = saturate!(g, cas, params) - println(report) - res = extract!(g, astsize) - res = rewrite(res, canonical_t; clean = false, m = @__MODULE__) # this just orders symbols and restores n-ary plus and mult - res -end - -########################################### - -params = SaturationParams(timeout = 20, schedulerparams = (1000, 5)) - -for i in 1:2 - ex = simplify(:(a + b + (0 * c) + d)) - println("Best found: $ex") -end - - -open("src/main.rs", "w") do f - write(f, rust_code(theory, query)) -end diff --git a/scratch/eggify.jl b/scratch/eggify.jl deleted file mode 100644 index 04e82b2c..00000000 --- a/scratch/eggify.jl +++ /dev/null @@ -1,54 +0,0 @@ -using Metatheory -using Metatheory.EGraphs - -to_sexpr_pattern(p::PatLiteral) = "$(p.val)" -to_sexpr_pattern(p::PatVar) = "?$(p.name)" -function to_sexpr_pattern(p::PatTerm) - e1 = join([p.head; to_sexpr_pattern.(p.args)], ' ') - "($e1)" -end - -to_sexpr(e::Symbol) = e -to_sexpr(e::Int64) = e -to_sexpr(e::Expr) = "($(join(to_sexpr.(e.args),' ')))" - -function eggify(rules) - egg_rules = [] - for rule in rules - l = to_sexpr_pattern(rule.left) - r = to_sexpr_pattern(rule.right) - if rule isa SymbolicRule - push!(egg_rules, "\tvec![rw!( \"$(rule.left) => $(rule.right)\" ; \"$l\" => \"$r\" )]") - elseif rule isa EqualityRule - push!(egg_rules, "\trw!( \"$(rule.left) == $(rule.right)\" ; \"$l\" <=> \"$r\" )") - else - println("Unsupported Rewrite Mode") - @assert false - end - - end - return join(egg_rules, ",\n") -end - -function rust_code(theory, query, params = SaturationParams()) - """ - use egg::{*, rewrite as rw}; - //use std::time::Duration; - fn main() { - let rules : &[Rewrite] = &vec![ - $(eggify(theory)) - ].concat(); - - let start = "$(to_sexpr(cleanast(query)))".parse().unwrap(); - let runner = Runner::default().with_expr(&start) - // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html - .with_iter_limit($(params.timeout)) - .with_node_limit($(params.enodelimit)) - .run(rules); - runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, AstSize); - let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - println!("best cost: {}, best expr {}", best_cost, best_expr); - } - """ -end diff --git a/scratch/figures/fib.pdf b/scratch/figures/fib.pdf deleted file mode 100644 index 55874cf8..00000000 Binary files a/scratch/figures/fib.pdf and /dev/null differ diff --git a/scratch/gen_egg_instructions.md b/scratch/gen_egg_instructions.md deleted file mode 100644 index 2bf4a57d..00000000 --- a/scratch/gen_egg_instructions.md +++ /dev/null @@ -1,41 +0,0 @@ -This is a simple script to convert Metatheory.jl theories into an Egg query for comparison. - -Get a rust toolchain - -Make a new project - -``` -cargo new my_project -cd my_project -``` - -Add egg as a dependency to the Cargo.toml. Add the last line shown here. - -``` -[package] -name = "autoegg" -version = "0.1.0" -authors = ["Philip Zucker "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -egg = "0.6.0" -``` - -Copy and paste the Julia script in the project folder. Replace the example theory and query with yours in the script - -Run it - -``` -julia gen_egg.jl -``` - -Now you can run it in Egg - -``` -cargo run --release -``` - -Profit. diff --git a/scratch/src/main.rs b/scratch/src/main.rs deleted file mode 100644 index a885fae3..00000000 --- a/scratch/src/main.rs +++ /dev/null @@ -1,56 +0,0 @@ -use egg::{*, rewrite as rw}; -//use std::time::Duration; -fn main() { - let rules : &[Rewrite] = &vec![ - vec![rw!( "p || q || r => p || q || r" ; "(|| (|| ?p ?q) ?r)" => "(|| ?p (|| ?q ?r))" )], - vec![rw!( "p || q => q || p" ; "(|| ?p ?q)" => "(|| ?q ?p)" )], - vec![rw!( "p || p => p" ; "(|| ?p ?p)" => "?p" )], - vec![rw!( "p || true => true" ; "(|| ?p true)" => "true" )], - vec![rw!( "p || false => p" ; "(|| ?p false)" => "?p" )], - vec![rw!( "p && q && r => p && q && r" ; "(&& (&& ?p ?q) ?r)" => "(&& ?p (&& ?q ?r))" )], - vec![rw!( "p && q => q && p" ; "(&& ?p ?q)" => "(&& ?q ?p)" )], - vec![rw!( "p && p => p" ; "(&& ?p ?p)" => "?p" )], - vec![rw!( "p && true => p" ; "(&& ?p true)" => "?p" )], - vec![rw!( "p && false => false" ; "(&& ?p false)" => "false" )], - vec![rw!( "!p || q => !p && !q" ; "(! (|| ?p ?q))" => "(&& (! ?p) (! ?q))" )], - vec![rw!( "!p && q => !p || !q" ; "(! (&& ?p ?q))" => "(|| (! ?p) (! ?q))" )], - vec![rw!( "p && q || r => p && q || p && r" ; "(&& ?p (|| ?q ?r))" => "(|| (&& ?p ?q) (&& ?p ?r))" )], - vec![rw!( "p || q && r => p || q && p || r" ; "(|| ?p (&& ?q ?r))" => "(&& (|| ?p ?q) (|| ?p ?r))" )], - vec![rw!( "p && p || q => p" ; "(&& ?p (|| ?p ?q))" => "?p" )], - vec![rw!( "p || p && q => p" ; "(|| ?p (&& ?p ?q))" => "?p" )], - vec![rw!( "p && !p || q => p && q" ; "(&& ?p (|| (! ?p) ?q))" => "(&& ?p ?q)" )], - vec![rw!( "p || !p && q => p || q" ; "(|| ?p (&& (! ?p) ?q))" => "(|| ?p ?q)" )], - vec![rw!( "p && !p => false" ; "(&& ?p (! ?p))" => "false" )], - vec![rw!( "p || !p => true" ; "(|| ?p (! ?p))" => "true" )], - vec![rw!( "!!p => p" ; "(! (! ?p))" => "?p" )], - vec![rw!( "p == !p => false" ; "(== ?p (! ?p))" => "false" )], - vec![rw!( "p == p => true" ; "(== ?p ?p)" => "true" )], - vec![rw!( "p == q => !p || q && !q || p" ; "(== ?p ?q)" => "(&& (|| (! ?p) ?q) (|| (! ?q) ?p))" )], - vec![rw!( "p => q => !p || q" ; "(=> ?p ?q)" => "(|| (! ?p) ?q)" )], - vec![rw!( "true == false => false" ; "(== true false)" => "false" )], - vec![rw!( "false == true => false" ; "(== false true)" => "false" )], - vec![rw!( "true == true => true" ; "(== true true)" => "true" )], - vec![rw!( "false == false => true" ; "(== false false)" => "true" )], - vec![rw!( "true || false => true" ; "(|| true false)" => "true" )], - vec![rw!( "false || true => true" ; "(|| false true)" => "true" )], - vec![rw!( "true || true => true" ; "(|| true true)" => "true" )], - vec![rw!( "false || false => false" ; "(|| false false)" => "false" )], - vec![rw!( "true && true => true" ; "(&& true true)" => "true" )], - vec![rw!( "false && true => false" ; "(&& false true)" => "false" )], - vec![rw!( "true && false => false" ; "(&& true false)" => "false" )], - vec![rw!( "false && false => false" ; "(&& false false)" => "false" )], - vec![rw!( "!true => false" ; "(! true)" => "false" )], - vec![rw!( "!false => true" ; "(! false)" => "true" )] - ].concat(); - - let start = "(|| (! (&& (&& (|| (! p) q) (|| (! r) s)) (|| p r))) (|| q s))".parse().unwrap(); - let runner = Runner::default().with_expr(&start) - // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html - .with_iter_limit(22) - .with_node_limit(15000) - .run(rules); - runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, AstSize); - let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - println!("best cost: {}, best expr {}", best_cost, best_expr); -} diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 1a1bdc6a..b9e2a1bb 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -2,59 +2,34 @@ module EGraphs include("../docstrings.jl") -using DataStructures using TermInterface using TimerOutputs -using Metatheory: alwaystrue, cleanast, binarize using Metatheory.Patterns using Metatheory.Rules -using Metatheory.EMatchCompiler +using Metatheory.VecExprModule -include("intdisjointmap.jl") -export IntDisjointSet -export in_same_set +using Metatheory: alwaystrue, cleanast, UNDEF_ID_VEC, maybe_quote_operation, OptBuffer + +import Metatheory: to_expr + +include("unionfind.jl") +export UnionFind + +include("uniquequeue.jl") include("egraph.jl") -export AbstractENode -export ENodeLiteral -export ENodeTerm -export EClassId -export EClass -export hasdata -export getdata -export setdata! -export find -export lookup -export arity -export EGraph -export merge! -export in_same_class -export addexpr! -export rebuild! -export settermtype! -export gettermtype - -include("analysis.jl") -export analyze! -export extract! -export astsize -export astsize_inv -export getcost! - -export Sub +export Id, + EClass, find, lookup, arity, EGraph, merge!, in_same_class, addexpr!, rebuild!, has_constant, get_constant, lookup_pat + +include("extract.jl") +export extract!, astsize, astsize_inv + include("Schedulers.jl") export Schedulers using .Schedulers include("saturation.jl") -export SaturationGoal -export EqualityGoal -export reached -export SaturationParams -export saturate! -export areequal -export @areequal -export @areequalg +export SaturationParams, saturate! end diff --git a/src/EGraphs/Schedulers.jl b/src/EGraphs/Schedulers.jl index 6ca3d36b..a2178e41 100644 --- a/src/EGraphs/Schedulers.jl +++ b/src/EGraphs/Schedulers.jl @@ -7,14 +7,8 @@ using Metatheory.EGraphs using Metatheory.Patterns using DocStringExtensions -export AbstractScheduler -export SimpleScheduler -export BackoffScheduler -export ScoredScheduler -export cansaturate -export cansearch -export inform! -export setiter! +export AbstractScheduler, + SimpleScheduler, BackoffScheduler, FreezingScheduler, ScoredScheduler, cansaturate, cansearch, inform!, setiter! """ Represents a rule scheduler for the equality saturation process @@ -23,33 +17,61 @@ Represents a rule scheduler for the equality saturation process abstract type AbstractScheduler end """ + cansaturate(s::AbstractScheduler) + Should return `true` if the e-graph can be said to be saturated -``` -cansaturate(s::AbstractScheduler) -``` """ function cansaturate end """ -Should return `false` if the rule `r` should be skipped -``` -cansearch(s::AbstractScheduler, r::Rule) -``` + cansearch(s::AbstractScheduler, i::Int) + cansearch(s::AbstractScheduler, i::Int, eclass_id::Id) + +Given a theory `t` and a rule `r` with index `i` in the theory, +should return `false` if the search for rule with index `i` should be skipped +for the current iteration. An extra `eclass_id::Id` arguments can be passed +in order to filter out specific e-classes. """ function cansearch end """ -This function is called **after** pattern matching on the e-graph, -informs the scheduler about the yielded matches. -Returns `false` if the matches should not be yielded and ignored. -``` -inform!(s::AbstractScheduler, r::AbstractRule, n_matches) -``` + inform!(s::AbstractScheduler, i::Int, n_matches) + inform!(s::AbstractScheduler, i::Int, eclass_id::Id, n_matches) + + +Given a theory `t` and a rule `r` with index `i` in the theory, +This function is called **after** pattern matching (searching) the e-graph, +it informs the scheduler about the number of yielded matches. """ function inform! end +""" + setiter!(s::AbstractScheduler, i::Int) + +Inform a scheduler about the current iteration number. +""" function setiter! end +""" + rebuild!(s::AbstractScheduler, g::EGraph) + +Some schedulers may hold data that need to be re-canonicalized +after an iteration of equality saturation, such as references to e-class IDs. +This is called by equality saturation after e-graph `rebuild!` +""" +function rebuild! end + +# =========================================================================== +# Defaults +# =========================================================================== + +@inline inform!(::AbstractScheduler, ::Int, ::Int) = nothing +@inline inform!(::AbstractScheduler, ::Int, ::Id, ::Int) = nothing +@inline setiter!(::AbstractScheduler, ::Int) = nothing +@inline rebuild!(::AbstractScheduler) = nothing + + + # =========================================================================== # SimpleScheduler # =========================================================================== @@ -60,26 +82,16 @@ A simple Rewrite Scheduler that applies every rule every time """ struct SimpleScheduler <: AbstractScheduler end -cansaturate(s::SimpleScheduler) = true -cansearch(s::SimpleScheduler, r::AbstractRule) = true -function SimpleScheduler(G::EGraph, theory::Vector{<:AbstractRule}) - SimpleScheduler() -end -inform!(s::SimpleScheduler, r, n_matches) = true -setiter!(s::SimpleScheduler, iteration) = nothing +SimpleScheduler(::EGraph, ::Theory) = SimpleScheduler() +@inline cansaturate(s::SimpleScheduler) = true +@inline cansearch(s::SimpleScheduler, ::Int) = true +@inline cansearch(s::SimpleScheduler, ::Int, ::Id) = true # =========================================================================== # BackoffScheduler # =========================================================================== -mutable struct BackoffSchedulerEntry - match_limit::Int - ban_length::Int - times_banned::Int - banned_until::Int -end - """ A Rewrite Scheduler that implements exponential rule backoff. For each rewrite, there exists a configurable initial match limit. @@ -90,163 +102,100 @@ will be banned next time. This seems effective at preventing explosive rules like associativity from taking an unfair amount of resources. """ -mutable struct BackoffScheduler <: AbstractScheduler - data::IdDict{AbstractRule,BackoffSchedulerEntry} - G::EGraph - theory::Vector{<:AbstractRule} - curr_iter::Int +Base.@kwdef mutable struct BackoffScheduler <: AbstractScheduler + data::Vector{Tuple{Int,Int}} # TimesBanned ⊗ BannedUntil + g::EGraph + theory::Theory + curr_iter::Int = 1 + match_limit::Int = 1000 + ban_length::Int = 5 end -cansearch(s::BackoffScheduler, r::AbstractRule)::Bool = s.curr_iter > s.data[r].banned_until - - -function BackoffScheduler(g::EGraph, theory::Vector{<:AbstractRule}) - # BackoffScheduler(g, theory, 128, 4) - BackoffScheduler(g, theory, 1000, 5) -end +@inline cansearch(s::BackoffScheduler, rule_idx::Int)::Bool = s.curr_iter > last(s.data[rule_idx]) +@inline cansearch(s::BackoffScheduler, rule_idx::Int, eclass_id::Id) = true -function BackoffScheduler(G::EGraph, theory::Vector{<:AbstractRule}, match_limit::Int, ban_length::Int) - gsize = length(G.uf) - data = IdDict{AbstractRule,BackoffSchedulerEntry}() - - for rule in theory - data[rule] = BackoffSchedulerEntry(match_limit, ban_length, 0, 0) - end - - return BackoffScheduler(data, G, theory, 1) -end +BackoffScheduler(g::EGraph, theory::Theory; kwargs...) = + BackoffScheduler(; data = fill((0, 0), length(theory)), g, theory, kwargs...) # can saturate if there's no banned rule -cansaturate(s::BackoffScheduler)::Bool = all(kv -> s.curr_iter > last(kv).banned_until, s.data) +cansaturate(s::BackoffScheduler)::Bool = all((<)(s.curr_iter) ∘ last, s.data) -function inform!(s::BackoffScheduler, rule::AbstractRule, n_matches) - rd = s.data[rule] - treshold = rd.match_limit << rd.times_banned - if n_matches > treshold - ban_length = rd.ban_length << rd.times_banned - rd.times_banned += 1 - rd.banned_until = s.curr_iter + ban_length - return false +function inform!(s::BackoffScheduler, rule_idx::Int, n_matches::Int) + (times_banned, _) = s.data[rule_idx] + threshold = s.match_limit << times_banned + if n_matches > threshold + s.data[rule_idx] = (times_banned += 1, s.curr_iter + (s.ban_length << times_banned)) end - return true end -function setiter!(s::BackoffScheduler, curr_iter) +function setiter!(s::BackoffScheduler, curr_iter::Int) s.curr_iter = curr_iter end + # =========================================================================== -# ScoredScheduler +# FreezingScheduler # =========================================================================== - -mutable struct ScoredSchedulerEntry - match_limit::Int - ban_length::Int +struct FreezingSchedulerStat times_banned::Int banned_until::Int - weight::Int -end - -""" -A Rewrite Scheduler that implements exponential rule backoff. -For each rewrite, there exists a configurable initial match limit. -If a rewrite search yield more than this limit, then we ban this rule -for number of iterations, double its limit, and double the time it -will be banned next time. - -This seems effective at preventing explosive rules like -associativity from taking an unfair amount of resources. -""" -mutable struct ScoredScheduler <: AbstractScheduler - data::IdDict{AbstractRule,ScoredSchedulerEntry} - G::EGraph - theory::Vector{<:AbstractRule} - curr_iter::Int + size_limit::Int + ban_length::Int end -cansearch(s::ScoredScheduler, r::AbstractRule)::Bool = s.curr_iter > s.data[r].banned_until - -exprsize(a) = 1 - -function exprsize(e::PatTerm) - c = 1 + length(e.args) - for a in e.args - c += exprsize(a) - end - return c +Base.@kwdef mutable struct FreezingScheduler <: AbstractScheduler + data::Dict{Id,FreezingSchedulerStat} = Dict{Id,FreezingSchedulerStat}() + g::EGraph + theory::Theory + curr_iter::Int = 1 + default_eclass_size_limit::Int = 10 + default_eclass_size_increment::Int = 3 + default_eclass_ban_length::Int = 3 + default_eclass_ban_increment::Int = 2 end -function exprsize(e::Expr) - start = Meta.isexpr(e, :call) ? 2 : 1 +FreezingScheduler(g::EGraph, theory::Theory; kwargs...) = FreezingScheduler(; g, theory, kwargs...) - c = 1 + length(e.args[start:end]) - for a in e.args[start:end] - c += exprsize(a) - end +@inline cansearch(s::FreezingScheduler, rule_idx::Int)::Bool = true +@inline cansearch(s::FreezingScheduler, ::Int, eclass_id::Id) = s.curr_iter > s[eclass_id].banned_until - return c -end +function Base.getindex(s::FreezingScheduler, id::Id) + haskey(s.data, id) && return s.data[id] + nid = find(s.g, id) + haskey(s.data, nid) && return s.data[nid] -function ScoredScheduler(g::EGraph, theory::Vector{<:AbstractRule}) - # BackoffScheduler(g, theory, 128, 4) - ScoredScheduler(g, theory, 1000, 5, exprsize) + s.data[id] = FreezingSchedulerStat(0, 0, s.default_eclass_size_limit, s.default_eclass_ban_length) end -function ScoredScheduler( - G::EGraph, - theory::Vector{<:AbstractRule}, - match_limit::Int, - ban_length::Int, - complexity::Function, -) - gsize = length(G.uf) - data = IdDict{AbstractRule,ScoredSchedulerEntry}() - - for rule in theory - if rule isa DynamicRule - w = 2 - data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) - continue - end - (l, r) = rule.left, rule.right - - cl = complexity(l) - cr = complexity(r) - if cl > cr - w = 1 # reduces complexity - elseif cr > cl - w = 3 # augments complexity - else - w = 2 # complexity is equal - end - data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) - end +# can saturate if there's no banned rule +cansaturate(s::FreezingScheduler)::Bool = all(stat -> stat.banned_until < s.curr_iter, values(s.data)) - return ScoredScheduler(data, G, theory, 1) -end +function inform!(s::FreezingScheduler, rule_idx::Int, n_matches::Int, eclass_id::Id) + stats = s[eclass_id] + threshold = stats.size_limit + s.default_eclass_size_increment * stats.times_banned + len = length(s.g[eclass_id]) -# can saturate if there's no banned rule -cansaturate(s::ScoredScheduler)::Bool = all(kv -> s.curr_iter > last(kv).banned_until, s.data) - - -function inform!(s::ScoredScheduler, rule::AbstractRule, n_matches) - rd = s.data[rule] - treshold = rd.match_limit * (rd.weight^rd.times_banned) - if n_matches > treshold - ban_length = rd.ban_length * (rd.weight^rd.times_banned) - rd.times_banned += 1 - rd.banned_until = s.curr_iter + ban_length - # @info "banning rule $rule until $(rd.banned_until)!" - return false + if len > threshold + ban_length = stats.ban_length + s.default_eclass_ban_increment * stats.times_banned + stats.times_banned += 1 + stats.banned_until = s.curr_iter + ban_length end - return true end -function setiter!(s::ScoredScheduler, curr_iter) +function setiter!(s::FreezingScheduler, curr_iter::Int) s.curr_iter = curr_iter end +function rebuild!(s::FreezingScheduler) + new_data = Dict{Id,FreezingSchedulerStat}() + for (id, stats) in s.data + new_data[find(s.g, id)] = stats + end + finalize(s.data) + s.data = new_data + true +end end diff --git a/src/EGraphs/analysis.jl b/src/EGraphs/analysis.jl deleted file mode 100644 index 2510cd62..00000000 --- a/src/EGraphs/analysis.jl +++ /dev/null @@ -1,209 +0,0 @@ -analysis_reference(x::Symbol) = Val(x) -analysis_reference(x::Function) = x -analysis_reference(x) = error("$x is not a valid analysis reference") - -""" - islazy(::Val{analysis_name}) - -Should return `true` if the EGraph Analysis `an` is lazy -and false otherwise. A *lazy* EGraph Analysis is computed -only when [analyze!](@ref) is called. *Non-lazy* -analyses are instead computed on-the-fly every time ENodes are added to the EGraph or -EClasses are merged. -""" -islazy(::Val{analysis_name}) where {analysis_name} = false -islazy(analysis_name) = islazy(analysis_reference(analysis_name)) - -""" - modify!(::Val{analysis_name}, g, id) - -The `modify!` function for EGraph Analysis can optionally modify the eclass -`g[id]` after it has been analyzed, typically by adding an ENode. -It should be **idempotent** if no other changes occur to the EClass. -(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). -""" -modify!(::Val{analysis_name}, g, id) where {analysis_name} = nothing -modify!(an, g, id) = modify!(analysis_reference(an), g, id) - - -""" - join(::Val{analysis_name}, a, b) - -Joins two analyses values into a single one, used by [analyze!](@ref) -when two eclasses are being merged or the analysis is being constructed. -""" -join(analysis::Val{analysis_name}, a, b) where {analysis_name} = - error("Analysis $analysis_name does not implement join") -join(an, a, b) = join(analysis_reference(an), a, b) - -""" - make(::Val{analysis_name}, g, n) - -Given an ENode `n`, `make` should return the corresponding analysis value. -""" -make(::Val{analysis_name}, g, n) where {analysis_name} = error("Analysis $analysis_name does not implement make") -make(an, g, n) = make(analysis_reference(an), g, n) - -analyze!(g::EGraph, analysis_ref, id::EClassId) = analyze!(g, analysis_ref, reachable(g, id)) -analyze!(g::EGraph, analysis_ref) = analyze!(g, analysis_ref, collect(keys(g.classes))) - - -""" - analyze!(egraph, analysis_name, [ECLASS_IDS]) - -Given an [EGraph](@ref) and an `analysis` identified by name `analysis_name`, -do an automated bottom up trasversal of the EGraph, associating a value from the -domain of analysis to each ENode in the egraph by the [make](@ref) function. -Then, for each [EClass](@ref), compute the [join](@ref) of the children ENodes analyses values. -After `analyze!` is called, an analysis value will be associated to each EClass in the EGraph. -One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref). -""" -function analyze!(g::EGraph, analysis_ref, ids::Vector{EClassId}) - addanalysis!(g, analysis_ref) - ids = sort(ids) - # @assert isempty(g.dirty) - - did_something = true - while did_something - did_something = false - - for id in ids - eclass = g[id] - id = eclass.id - pass = mapreduce(x -> make(analysis_ref, g, x), (x, y) -> join(analysis_ref, x, y), eclass) - - if !isequal(pass, getdata(eclass, analysis_ref, missing)) - setdata!(eclass, analysis_ref, pass) - did_something = true - push!(g.dirty, id) - end - end - end - - for id in ids - eclass = g[id] - id = eclass.id - if !hasdata(eclass, analysis_ref) - error("failed to compute analysis for eclass ", id) - end - end - - return true -end - -""" -A basic cost function, where the computed cost is the size -(number of children) of the current expression. -""" -function astsize(n::ENodeTerm, g::EGraph) - cost = 1 + arity(n) - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, astsize) && (cost += Inf; break) - cost += last(getdata(eclass, astsize)) - end - return cost -end - -astsize(n::ENodeLiteral, g::EGraph) = 1 - -""" -A basic cost function, where the computed cost is the size -(number of children) of the current expression, times -1. -Strives to get the largest expression -""" -function astsize_inv(n::ENodeTerm, g::EGraph) - cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, astsize_inv) && (cost += Inf; break) - cost += last(getdata(eclass, astsize_inv)) - end - return cost -end - -astsize_inv(n::ENodeLiteral, g::EGraph) = -1 - - -""" -When passing a function to analysis functions it is considered as a cost function -""" -make(f::Function, g::EGraph, n::AbstractENode) = (n, f(n, g)) - -join(f::Function, from, to) = last(from) <= last(to) ? from : to - -islazy(::Function) = true -modify!(::Function, g, id) = nothing - -function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing) - eclass = g[id] - if !isnothing(cse_env) && haskey(cse_env, id) - (sym, _) = cse_env[id] - return sym - end - (n, ck) = getdata(eclass, costfun, (nothing, Inf)) - ck == Inf && error("Infinite cost when extracting enode") - - if n isa ENodeLiteral - return n.value - elseif n isa ENodeTerm - children = map(arg -> rec_extract(g, costfun, arg; cse_env = cse_env), n.args) - meta = getdata(eclass, :metadata_analysis, nothing) - T = symtype(n) - egraph_reconstruct_expression(T, operation(n), collect(children); metadata = meta, exprhead = exprhead(n)) - else - error("Unknown ENode Type $(typeof(n))") - end -end - -""" -Given a cost function, extract the expression -with the smallest computed cost from an [`EGraph`](@ref) -""" -function extract!(g::EGraph, costfun::Function; root = -1, cse = false) - if root == -1 - root = g.root - end - analyze!(g, costfun, root) - if cse - # TODO make sure there is no assignments/stateful code!! - cse_env = OrderedDict{EClassId,Tuple{Symbol,Any}}() # - collect_cse!(g, costfun, root, cse_env, Set{EClassId}()) - - body = rec_extract(g, costfun, root; cse_env = cse_env) - - assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env] - # return body - Expr(:let, Expr(:block, assignments...), body) - else - return rec_extract(g, costfun, root) - end -end - - -# Builds a dict e-class id => (symbol, extracted term) of common subexpressions in an e-graph -function collect_cse!(g::EGraph, costfun, id, cse_env, seen) - eclass = g[id] - (cn, ck) = getdata(eclass, costfun, (nothing, Inf)) - ck == Inf && error("Error when computing CSE") - if cn isa ENodeTerm - if id in seen - cse_env[id] = (gensym(), rec_extract(g, costfun, id))#, cse_env=cse_env)) # todo generalize symbol? - return - end - for child_id in arguments(cn) - collect_cse!(g, costfun, child_id, cse_env, seen) - end - push!(seen, id) - end -end - - -function getcost!(g::EGraph, costfun; root = -1) - if root == -1 - root = g.root - end - analyze!(g, costfun, root) - bestnode, cost = getdata(g[root], costfun) - return cost -end diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 4e92e539..ad5dc41a 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -2,205 +2,136 @@ # https://dl.acm.org/doi/10.1145/3434304 -abstract type AbstractENode end - -import Metatheory: maybelock! - -const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple} -const EClassId = Int64 -const TermTypes = Dict{Tuple{Any,Int},Type} -# TODO document bindings -const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}} -const DEFAULT_BUFFER_SIZE = 1048576 - -struct ENodeLiteral <: AbstractENode - value - hash::Ref{UInt} - ENodeLiteral(a) = new(a, Ref{UInt}(0)) -end - -Base.:(==)(a::ENodeLiteral, b::ENodeLiteral) = hash(a) == hash(b) - -TermInterface.istree(n::ENodeLiteral) = false -TermInterface.exprhead(n::ENodeLiteral) = nothing -TermInterface.operation(n::ENodeLiteral) = n.value -TermInterface.arity(n::ENodeLiteral) = 0 +""" + modify!(eclass::EClass{Analysis}) -function Base.hash(t::ENodeLiteral, salt::UInt) - !iszero(salt) && return hash(hash(t, zero(UInt)), salt) - h = t.hash[] - !iszero(h) && return h - h′ = hash(t.value, salt) - t.hash[] = h′ - return h′ -end +The `modify!` function for EGraph Analysis can optionally modify the eclass +`eclass` after it has been analyzed, typically by adding an e-node. +It should be **idempotent** if no other changes occur to the EClass. +(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). +""" +function modify! end -mutable struct ENodeTerm <: AbstractENode - exprhead::Union{Symbol,Nothing} - operation::Any - symtype::Type - args::Vector{EClassId} - hash::Ref{UInt} # hash cache - ENodeTerm(exprhead, operation, symtype, c_ids) = new(exprhead, operation, symtype, c_ids, Ref{UInt}(0)) -end +""" + join(a::AnalysisType, b::AnalysisType)::AnalysisType -function Base.:(==)(a::ENodeTerm, b::ENodeTerm) - hash(a) == hash(b) && a.operation == b.operation -end +Joins two analyses values into a single one, used by [analyze!](@ref) +when two eclasses are being merged or the analysis is being constructed. +""" +function join end +""" + make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr)::AnalysisType where {ExpressionType} -TermInterface.istree(n::ENodeTerm) = true -TermInterface.symtype(n::ENodeTerm) = n.symtype -TermInterface.exprhead(n::ENodeTerm) = n.exprhead -TermInterface.operation(n::ENodeTerm) = n.operation -TermInterface.arguments(n::ENodeTerm) = n.args -TermInterface.arity(n::ENodeTerm) = length(n.args) - -# This optimization comes from SymbolicUtils -# The hash of an enode is cached to avoid recomputing it. -# Shaves off a lot of time in accessing dictionaries with ENodes as keys. -function Base.hash(t::ENodeTerm, salt::UInt) - !iszero(salt) && return hash(hash(t, zero(UInt)), salt) - h = t.hash[] - !iszero(h) && return h - h′ = hash(t.args, hash(t.exprhead, hash(t.operation, salt))) - t.hash[] = h′ - return h′ -end +Given an e-node `n`, `make` should return the corresponding analysis value. +""" +function make end -# parametrize metadata by M -mutable struct EClass - g # EGraph - id::EClassId - nodes::Vector{AbstractENode} - parents::Vector{Pair{AbstractENode,EClassId}} - data::AnalysisData -end +""" + EClass{D} -function toexpr(n::ENodeTerm) - Expr(:call, :ENode, exprhead(n), operation(n), symtype(n), arguments(n)) -end +An `EClass` is an equivalence class of terms. -function Base.show(io::IO, x::ENodeTerm) - print(io, toexpr(x)) +The children and parent nodes are stored as [`VecExpr`](@ref)s for performance, which +means that without a reference to the [`EGraph`](@ref) object we cannot re-build human-readable terms +they represent. The [`EGraph`](@ref) itself comes with pretty printing for human-readable terms. +""" +mutable struct EClass{D} + const id::Id + const nodes::Vector{VecExpr} + const parents::Vector{Pair{VecExpr,Id}} + data::Union{D,Nothing} end -toexpr(n::ENodeLiteral) = operation(n) - -Base.show(io::IO, x::ENodeLiteral) = print(io, toexpr(x)) - -EClass(g, id) = EClass(g, id, AbstractENode[], Pair{AbstractENode,EClassId}[], nothing) -EClass(g, id, nodes, parents) = EClass(g, id, nodes, parents, NamedTuple()) - # Interface for indexing EClass Base.getindex(a::EClass, i) = a.nodes[i] -Base.setindex!(a::EClass, v, i) = setindex!(a.nodes, v, i) -Base.firstindex(a::EClass) = firstindex(a.nodes) -Base.lastindex(a::EClass) = lastindex(a.nodes) -Base.length(a::EClass) = length(a.nodes) # Interface for iterating EClass Base.iterate(a::EClass) = iterate(a.nodes) Base.iterate(a::EClass, state) = iterate(a.nodes, state) +Base.length(a::EClass) = length(a.nodes) + # Showing function Base.show(io::IO, a::EClass) - print(io, "EClass $(a.id) (") - - print(io, "[", Base.join(a.nodes, ", "), "], ") - # print(io, a.data) - print(io, ")") + println(io, "$(typeof(a)) %$(a.id) with $(length(a.nodes)) e-nodes:") + println(io, " data: $(a.data)") + println(io, " nodes:") + for n in a.nodes + println(io, " $n") + end end -function addparent!(a::EClass, n::AbstractENode, id::EClassId) +function addparent!(@nospecialize(a::EClass), n::VecExpr, id::Id) push!(a.parents, (n => id)) end -function Base.union!(to::EClass, from::EClass) - # TODO revisit - append!(to.nodes, from.nodes) - append!(to.parents, from.parents) - if !isnothing(to.data) && !isnothing(from.data) - to.data = join_analysis_data!(to.g, something(to.data), something(from.data)) - elseif to.data === nothing - to.data = from.data - end - return to -end -function join_analysis_data!(g, dst::AnalysisData, src::AnalysisData) - new_dst = merge(dst, src) - for analysis_name in keys(src) - analysis_ref = g.analyses[analysis_name] - if hasproperty(dst, analysis_name) - ref = getproperty(new_dst, analysis_name) - ref[] = join(analysis_ref, ref[], getproperty(src, analysis_name)[]) - end +function merge_analysis_data!(a::EClass{D}, b::EClass{D})::Tuple{Bool,Bool,Union{D,Nothing}} where {D} + if !isnothing(a.data) && !isnothing(b.data) + new_a_data = join(a.data, b.data) + (a.data != new_a_data, b.data != new_a_data, new_a_data) + elseif isnothing(a.data) && !isnothing(b.data) + # a merged, b not merged + (true, false, b.data) + elseif !isnothing(a.data) && isnothing(b.data) + (false, true, a.data) + else + (false, false, nothing) end - new_dst end -# Thanks to Shashi Gowda -hasdata(a::EClass, analysis_name::Symbol) = hasproperty(a.data, analysis_name) -hasdata(a::EClass, f::Function) = hasproperty(a.data, nameof(f)) -getdata(a::EClass, analysis_name::Symbol) = getproperty(a.data, analysis_name)[] -getdata(a::EClass, f::Function) = getproperty(a.data, nameof(f))[] -getdata(a::EClass, analysis_ref::Union{Symbol,Function}, default) = - hasdata(a, analysis_ref) ? getdata(a, analysis_ref) : default +""" +There's no need of computing hash for dictionaries where keys are UInt64. +Wrap them in an immutable struct that overrides `hash`. +TODO: this is rather hacky. We need a more performant dict implementation. -setdata!(a::EClass, f::Function, value) = setdata!(a, nameof(f), value) -function setdata!(a::EClass, analysis_name::Symbol, value) - if hasdata(a, analysis_name) - ref = getproperty(a.data, analysis_name) - ref[] = value - else - a.data = merge(a.data, NamedTuple{(analysis_name,)}((Ref{Any}(value),))) - end +Trick from: https://discourse.julialang.org/t/dictionary-with-custom-hash-function/49168 +""" +struct IdKey + val::Id end +Base.hash(a::IdKey, h::UInt) = xor(a.val, h) +Base.:(==)(a::IdKey, b::IdKey) = a.val == b.val -function funs(a::EClass) - map(operation, a.nodes) -end +""" + EGraph{ExpressionType,Analysis} -function funs_arity(a::EClass) - map(a.nodes) do x - (operation(x), arity(x)) - end -end +A concrete type representing an *e-graph*. + +An [`EGraph`](@ref) is a set of equivalence classes ([`EClass`](@ref)). +An `EClass` is in turn a set of e-nodes representing equivalent terms. +An e-node points to a set of children e-classes. +In Metatheory.jl, an e-node is implemented as a [`VecExpr`](@ref) for performance reasons. +The IDs stored in an e-node (i.e. `VecExpr`) or an `EClass` by themselves are +not necessarily very informative, but you can access the terms of each e-node +via `Metatheory.to_expr`. -""" -A concrete type representing an [`EGraph`]. See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for implementation details. """ -mutable struct EGraph +mutable struct EGraph{ExpressionType,Analysis} "stores the equality relations over e-class ids" - uf::IntDisjointSet + uf::UnionFind "map from eclass id to eclasses" - classes::Dict{EClassId,EClass} - "hashcons" - memo::Dict{AbstractENode,EClassId} # memo - "worklist for ammortized upwards merging" - dirty::Vector{EClassId} - root::EClassId - "A vector of analyses associated to the EGraph" - analyses::Dict{Union{Symbol,Function},Union{Symbol,Function}} - "a cache mapping function symbols to e-classes that contain e-nodes with that function symbol." - symcache::Dict{Any,Vector{EClassId}} - default_termtype::Type - termtypes::TermTypes - numclasses::Int - numnodes::Int - "If we use global buffers we may need to lock. Defaults to true." + classes::Dict{IdKey,EClass{Analysis}} + "hashcons mapping e-nodes to their e-class id" + memo::Dict{VecExpr,Id} + "Hashcons the constants in the e-graph" + constants::Dict{UInt64,Any} + "Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass." + pending::Vector{Pair{VecExpr,Id}} + analysis_pending::UniqueQueue{Pair{VecExpr,Id}} + root::Id + "a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol." + classes_by_op::Dict{IdKey,Vector{Id}} + clean::Bool + "If we use global buffers we may need to lock. Defaults to false." needslock::Bool - "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." - buffer::Vector{Bindings} - "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." - merges_buffer::Vector{Tuple{Int,Int}} lock::ReentrantLock end @@ -209,139 +140,144 @@ end EGraph(expr) Construct an EGraph from a starting symbolic expression `expr`. """ -function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE) - EGraph( - IntDisjointSet(), - Dict{EClassId,EClass}(), - Dict{AbstractENode,EClassId}(), - EClassId[], - -1, - Dict{Union{Symbol,Function},Union{Symbol,Function}}(), - Dict{Any,Vector{EClassId}}(), - Expr, - TermTypes(), - 0, +function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {ExpressionType,Analysis} + EGraph{ExpressionType,Analysis}( + UnionFind(), + Dict{IdKey,EClass{Analysis}}(), + Dict{VecExpr,Id}(), + Dict{UInt64,Any}(), + Pair{VecExpr,Id}[], + UniqueQueue{Pair{VecExpr,Id}}(), 0, + Dict{IdKey,Vector{Id}}(), + false, needslock, - Bindings[], - Tuple{Int,Int}[], ReentrantLock(), ) end +EGraph(; kwargs...) = EGraph{Expr,Nothing}(; kwargs...) +EGraph{ExpressionType}(; kwargs...) where {ExpressionType} = EGraph{ExpressionType,Nothing}(; kwargs...) -function maybelock!(f::Function, g::EGraph) - g.needslock ? lock(f, g.buffer_lock) : f() -end - -function EGraph(e; keepmeta = false, kwargs...) - g = EGraph(kwargs...) - keepmeta && addanalysis!(g, :metadata_analysis) - g.root = addexpr!(g, e; keepmeta = keepmeta) +function EGraph{ExpressionType,Analysis}(e; kwargs...) where {ExpressionType,Analysis} + g = EGraph{ExpressionType,Analysis}(; kwargs...) + g.root = addexpr!(g, e) g end -function addanalysis!(g::EGraph, costfun::Function) - g.analyses[nameof(costfun)] = costfun - g.analyses[costfun] = costfun +EGraph{ExpressionType}(e; kwargs...) where {ExpressionType} = EGraph{ExpressionType,Nothing}(e; kwargs...) +EGraph(e; kwargs...) = EGraph{typeof(e),Nothing}(e; kwargs...) + +# Fallback implementation for analysis methods make and modify +@inline make(::EGraph, ::VecExpr) = nothing +@inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing + +@inline get_constant(@nospecialize(g::EGraph), hash::UInt64) = g.constants[hash] +@inline has_constant(@nospecialize(g::EGraph), hash::UInt64)::Bool = haskey(g.constants, hash) + +@inline function add_constant!(@nospecialize(g::EGraph), @nospecialize(c))::Id + h = hash(c) + get!(g.constants, h, c) + h end -function addanalysis!(g::EGraph, analysis_name::Symbol) - g.analyses[analysis_name] = analysis_name +@inline function add_constant_hashed!(@nospecialize(g::EGraph), @nospecialize(c), h::UInt64)::Id + g.constants[h] = c + h end -function settermtype!(g::EGraph, f, ar, T) - g.termtypes[(f, ar)] = T + +function to_expr(g::EGraph, n::VecExpr) + v_isexpr(n) || return get_constant(g, v_head(n)) + h = get_constant(g, v_head(n)) + args = Core.SSAValue.(Int.(v_children(n))) + if v_iscall(n) + maketerm(Expr, :call, [h; args], nothing) + else + maketerm(Expr, h, args, nothing) + end end -function settermtype!(g::EGraph, T) - g.default_termtype = T +function pretty_dict(g::EGraph) + d = Dict{Int,Vector{Any}}() + for (class_id, eclass) in g.classes + d[class_id.val] = map(n -> to_expr(g, n), eclass.nodes) + end + d end +export pretty_dict -function gettermtype(g::EGraph, f, ar) - if haskey(g.termtypes, (f, ar)) - g.termtypes[(f, ar)] - else - g.default_termtype +function Base.show(io::IO, g::EGraph) + d = pretty_dict(g) + t = "$(typeof(g)) with $(length(d)) e-classes:" + cs = map(sort!(collect(d); by = first)) do (k, vect) + " $k => [$(Base.join(vect, ", "))]" end + print(io, Base.join([t; cs], "\n")) end """ Returns the canonical e-class id for a given e-class. """ -find(g::EGraph, a::EClassId)::EClassId = find_root(g.uf, a) -find(g::EGraph, a::EClass)::EClassId = find(g, a.id) - -Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)] - -### Definition 2.3: canonicalization -iscanonical(g::EGraph, n::ENodeTerm) = n == canonicalize(g, n) -iscanonical(g::EGraph, n::ENodeLiteral) = true -iscanonical(g::EGraph, e::EClass) = find(g, e.id) == e.id - -canonicalize(g::EGraph, n::ENodeLiteral) = n +@inline find(g::EGraph, a::Id)::Id = find(g.uf, a) +@inline find(@nospecialize(g::EGraph), @nospecialize(a::EClass))::Id = find(g, a.id) -function canonicalize(g::EGraph, n::ENodeTerm) - if arity(n) > 0 - new_args = map(x -> find(g, x), n.args) - return ENodeTerm(exprhead(n), operation(n), symtype(n), new_args) - end - return n -end +@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))] -function canonicalize!(g::EGraph, n::ENodeTerm) - for (i, arg) in enumerate(n.args) - n.args[i] = find(g, arg) +function canonicalize!(g::EGraph, n::VecExpr) + v_isexpr(n) || @goto ret + for i in (VECEXPR_META_LENGTH + 1):length(n) + @inbounds n[i] = find(g, n[i]) end - n.hash[] = UInt(0) - return n + v_unset_hash!(n) + @label ret + v_hash!(n) + n end -canonicalize!(g::EGraph, n::ENodeLiteral) = n +function lookup(g::EGraph, n::VecExpr)::Id + canonicalize!(g, n) - -function canonicalize!(g::EGraph, e::EClass) - e.id = find(g, e.id) + id = get(g.memo, n, zero(Id)) + iszero(id) ? id : find(g, id) end -function lookup(g::EGraph, n::AbstractENode)::EClassId - cc = canonicalize(g, n) - haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1 + +function add_class_by_op(g::EGraph, n, eclass_id) + key = IdKey(v_signature(n)) + vec = get!(g.classes_by_op, key, Vector{Id}()) + push!(vec, eclass_id) end """ Inserts an e-node in an [`EGraph`](@ref) """ -function add!(g::EGraph, n::AbstractENode)::EClassId - n = canonicalize(g, n) - haskey(g.memo, n) && return g.memo[n] +function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis} + canonicalize!(g, n) + + id = get(g.memo, n, zero(Id)) + iszero(id) || return id + + if should_copy + n = copy(n) + end id = push!(g.uf) # create new singleton eclass - if n isa ENodeTerm - for c_id in arguments(n) - addparent!(g.classes[c_id], n, id) + if v_isexpr(n) + for c_id in v_children(n) + addparent!(g.classes[IdKey(c_id)], n, id) end end g.memo[n] = id - if haskey(g.symcache, operation(n)) - push!(g.symcache[operation(n)], id) - else - g.symcache[operation(n)] = [id] - end + add_class_by_op(g, n, id) + eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n)) + g.classes[IdKey(id)] = eclass + modify!(g, eclass) + push!(g.pending, n => id) - classdata = EClass(g, id, AbstractENode[n], Pair{AbstractENode,EClassId}[]) - g.classes[id] = classdata - g.numclasses += 1 - - for an in values(g.analyses) - if !islazy(an) && an !== :metadata_analysis - setdata!(classdata, an, make(an, g, n)) - modify!(an, g, id) - end - end return id end @@ -362,204 +298,223 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int [`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly insert the literal into the [`EGraph`](@ref). """ -function addexpr!(g::EGraph, se; keepmeta = false)::EClassId +function addexpr!(g::EGraph, se)::Id + se isa EClass && return se.id e = preprocess(se) - id = add!(g, if istree(se) - class_ids::Vector{EClassId} = [addexpr!(g, arg; keepmeta = keepmeta) for arg in arguments(e)] - ENodeTerm(exprhead(e), operation(e), symtype(e), class_ids) - else - # constant enode - ENodeLiteral(e) - end) - if keepmeta - meta = TermInterface.metadata(e) - !isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta) + isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false) + + args = iscall(e) ? arguments(e) : children(e) + ar = length(args) + n = v_new(ar) + v_set_flag!(n, VECEXPR_FLAG_ISTREE) + iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL) + h = iscall(e) ? operation(e) : head(e) + v_set_head!(n, add_constant!(g, h)) + # get the signature from op and arity + v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar))) + for i in v_children_range(n) + @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) end - return id -end -function addexpr!(g::EGraph, ec::EClass; keepmeta = false) - @assert g == ec.g - find(g, ec.id) + add!(g, n, false) end """ Given an [`EGraph`](@ref) and two e-class ids, set the two e-classes as equal. """ -function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId - id_a = find(g, a) - id_b = find(g, b) +function Base.union!( + g::EGraph{ExpressionType,AnalysisType}, + enode_id1::Id, + enode_id2::Id, +)::Bool where {ExpressionType,AnalysisType} + g.clean = false + id_1 = IdKey(find(g, enode_id1)) + id_2 = IdKey(find(g, enode_id2)) - id_a == id_b && return id_a - to = union!(g.uf, id_a, id_b) - from = (to == id_a) ? id_b : id_a + id_1 == id_2 && return false - push!(g.dirty, to) + # Make sure class 2 has fewer parents + if length(g.classes[id_1].parents) < length(g.classes[id_2].parents) + id_1, id_2 = id_2, id_1 + end - from_class = g.classes[from] - to_class = g.classes[to] - to_class.id = to + union!(g.uf, id_1.val, id_2.val) - # I (was) the troublesome line! - g.classes[to] = union!(to_class, from_class) - delete!(g.classes, from) - g.numclasses -= 1 + eclass_2 = pop!(g.classes, id_2)::EClass + eclass_1 = g.classes[id_1]::EClass - return to -end + append!(g.pending, eclass_2.parents) -function in_same_class(g::EGraph, a, b) - find(g, a) == find(g, b) -end + (merged_1, merged_2, new_data) = merge_analysis_data!(eclass_1, eclass_2) + merged_1 && append!(g.analysis_pending, eclass_1.parents) + merged_2 && append!(g.analysis_pending, eclass_2.parents) -# TODO new rebuilding from egg -""" -This function restores invariants and executes -upwards merging in an [`EGraph`](@ref). See -the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) -for more details. -""" -function rebuild!(g::EGraph) - # normalize!(g.uf) - - while !isempty(g.dirty) - # todo = unique([find(egraph, id) for id ∈ egraph.dirty]) - todo = unique(g.dirty) - empty!(g.dirty) - for x in todo - repair!(g, x) - end - end + # update eclass_1 + append!(eclass_1.nodes, eclass_2.nodes) + append!(eclass_1.parents, eclass_2.parents) + eclass_1.data = new_data - if g.root != -1 - g.root = find(g, g.root) - end + modify!(g, eclass_1) + + return true +end + +function in_same_class(g::EGraph, ids::Id...)::Bool + nids = length(ids) + nids == 1 && return true - normalize!(g.uf) + first_id = find(g, ids[1]) + for i in 2:nids + first_id == find(g, ids[i]) || return false + end + true end -function repair!(g::EGraph, id::EClassId) - id = find(g, id) - ecdata = g[id] - ecdata.id = id - new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}() +function rebuild_classes!(g::EGraph) + for v in values(g.classes_by_op) + empty!(v) + end - for (p_enode, p_eclass) in ecdata.parents - p_enode = canonicalize!(g, p_enode) - # deduplicate parents - if haskey(new_parents, p_enode) - merge!(g, p_eclass, new_parents[p_enode]) + for (eclass_id, eclass) in g.classes + # old_len = length(eclass.nodes) + for n in eclass.nodes + canonicalize!(g, n) end - n_id = find(g, p_eclass) - g.memo[p_enode] = n_id - new_parents[p_enode] = n_id - end + # Sort to go in order? + unique!(eclass.nodes) - ecdata.parents = collect(new_parents) + for n in eclass.nodes + add_class_by_op(g, n, eclass_id.val) + end + end - # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) + for v in values(g.classes_by_op) + sort!(v) + unique!(v) + end +end - # Analysis invariant maintenance - for an in values(g.analyses) - hasdata(ecdata, an) && modify!(an, g, id) - for (p_enode, p_id) in ecdata.parents - # p_eclass = find(g, p_eclass) - p_eclass = g[p_id] - if !islazy(an) && !hasdata(p_eclass, an) - setdata!(p_eclass, an, make(an, g, p_enode)) +function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {ExpressionType,AnalysisType} + n_unions = 0 + + while !isempty(g.pending) || !isempty(g.analysis_pending) + while !isempty(g.pending) + (node::VecExpr, eclass_id::Id) = pop!(g.pending) + node = copy(node) + canonicalize!(g, node) + old_class_id = get!(g.memo, node, eclass_id) + if old_class_id != eclass_id + did_something = union!(g, old_class_id, eclass_id) + # TODO unique! can node dedup be moved here? compare performance + # did_something && unique!(g[eclass_id].nodes) + n_unions += did_something end - if hasdata(p_eclass, an) - p_data = getdata(p_eclass, an) - - if an !== :metadata_analysis - new_data = join(an, p_data, make(an, g, p_enode)) - if new_data != p_data - setdata!(p_eclass, an, new_data) - push!(g.dirty, p_id) + end + + while !isempty(g.analysis_pending) + (node::VecExpr, eclass_id::Id) = pop!(g.analysis_pending) + eclass_id = find(g, eclass_id) + eclass_id_key = IdKey(eclass_id) + eclass = g.classes[eclass_id_key] + + node_data = make(g, node) + if !isnothing(node_data) + if !isnothing(eclass.data) + joined_data = join(eclass.data, node_data) + + if joined_data != eclass.data + eclass.data = joined_data + modify!(g, eclass) + append!(g.analysis_pending, eclass.parents) end + else + eclass.data = node_data + modify!(g, eclass) + append!(g.analysis_pending, eclass.parents) end end end end - - unique!(ecdata.nodes) - - # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) - + n_unions end - -""" -Recursive function that traverses an [`EGraph`](@ref) and -returns a vector of all reachable e-classes from a given e-class id. -""" -function reachable(g::EGraph, id::EClassId) - id = find(g, id) - hist = EClassId[id] - todo = EClassId[id] - - - function reachable_node(xn::ENodeTerm) - x = canonicalize(g, xn) - for c_id in arguments(x) - if c_id ∉ hist - push!(hist, c_id) - push!(todo, c_id) +function check_memo(g::EGraph)::Bool + test_memo = Dict{VecExpr,Id}() + for (id, class) in g.classes + @assert id.val == class.id + for node in class.nodes + old_id = get!(test_memo, node, id.val) + if old_id != id.val + @assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)" end end end - function reachable_node(x::ENodeLiteral) end - while !isempty(todo) - curr = find(g, pop!(todo)) - for n in g.classes[curr] - reachable_node(n) - end + for (node, id) in test_memo + @assert id == find(g, id) + @assert id == find(g, g.memo[node]) end - return hist + true end +function check_analysis(g) + for (id, eclass) in g.classes + isnothing(eclass.data) && continue + pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass) + @assert eclass.data == pass + end + true +end """ -When extracting symbolic expressions from an e-graph, we need -to instruct the e-graph how to rebuild expressions of a certain type. -This function must be extended by the user to add new types of expressions that can be manipulated by e-graphs. +This function restores invariants and executes +upwards merging in an [`EGraph`](@ref). See +the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) +for more details. """ -function egraph_reconstruct_expression(T::Type{Expr}, op, args; metadata = nothing, exprhead = :call) - similarterm(Expr(:call, :_), op, args; metadata = metadata, exprhead = exprhead) +function rebuild!(g::EGraph; should_check_memo=false, should_check_analysis=false) + n_unions = process_unions!(g) + trimmed_nodes = rebuild_classes!(g) + @assert !should_check_memo || check_memo(g) + @assert !should_check_analysis || check_analysis(g) + g.clean = true + + @debug "REBUILT" n_unions trimmed_nodes end # Thanks to Max Willsey and Yihong Zhang -import Metatheory: lookup_pat -function lookup_pat(g::EGraph, p::PatTerm)::EClassId +function lookup_pat(g::EGraph{ExpressionType}, p::PatExpr)::Id where {ExpressionType} @assert isground(p) - eh = exprhead(p) - op = operation(p) - args = arguments(p) - ar = arity(p) + args = children(p) + h = v_head(p.n) - T = gettermtype(g, op, ar) + has_op = has_constant(g, h) || (h != p.quoted_head_hash && has_constant(g, p.quoted_head_hash)) + has_op || return 0 - ids = map(x -> lookup_pat(g, x), args) - !all((>)(0), ids) && return -1 + for i in v_children_range(p.n) + @inbounds p.n[i] = lookup_pat(g, args[i - VECEXPR_META_LENGTH]) + p.n[i] <= 0 && return 0 + end - if T == Expr && op isa Union{Function,DataType} - id = lookup(g, ENodeTerm(eh, op, T, ids)) - id < 0 && return lookup(g, ENodeTerm(eh, nameof(op), T, ids)) - return id - else - return lookup(g, ENodeTerm(eh, op, T, ids)) + id = lookup(g, p.n) + if id <= 0 && h != p.quoted_head_hash + v_set_head!(p.n, p.quoted_head_hash) + id = lookup(g, p.n) + v_set_head!(p.n, p.head_hash) end + id end -lookup_pat(g::EGraph, p::Any) = lookup(g, ENodeLiteral(p)) -lookup_pat(g::EGraph, p::AbstractPat) = throw(UnsupportedPatternException(p)) +function lookup_pat(g::EGraph, p::PatLiteral)::Id + h = last(p.n) + has_constant(g, h) ? lookup(g, p.n) : 0 +end diff --git a/src/EGraphs/extract.jl b/src/EGraphs/extract.jl new file mode 100644 index 00000000..85186b5c --- /dev/null +++ b/src/EGraphs/extract.jl @@ -0,0 +1,116 @@ +struct Extractor{CostFun,Cost} + g::EGraph + cost_function::CostFun + costs::Dict{IdKey,Tuple{Cost,Int64}} # Cost and index in eclass + Extractor{CF,C}(g::EGraph, cf::CF, d::Dict{IdKey,Tuple{C,Int64}}) where {CF,C} = new{CF,C}(g, cf, d) +end + +""" +Given a cost function, extract the expression +with the smallest computed cost from an [`EGraph`](@ref) +""" +function Extractor(g::EGraph, cost_function::Function, cost_type = Float64) + extractor = Extractor{typeof(cost_function),cost_type}(g, cost_function, Dict{IdKey,Tuple{cost_type,Int64}}()) + find_costs!(extractor) + extractor +end + +function extract_expr_recursive(g::EGraph{T}, n::VecExpr, get_node::Function) where {T} + h = get_constant(g, v_head(n)) + v_isexpr(n) || return h + children = map(c -> extract_expr_recursive(g, c, get_node), get_node.(v_children(n))) + # TODO metadata? + maketerm(T, h, children, nothing) +end + +function extract_expr_recursive(g::EGraph{Expr}, n::VecExpr, get_node::Function) + h = get_constant(g, v_head(n)) + v_isexpr(n) || return h + children = map(c -> extract_expr_recursive(g, c, get_node), get_node.(v_children(n))) + + if v_iscall(n) + maketerm(Expr, :call, [h; children], nothing) + else + maketerm(Expr, h, children, nothing) + end +end + + +function (extractor::Extractor)(root = extractor.g.root) + get_node(eclass_id::Id) = find_best_node(extractor, eclass_id) + # TODO check if infinite cost? + extract_expr_recursive(extractor.g, find_best_node(extractor, root), get_node) +end + +# costs dict stores index of enode. get this enode from the eclass +function find_best_node(extractor::Extractor, eclass_id::Id) + eclass = extractor.g[eclass_id] + (_, node_index) = extractor.costs[IdKey(eclass.id)] + eclass.nodes[node_index] +end + +function find_costs!(extractor::Extractor{CF,CT}) where {CF,CT} + did_something = true + while did_something + did_something = false + + for (id, eclass) in extractor.g.classes + min_cost = typemax(CT) + min_cost_node_idx = 0 + + for (idx, n) in enumerate(eclass.nodes) + has_all = true + for child_id in v_children(n) + has_all = has_all && haskey(extractor.costs, IdKey(child_id)) + has_all || break + end + if has_all + cost = extractor.cost_function( + n, + get_constant(extractor.g, v_head(n)), + map(child_id -> extractor.costs[IdKey(child_id)][1], v_children(n)), + ) + if cost < min_cost + min_cost = cost + min_cost_node_idx = idx + end + end + end + + if min_cost != typemax(CT) && (!haskey(extractor.costs, id) || (min_cost < extractor.costs[id][1])) + extractor.costs[id] = (min_cost, min_cost_node_idx) + did_something = true + end + end + end + + for (id, _) in extractor.g.classes + if !haskey(extractor.costs, id) + error("failed to compute extraction costs for eclass ", id.val) + end + end +end + +""" +A basic cost function, where the computed cost is the number +of expression tree nodes. +""" +function astsize(n::VecExpr, op, costs)::Float64 + v_isexpr(n) || return 1 + cost = 1 + sum(costs) +end + +""" +A basic cost function, where the computed cost is the number +of expression tree nodes times -1. +Strives to get the largest expression. This may lead to stack overflow for egraphs with loops. +""" +function astsize_inv(n::VecExpr, op, costs::Vector{Float64})::Float64 + v_isexpr(n) || return -1 + cost = -1 + sum(costs) +end + +function extract!(g::EGraph, costfun, root = g.root, cost_type = Float64) + Extractor(g, costfun, cost_type)(root) +end + diff --git a/src/EGraphs/intdisjointmap.jl b/src/EGraphs/intdisjointmap.jl deleted file mode 100644 index 2f475458..00000000 --- a/src/EGraphs/intdisjointmap.jl +++ /dev/null @@ -1,73 +0,0 @@ -struct IntDisjointSet - parents::Vector{Int} - normalized::Ref{Bool} -end - -IntDisjointSet() = IntDisjointSet(Int[], Ref(true)) -Base.length(x::IntDisjointSet) = length(x.parents) - -function Base.push!(x::IntDisjointSet)::Int - push!(x.parents, -1) - length(x) -end - -function find_root(x::IntDisjointSet, i::Int)::Int - while x.parents[i] >= 0 - i = x.parents[i] - end - return i -end - -function in_same_set(x::IntDisjointSet, a::Int, b::Int) - find_root(x, a) == find_root(x, b) -end - -function Base.union!(x::IntDisjointSet, i::Int, j::Int) - pi = find_root(x, i) - pj = find_root(x, j) - if pi != pj - x.normalized[] = false - isize = -x.parents[pi] - jsize = -x.parents[pj] - if isize > jsize # swap to make size of i less than j - pi, pj = pj, pi - isize, jsize = jsize, isize - end - x.parents[pj] -= isize # increase new size of pj - x.parents[pi] = pj # set parent of pi to pj - end - return pj -end - -function normalize!(x::IntDisjointSet) - for i in 1:length(x) - p_i = find_root(x, i) - if p_i != i - x.parents[i] = p_i - end - end - x.normalized[] = true -end - -# If normalized we don't even need a loop here. -function _find_root_normal(x::IntDisjointSet, i::Int) - p_i = x.parents[i] - if p_i < 0 # Is `i` a root? - return i - else - return p_i - end - # return pi -end - -function _in_same_set_normal(x::IntDisjointSet, a::Int64, b::Int64) - _find_root_normal(x, a) == _find_root_normal(x, b) -end - -function find_root_if_normal(x::IntDisjointSet, i::Int64) - if x.normalized[] - _find_root_normal(x, i) - else - find_root(x, i) - end -end diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 57d9d8c7..fcc3555f 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -1,26 +1,3 @@ -abstract type SaturationGoal end - -reached(g::EGraph, goal::Nothing) = false -reached(g::EGraph, goal::SaturationGoal) = false -reached(g::EGraph, goal::Function) = goal(g) - -""" -This goal is reached when the `exprs` list of expressions are in the -same equivalence class. -""" -struct EqualityGoal <: SaturationGoal - exprs::Vector{Any} - ids::Vector{EClassId} - function EqualityGoal(exprs, eclasses) - @assert length(exprs) == length(eclasses) && length(exprs) != 0 - new(exprs, eclasses) - end -end - -function reached(g::EGraph, goal::EqualityGoal) - all(x -> in_same_class(g, goal.ids[1], x), @view goal.ids[2:end]) -end - mutable struct SaturationReport reason::Union{Symbol,Nothing} egraph::EGraph @@ -40,7 +17,7 @@ function Base.show(io::IO, x::SaturationReport) println(io, "=================") println(io, "\tStop Reason: $(x.reason)") println(io, "\tIterations: $(x.iterations)") - println(io, "\tEGraph Size: $(g.numclasses) eclasses, $(length(g.memo)) nodes") + println(io, "\tEGraph Size: $(length(g.classes)) eclasses, $(length(g.memo)) nodes") print_timer(io, x.to) end @@ -54,80 +31,83 @@ Base.@kwdef mutable struct SaturationParams "Maximum number of eclasses allowed" eclasslimit::Int = 5000 enodelimit::Int = 15000 - goal::Union{Nothing,SaturationGoal,Function} = nothing - stopwhen::Function = () -> false + goal::Function = (g::EGraph) -> false scheduler::Type{<:AbstractScheduler} = BackoffScheduler - schedulerparams::Tuple = () + schedulerparams::NamedTuple = (;) threaded::Bool = false timer::Bool = true + "Activate check for memoization of nodes (hashcons) after rebuilding" + check_memo::Bool = false + "Activate check for join-semilattice invariant for semantic analysis values after rebuilding" + check_analysis::Bool = false end -# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} -# if isground(p) -# id = lookup_pat(g, p) -# !isnothing(id) && return [id] -# else -# return keys(g.classes) -# end -# return [] -# end - -function cached_ids(g::EGraph, p::AbstractPattern) # p is a literal - @warn "Pattern matching against the whole e-graph" - return keys(g.classes) +function cached_ids(g::EGraph, p::PatExpr)::Vector{Id} + if isground(p) + id = lookup_pat(g, p) + iszero(id) ? UNDEF_ID_VEC : [id] + else + get(g.classes_by_op, IdKey(v_signature(p.n)), UNDEF_ID_VEC) + end end -function cached_ids(g::EGraph, p) # p is a literal - id = lookup(g, ENodeLiteral(p)) +function cached_ids(g::EGraph, p::PatLiteral) # p is a literal + id = lookup_pat(g, p) id > 0 && return [id] - return [] -end - - -# function cached_ids(g::EGraph, p::PatTerm) -# arr = get(g.symcache, operation(p), EClassId[]) -# if operation(p) isa Union{Function,DataType} -# append!(arr, get(g.symcache, nameof(operation(p)), EClassId[])) -# end -# arr -# end - -function cached_ids(g::EGraph, p::PatTerm) - keys(g.classes) + return UNDEF_ID_VEC end +cached_ids(g::EGraph, p::PatVar) = Iterators.map(x -> x.val, keys(g.classes)) """ Returns an iterator of `Match`es. """ function eqsat_search!( g::EGraph, - theory::Vector{<:AbstractRule}, + theory::Theory, scheduler::AbstractScheduler, report::SaturationReport, + ematch_buffer::OptBuffer{UInt128}, )::Int n_matches = 0 - maybelock!(g) do - empty!(g.buffer) - end + g.needslock && lock(g.lock) + empty!(ematch_buffer) + g.needslock && unlock(g.lock) + @debug "SEARCHING" for (rule_idx, rule) in enumerate(theory) + prev_matches = n_matches @timeit report.to string(rule_idx) begin prev_matches = n_matches # don't apply banned rules - if !cansearch(scheduler, rule) + if !cansearch(scheduler, rule_idx) @debug "$rule is banned" continue end - ids = cached_ids(g, rule.left) - rule isa BidirRule && (ids = ids ∪ cached_ids(g, rule.right)) - for i in ids - n_matches += rule.ematcher!(g, rule_idx, i) + + ids_left = cached_ids(g, rule.left) + for i in ids_left + cansearch(scheduler, rule_idx, i) || continue + n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) + inform!(scheduler, rule_idx, i, n_matches) + end + + if is_bidirectional(rule) + ids_right = cached_ids(g, rule.right) + for i in ids_right + cansearch(scheduler, rule_idx, i) || continue + n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) + inform!(scheduler, rule_idx, i, n_matches) + end end + n_matches - prev_matches > 0 && @debug "Rule $rule_idx: $rule produced $(n_matches - prev_matches) matches" - inform!(scheduler, rule, n_matches) + # if n_matches - prev_matches > 2 && rule_idx == 2 + # @debug buffer_readable(g, old_len) + # end + inform!(scheduler, rule_idx, n_matches) end end @@ -135,109 +115,158 @@ function eqsat_search!( return n_matches end - -function drop_n!(D::CircularDeque, nn) - D.n -= nn - tmp = D.first + nn - D.first = tmp > D.capacity ? 1 : tmp +function instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatLiteral)::Id + add_constant_hashed!(g, p.value, v_head(p.n)) + add!(g, p.n, true) end -instantiate_enode!(bindings::Bindings, g::EGraph, p::Any)::EClassId = add!(g, ENodeLiteral(p)) -instantiate_enode!(bindings::Bindings, g::EGraph, p::PatVar)::EClassId = bindings[p.idx][1] -function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId - eh = exprhead(p) - op = operation(p) - ar = arity(p) - args = arguments(p) - T = gettermtype(g, op, ar) - # TODO add predicate check `quotes_operation` - new_op = T == Expr && op isa Union{Function,DataType} ? nameof(op) : op - add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args))) -end - -function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) - push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) - nothing -end +instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx]) +function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id where {ExpressionType} + add_constant_hashed!(g, p.head, p.head_hash) -function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) - pat_to_inst = direction == 1 ? rule.right : rule.left - push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) - nothing + for i in v_children_range(p.n) + @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) + end + add!(g, p.n, true) end +function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id + add_constant_hashed!(g, p.quoted_head, p.quoted_head_hash) + v_set_head!(p.n, p.quoted_head_hash) -function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClassId, direction::Int) - pat_to_inst = direction == 1 ? rule.right : rule.left - other_id = instantiate_enode!(bindings, g, pat_to_inst) - - if find(g, id) == find(g, other_id) - @debug "$rule produced a contradiction!" - return :contradiction + for i in v_children_range(p.n) + @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - nothing + add!(g, p.n, true) end """ Instantiate argument for dynamic rule application in e-graph """ -function instantiate_actual_param!(bindings::Bindings, g::EGraph, i) - ecid, literal_position = bindings[i] +function instantiate_actual_param!(bindings, g::EGraph, i) + ecid = v_pair_first(bindings[i]) + literal_position = reinterpret(Int, v_pair_last(bindings[i])) ecid <= 0 && error("unbound pattern variable") eclass = g[ecid] if literal_position > 0 - @assert eclass[literal_position] isa ENodeLiteral - return eclass[literal_position].value + @assert !v_isexpr(eclass[literal_position]) + return get_constant(g, v_head(eclass[literal_position])) end return eclass end -function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClassId, direction::Int) - f = rule.rhs_fun - r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) - isnothing(r) && return nothing - rcid = addexpr!(g, r) - push!(g.merges_buffer, (id, rcid)) - return nothing + +struct RuleApplicationResult + halt_reason::Symbol + l::Id + r::Id +end + +function apply_rule!( + bindings::SubArray{UInt128,1,Vector{UInt128},Tuple{UnitRange{Int64}},true}, + g::EGraph, + rule::RewriteRule, + id::Id, + direction::Int, +)::RuleApplicationResult + if rule.op === (-->) # DirectedRule + new_id::Id = instantiate_enode!(bindings, g, rule.right) + RuleApplicationResult(:nothing, new_id, id) + elseif rule.op === (==) # EqualityRule + pat_to_inst = direction == 1 ? rule.right : rule.left + new_id = instantiate_enode!(bindings, g, pat_to_inst) + RuleApplicationResult(:nothing, new_id, id) + elseif rule.op === (!=) # UnequalRule + pat_to_inst = direction == 1 ? rule.right : rule.left + other_id = instantiate_enode!(bindings, g, pat_to_inst) + + if find(g, id) == find(g, other_id) + @debug "$rule produced a contradiction!" + return RuleApplicationResult(:contradiction, 0, 0) + end + RuleApplicationResult(:nothing, 0, 0) + elseif rule.op === (|>) # DynamicRule + r = rule.right(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) + isnothing(r) && return RuleApplicationResult(:nothing, 0, 0) + rcid = addexpr!(g, r) + RuleApplicationResult(:nothing, rcid, id) + else + RuleApplicationResult(:error, 0, 0) + end end +const CHECK_GOAL_EVERY_N_MATCHES = 20 +function eqsat_apply!( + g::EGraph, + theory::Theory, + rep::SaturationReport, + params::SaturationParams, + ematch_buffer::OptBuffer{UInt128}, +) + n_matches = 0 + k = length(ematch_buffer) -function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams) - i = 0 - @assert isempty(g.merges_buffer) + @debug "APPLYING $(count((==)(0xffffffffffffffffffffffffffffffff), ematch_buffer)) matches" + g.needslock && lock(g.lock) + while k > 0 - @debug "APPLYING $(length(g.buffer)) matches" - maybelock!(g) do - while !isempty(g.buffer) + if n_matches % CHECK_GOAL_EVERY_N_MATCHES == 0 && params.goal(g) + @debug "Goal reached" + rep.reason = :goalreached + return + end - if reached(g, params.goal) - @debug "Goal reached" - rep.reason = :goalreached - return + delimiter = ematch_buffer.v[k] + @assert delimiter == 0xffffffffffffffffffffffffffffffff + n = k - 1 + + next_delimiter_idx = 0 + n_elems = 0 + for i in n:-1:1 + n_elems += 1 + if ematch_buffer.v[i] == 0xffffffffffffffffffffffffffffffff + n_elems -= 1 + next_delimiter_idx = i + break end + end - bindings = pop!(g.buffer) - rule_idx, id = bindings[0] - direction = sign(rule_idx) - rule_idx = abs(rule_idx) - rule = theory[rule_idx] + n_matches += 1 + match_info = ematch_buffer.v[next_delimiter_idx + 1] + id = v_pair_first(match_info) + rule_idx = reinterpret(Int, v_pair_last(match_info)) + direction = sign(rule_idx) + rule_idx = abs(rule_idx) + rule = theory[rule_idx] + bindings = @view ematch_buffer.v[(next_delimiter_idx + 2):n] - halt_reason = apply_rule!(bindings, g, rule, id, direction) + res = apply_rule!(bindings, g, rule, id, direction) - if !isnothing(halt_reason) - rep.reason = halt_reason - return - end + k = next_delimiter_idx + if res.halt_reason !== :nothing + rep.reason = res.halt_reason + return end - end - maybelock!(g) do - while !isempty(g.merges_buffer) - (l, r) = pop!(g.merges_buffer) - merge!(g, l, r) + + if params.enodelimit > 0 && length(g.memo) > params.enodelimit + @debug "Too many enodes" + rep.reason = :enodelimit + break end + + !iszero(res.l) && !iszero(res.r) && union!(g, res.l, res.r) + end + if params.goal(g) + @debug "Goal reached" + rep.reason = :goalreached + return end + + empty!(ematch_buffer) + + g.needslock && unlock(g.lock) end @@ -246,25 +275,28 @@ Core algorithm of the library: the equality saturation step. """ function eqsat_step!( g::EGraph, - theory::Vector{<:AbstractRule}, - curr_iter, + theory::Theory, + curr_iter::Int, scheduler::AbstractScheduler, params::SaturationParams, - report, + report::SaturationReport, + ematch_buffer::OptBuffer{UInt128}, ) setiter!(scheduler, curr_iter) - @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report) + @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report, ematch_buffer) - @timeit report.to "Apply" eqsat_apply!(g, theory, report, params) + @timeit report.to "Apply" eqsat_apply!(g, theory, report, params, ematch_buffer) - if report.reason === nothing && cansaturate(scheduler) && isempty(g.dirty) + if report.reason === nothing && cansaturate(scheduler) && isempty(g.pending) report.reason = :saturated end - @timeit report.to "Rebuild" rebuild!(g) + @timeit report.to "Rebuild" rebuild!(g; should_check_memo = params.check_memo, should_check_analysis = params.check_analysis) + + Schedulers.rebuild!(scheduler) - @debug smallest_expr = extract!(g, astsize) + @debug "Smallest expression is" extract!(g, astsize) return report end @@ -273,47 +305,55 @@ end Given an [`EGraph`](@ref) and a collection of rewrite rules, execute the equality saturation algorithm. """ -function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = SaturationParams()) +function saturate!(g::EGraph, theory::Theory, params = SaturationParams()) curr_iter = 0 - sched = params.scheduler(g, theory, params.schedulerparams...) + sched = params.scheduler(g, theory; params.schedulerparams...) report = SaturationReport(g) start_time = time_ns() - !params.timer && disable_timer!(report.to) - timelimit = params.timelimit > 0 + params.timer || disable_timer!(report.to) + + # Buffer for e-matching. Use a local buffer for generated functions. + ematch_buffer = OptBuffer{UInt128}(64) while true curr_iter += 1 @debug "================ EQSAT ITERATION $curr_iter ================" + @debug g - report = eqsat_step!(g, theory, curr_iter, sched, params, report) + report = eqsat_step!(g, theory, curr_iter, sched, params, report, ematch_buffer) elapsed = time_ns() - start_time - if timelimit && params.timelimit <= elapsed - report.reason = :timelimit + if params.goal(g) + @debug "Goal reached" + report.reason = :goalreached break end - if !(report.reason isa Nothing) + if report.reason !== nothing + @debug "Reason" report.reason break end - if curr_iter >= params.timeout - report.reason = :timeout + if params.timelimit > 0 && params.timelimit <= elapsed + @debug "Time limit reached" + report.reason = :timelimit break end - if params.eclasslimit > 0 && g.numclasses > params.eclasslimit - report.reason = :eclasslimit + if curr_iter >= params.timeout + @debug "Too many iterations" + report.reason = :timeout break end - if reached(g, params.goal) - report.reason = :goalreached + if params.eclasslimit > 0 && length(g.classes) > params.eclasslimit + @debug "Too many eclasses" + report.reason = :eclasslimit break end end @@ -321,35 +361,3 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio return report end - -function areequal(theory::Vector, exprs...; params = SaturationParams()) - g = EGraph(exprs[1]) - areequal(g, theory, exprs...; params = params) -end - -function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams()) - if length(exprs) == 1 - return true - end - - n = length(exprs) - ids = map(Base.Fix1(addexpr!, g), collect(exprs)) - goal = EqualityGoal(collect(exprs), ids) - - params.goal = goal - - report = saturate!(g, t, params) - - if !(report.reason === :saturated) && !reached(g, goal) - return missing # failed to prove - end - return reached(g, goal) -end - -macro areequal(theory, exprs...) - esc(:(areequal($theory, $exprs...))) -end - -macro areequalg(G, theory, exprs...) - esc(:(areequal($G, $theory, $exprs...))) -end diff --git a/src/EGraphs/unionfind.jl b/src/EGraphs/unionfind.jl new file mode 100644 index 00000000..f53c0f23 --- /dev/null +++ b/src/EGraphs/unionfind.jl @@ -0,0 +1,25 @@ +struct UnionFind + parents::Vector{Id} +end + +UnionFind() = UnionFind(Id[]) + +function Base.push!(uf::UnionFind)::Id + l = length(uf.parents) + 1 + push!(uf.parents, l) + l +end + +Base.length(uf::UnionFind) = length(uf.parents) + +function Base.union!(uf::UnionFind, i::Id, j::Id) + uf.parents[j] = i + i +end + +function find(uf::UnionFind, i::Id) + while i != uf.parents[i] + i = uf.parents[i] + end + i +end diff --git a/src/EGraphs/uniquequeue.jl b/src/EGraphs/uniquequeue.jl new file mode 100644 index 00000000..512bb61a --- /dev/null +++ b/src/EGraphs/uniquequeue.jl @@ -0,0 +1,33 @@ +""" +A data structure to maintain a queue of unique elements. +Notably, insert/pop operations have O(1) expected amortized runtime complexity. +""" + +struct UniqueQueue{T} + set::Set{T} + vec::Vector{T} +end + + +UniqueQueue{T}() where {T} = UniqueQueue{T}(Set{T}(), T[]) + +function Base.push!(uq::UniqueQueue{T}, x::T) where {T} + if !(x in uq.set) + push!(uq.set, x) + push!(uq.vec, x) + end +end + +function Base.append!(uq::UniqueQueue{T}, xs::Vector{T}) where {T} + for x in xs + push!(uq, x) + end +end + +function Base.pop!(uq::UniqueQueue{T}) where {T} + v = pop!(uq.vec) + delete!(uq.set, v) + v +end + +Base.isempty(uq::UniqueQueue) = isempty(uq.vec) diff --git a/src/Library.jl b/src/Library.jl index 6a3f7f18..4154c72f 100644 --- a/src/Library.jl +++ b/src/Library.jl @@ -6,59 +6,58 @@ include("docstrings.jl") module Library -using Metatheory.Patterns -using Metatheory.Rules - +using Metatheory macro commutativity(op) - RewriteRule(PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatTerm(:call, op, [PatVar(:b), PatVar(:a)])) + :(@rule $(op)(~a, ~b) --> $(op)(~b, ~a)) end macro right_associative(op) - RewriteRule( - PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]), - PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]), - ) + :(@rule a b c $(op)(a, $(op)(b, c)) --> $(op)($(op)(a, b), c)) end macro left_associative(op) - RewriteRule( - PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]), - PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]), - ) + :(@rule a b c $(op)($(op)(a, b), c) --> $(op)(a, $(op)(b, c))) end macro identity_left(op, id) - RewriteRule(PatTerm(:call, op, [id, PatVar(:a)]), PatVar(:a)) + :(@rule $(op)($id, ~a) --> ~a) end macro identity_right(op, id) - RewriteRule(PatTerm(:call, op, [PatVar(:a), id]), PatVar(:a)) + :(@rule $(op)(~a, $id) --> ~a) end macro inverse_left(op, id, invop) - RewriteRule(PatTerm(:call, op, [PatTerm(:call, invop, [PatVar(:a)]), PatVar(:a)]), id) + :(@rule $(op)($(invop)(~a), ~a) --> $id) end macro inverse_right(op, id, invop) - RewriteRule(PatTerm(:call, op, [PatVar(:a), PatTerm(:call, invop, [PatVar(:a)])]), id) + :(@rule $(op)(~a, $(invop)(~a)) --> $id) end macro associativity(op) esc(quote - [(@left_associative $op), (@right_associative $op)] + RewriteRule[(@left_associative $op), (@right_associative $op)] end) end macro monoid(op, id) - esc(quote - [(@left_associative($op)), (@right_associative($op)), (@identity_left($op, $id)), (@identity_right($op, $id))] - end) + esc( + quote + RewriteRule[ + (@left_associative($op)), + (@right_associative($op)), + (@identity_left($op, $id)), + (@identity_right($op, $id)), + ] + end, + ) end macro commutative_monoid(op, id) esc(quote - [(@commutativity $op), (@left_associative $op), (@right_associative $op), (@identity_left $op $id)] + RewriteRule[(@commutativity $op), (@left_associative $op), (@right_associative $op), (@identity_left $op $id)] end) end diff --git a/src/Metatheory.jl b/src/Metatheory.jl index 6ab2a811..7e10e579 100644 --- a/src/Metatheory.jl +++ b/src/Metatheory.jl @@ -1,30 +1,74 @@ module Metatheory -using DataStructures - -using Base.Meta +using TermInterface: isexpr using Reexport -using TermInterface -@inline alwaystrue(x) = true +@inline alwaystrue(x...) = true + +function to_expr end -function lookup_pat end -function maybelock! end +# TODO: document +Base.@inline maybe_quote_operation(x::Union{Function,DataType,UnionAll}) = nameof(x) +Base.@inline maybe_quote_operation(x) = x include("docstrings.jl") + +include("vecexpr.jl") +@reexport using .VecExprModule + +include("optbuffer.jl") +export OptBuffer + +const UNDEF_ID_VEC = Vector{Id}(undef, 0) + +@reexport using TermInterface + +""" + @matchable struct Foo fields... end [HeadType] + +Take a struct definition and automatically define `TermInterface` methods. +`iscall` of such type will default to `true`. +""" +macro matchable(expr) + @assert expr.head == :struct + name = expr.args[2] + if name isa Expr + name.head === :(<:) && (name = name.args[1]) + name isa Expr && name.head === :curly && (name = name.args[1]) + end + fields = filter(x -> x isa Symbol || (x isa Expr && x.head == :(::)), expr.args[3].args) + get_name(s::Symbol) = s + get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) + fields = map(get_name, fields) + + quote + $expr + TermInterface.isexpr(::$name) = true + TermInterface.iscall(::$name) = true + TermInterface.head(::$name) = $name + TermInterface.operation(::$name) = $name + TermInterface.children(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) + TermInterface.arguments(x::$name) = TermInterface.children(x) + TermInterface.arity(x::$name) = $(length(fields)) + Base.length(x::$name) = $(length(fields) + 1) + end |> esc +end +export @matchable + include("utils.jl") export @timer -export @iftimer -export @timerewrite -export @matchable + include("Patterns.jl") @reexport using .Patterns +include("match_compiler.jl") +export match_compile + + include("ematch_compiler.jl") -@reexport using .EMatchCompiler +export ematch_compile -include("matchers.jl") include("Rules.jl") @reexport using .Rules diff --git a/src/Patterns.jl b/src/Patterns.jl index be460bea..cf0a9ee4 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -1,9 +1,13 @@ module Patterns -using Metatheory: binarize, cleanast, alwaystrue +using Metatheory: cleanast, alwaystrue, maybe_quote_operation using AutoHashEquals using TermInterface +using Metatheory.VecExprModule +import Metatheory: to_expr + +export AbstractPat, PatLiteral, PatVar, PatExpr, PatSegment, patvars, setdebrujin!, isground, constants """ Abstract type representing a pattern used in all the various pattern matching backends. @@ -11,13 +15,6 @@ Abstract type representing a pattern used in all the various pattern matching ba abstract type AbstractPat end -struct UnsupportedPatternException <: Exception - p::AbstractPat -end - -Base.showerror(io::IO, e::UnsupportedPatternException) = print(io, "Pattern ", e.p, " is unsupported in this context") - - Base.:(==)(a::AbstractPat, b::AbstractPat) = false TermInterface.arity(p::AbstractPat) = 0 """ @@ -25,7 +22,17 @@ A ground pattern contains no pattern variables and only literal values to match. """ isground(p::AbstractPat) = false -isground(x) = true # literals + +struct PatLiteral <: AbstractPat + value + n::VecExpr + PatLiteral(val) = new(val, VecExpr(Id[0, 0, 0, hash(val)])) +end + +PatLiteral(p::AbstractPat) = throw(DomainError(p, "Cannot construct a pattern literal of another pattern object.")) + +isground(x::PatLiteral) = true # literals + # PatVar is equivalent to SymbolicUtils's Slot """ @@ -42,15 +49,14 @@ boolean value. Such a slot will be considered a match only if `f` returns true. type assertion. Type assertions on a `PatVar`, will match if and only if the type of the matched term for the pattern variable is a subtype of `T`. """ -mutable struct PatVar{P} <: AbstractPat +mutable struct PatVar{P<:Union{Function,Type}} <: AbstractPat name::Symbol idx::Int predicate::P - predicate_code end Base.:(==)(a::PatVar, b::PatVar) = a.idx == b.idx -PatVar(var) = PatVar(var, -1, alwaystrue, nothing) -PatVar(var, i) = PatVar(var, i, alwaystrue, nothing) +PatVar(var) = PatVar(var, -1, alwaystrue) +PatVar(var, i) = PatVar(var, i, alwaystrue) """ If you want to match a variable number of subexpressions at once, you will need @@ -59,92 +65,117 @@ A segment pattern represents a vector of subexpressions matched. You can attach a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. """ -mutable struct PatSegment{P} <: AbstractPat +mutable struct PatSegment{P<:Union{Function,Type}} <: AbstractPat name::Symbol idx::Int predicate::P - predicate_code end -PatSegment(v) = PatSegment(v, -1, alwaystrue, nothing) -PatSegment(v, i) = PatSegment(v, i, alwaystrue, nothing) +PatSegment(v) = PatSegment(v, -1, alwaystrue) +PatSegment(v, i) = PatSegment(v, i, alwaystrue) """ -Term patterns will match -on terms of the same `arity` and with the same -function symbol `operation` and expression head `exprhead`. +Term patterns will match on terms of the same `arity` and with the same `operation`. """ -struct PatTerm <: AbstractPat - exprhead::Any - operation::Any - args::Vector - PatTerm(eh, op, args) = new(eh, op, args) #Ref{UInt}(0)) +struct PatExpr <: AbstractPat + head + head_hash::UInt + quoted_head + quoted_head_hash::UInt + children::Vector{AbstractPat} + isground::Bool + """ + Behaves like an e-node to not re-allocate memory when doing e-graph lookups and instantiation + in case of cache hits in the e-graph hashcons + """ + n::VecExpr + function PatExpr(iscall, op, qop, args::Vector) + op_hash = hash(op) + qop_hash = hash(qop) + ar = length(args) + signature = hash(qop, hash(ar)) + + n = v_new(ar) + v_set_flag!(n, VECEXPR_FLAG_ISTREE) + iscall && v_set_flag!(n, VECEXPR_FLAG_ISCALL) + v_set_head!(n, op_hash) + v_set_signature!(n, signature) + + for i in v_children_range(n) + @inbounds n[i] = 0 + end + + new(op, op_hash, qop, qop_hash, args, all(isground, args), n) + end end -TermInterface.istree(::PatTerm) = true -TermInterface.exprhead(e::PatTerm) = e.exprhead -TermInterface.operation(p::PatTerm) = p.operation -TermInterface.arguments(p::PatTerm) = p.args -TermInterface.arity(p::PatTerm) = length(arguments(p)) -TermInterface.metadata(p::PatTerm) = nothing - -function TermInterface.similarterm(x::PatTerm, head, args, symtype = nothing; metadata = nothing, exprhead = :call) - PatTerm(exprhead, head, args) + +# Should call `nameof` on op if Function or DataType. Identity otherwise +PatExpr(iscall, op, args::Vector) = PatExpr(iscall, op, maybe_quote_operation(op), args) + +isground(p::PatExpr)::Bool = p.isground + +function Base.isequal(x::PatExpr, y::PatExpr) + x.head_hash === y.head_hash && v_signature(x.n)===v_signature(y.n) && all(x.children .== y.children) end -isground(p::PatTerm) = all(isground, p.args) +TermInterface.isexpr(::PatExpr) = true +TermInterface.head(p::PatExpr) = p.head +TermInterface.operation(p::PatExpr) = p.head +TermInterface.children(p::PatExpr) = p.children +TermInterface.arguments(p::PatExpr) = p.children +TermInterface.iscall(p::PatExpr) = v_iscall(p.n) +TermInterface.arity(p::PatExpr) = length(p.children) + +function TermInterface.maketerm(::Type{PatExpr}, operation, arguments, metadata) + iscall = isnothing(metadata) ? true : metadata.iscall + PatExpr(iscall, operation, arguments...) +end -# ============================================== -# ================== PATTERN VARIABLES ========= -# ============================================== +# --------------------- +# # Pattern Variables. """ Collects pattern variables appearing in a pattern into a vector of symbols """ patvars(p::PatVar, s) = push!(s, p.name) patvars(p::PatSegment, s) = push!(s, p.name) -patvars(p::PatTerm, s) = (patvars(operation(p), s); foreach(x -> patvars(x, s), arguments(p)); s) -patvars(x, s) = s +patvars(p::PatExpr, s) = (patvars(operation(p), s); foreach(x -> patvars(x, s), arguments(p)); s) +patvars(::Any, s) = s patvars(p) = unique!(patvars(p, Symbol[])) -# ============================================== -# ================== DEBRUJIN INDEXING ========= -# ============================================== +# --------------------- +# # Debrujin Indexing. + function setdebrujin!(p::Union{PatVar,PatSegment}, pvars) p.idx = findfirst((==)(p.name), pvars) end # literal case -setdebrujin!(p, pvars) = nothing +setdebrujin!(::Any, pvars) = nothing -function setdebrujin!(p::PatTerm, pvars) +function setdebrujin!(p::PatExpr, pvars) setdebrujin!(operation(p), pvars) - foreach(x -> setdebrujin!(x, pvars), p.args) + foreach(x -> setdebrujin!(x, pvars), p.children) end - -to_expr(x) = x -to_expr(x::PatVar{T}) where {T} = Expr(:call, :~, Expr(:(::), x.name, x.predicate_code)) -to_expr(x::PatSegment{T}) where {T<:Function} = Expr(:..., Expr(:call, :~, Expr(:(::), x.name, x.predicate_code))) +to_expr(x::PatLiteral) = x.value +to_expr(x::PatVar{T}) where {T} = Expr(:call, :~, Expr(:(::), x.name, x.predicate)) +to_expr(x::PatSegment{T}) where {T<:Function} = Expr(:..., Expr(:call, :~, Expr(:(::), x.name, x.predicate))) to_expr(x::PatVar{typeof(alwaystrue)}) = Expr(:call, :~, x.name) to_expr(x::PatSegment{typeof(alwaystrue)}) = Expr(:..., Expr(:call, :~, x.name)) -to_expr(x::PatTerm) = similarterm(Expr(:call, :x), operation(x), map(to_expr, arguments(x)); exprhead = exprhead(x)) +function to_expr(x::PatExpr) + if iscall(x) + maketerm(Expr, :call, [x.quoted_head; to_expr.(arguments(x))], nothing) + else + maketerm(Expr, operation(x), to_expr.(arguments(x)), nothing) + end +end Base.show(io::IO, pat::AbstractPat) = print(io, to_expr(pat)) -# include("rules/patterns.jl") -export AbstractPat -export PatVar -export PatTerm -export PatSegment -export patvars -export setdebrujin! -export isground -export UnsupportedPatternException - - end diff --git a/src/Rewriters.jl b/src/Rewriters.jl index 94d1ab38..776c636e 100644 --- a/src/Rewriters.jl +++ b/src/Rewriters.jl @@ -13,7 +13,7 @@ rewriters. - `RestartedChain(itr)` like `Chain(itr)` but restarts from the first rewriter once on the first successful application of one of the chained rewriters. - `IfElse(cond, rw1, rw2)` runs the `cond` function on the input, applies `rw1` if cond - returns true, `rw2` if it retuns false + returns true, `rw2` if it returns false - `If(cond, rw)` is the same as `IfElse(cond, rw, Empty())` - `Prewalk(rw; threaded=false, thread_cutoff=100)` returns a rewriter which does a pre-order traversal of a given expression and applies the rewriter `rw`. Note that if @@ -160,22 +160,22 @@ end struct Walk{ord,C,F,threaded} rw::C thread_cutoff::Int - similarterm::F + maketerm::F end function instrument(x::Walk{ord,C,F,threaded}, f) where {ord,C,F,threaded} irw = instrument(x.rw, f) - Walk{ord,typeof(irw),typeof(x.similarterm),threaded}(irw, x.thread_cutoff, x.similarterm) + Walk{ord,typeof(irw),typeof(x.maketerm),threaded}(irw, x.thread_cutoff, x.maketerm) end using .Threads -function Postwalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm) - Walk{:post,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm) +function Postwalk(rw; threaded::Bool = false, thread_cutoff = 100, maketerm = maketerm) + Walk{:post,typeof(rw),typeof(maketerm),threaded}(rw, thread_cutoff, maketerm) end -function Prewalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm) - Walk{:pre,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm) +function Prewalk(rw; threaded::Bool = false, thread_cutoff = 100, maketerm = maketerm) + Walk{:pre,typeof(rw),typeof(maketerm),threaded}(rw, thread_cutoff, maketerm) end struct PassThrough{C} @@ -188,12 +188,12 @@ instrument(x::PassThrough, f) = PassThrough(instrument(x.rw, f)) passthrough(x, default) = x === nothing ? default : x function (p::Walk{ord,C,F,false})(x) where {ord,C,F} @assert ord === :pre || ord === :post - if istree(x) + if isexpr(x) if ord === :pre x = p.rw(x) end - if istree(x) - x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)); exprhead = exprhead(x)) + if isexpr(x) + x = p.maketerm(typeof(x), head(x), map(PassThrough(p), children(x)), nothing) end return ord === :post ? p.rw(x) : x else @@ -203,20 +203,20 @@ end function (p::Walk{ord,C,F,true})(x) where {ord,C,F} @assert ord === :pre || ord === :post - if istree(x) + if isexpr(x) if ord === :pre x = p.rw(x) end - if istree(x) - _args = map(arguments(x)) do arg + if isexpr(x) + _args = map(children(x)) do arg if node_count(arg) > p.thread_cutoff Threads.@spawn p(arg) else p(arg) end end - args = map((t, a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.similarterm(x, operation(x), args; exprhead = exprhead(x)) + ntail = map((t, a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, children(x)) + t = p.maketerm(typeof(x), head(x), ntail, nothing) end return ord === :post ? p.rw(t) : t else diff --git a/src/Rules.jl b/src/Rules.jl index d3c927c3..6ea51ef8 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -2,217 +2,187 @@ module Rules using TermInterface using AutoHashEquals -using Metatheory.EMatchCompiler using Metatheory.Patterns using Metatheory.Patterns: to_expr -using Metatheory: cleanast, binarize, matcher, instantiate +using Metatheory: OptBuffer + +export RewriteRule, + DirectedRule, + EqualityRule, + UnequalRule, + DynamicRule, + -->, + is_bidirectional, + Theory, + direct, + direct_left_to_right, + direct_right_to_left + +const STACK_SIZE = 512 -const EMPTY_DICT = Base.ImmutableDict{Int,Any}() - -abstract type AbstractRule end -# Must override -Base.:(==)(a::AbstractRule, b::AbstractRule) = false - -abstract type SymbolicRule <: AbstractRule end - -abstract type BidirRule <: SymbolicRule end +""" +Rules in Metatheory can be defined with the `@rule` macro. -struct RuleRewriteError - rule - expr -end +Rules defined as with the --> are +called *directed rewrite* rules. Application of a *directed rewrite* rule +is a replacement of the `left` pattern with +the `right` substitution, with the correct instantiation +of pattern variables. -getdepth(::Any) = typemax(Int) +```julia +@rule ~a * ~b --> ~b * ~a +``` -showraw(io, t) = Base.show(IOContext(io, :simplify => false), t) -showraw(t) = showraw(stdout, t) +An *equational rule* is a symbolic substitution rule with operator `==` that +can be rewritten bidirectionally. Therefore, it can only be used +with the EGraphs backend. -@noinline function Base.showerror(io::IO, err::RuleRewriteError) - msg = "Failed to apply rule $(err.rule) on expression " - msg *= sprint(io -> showraw(io, err.expr)) - print(io, msg) -end +```julia +@rule ~a * ~b == ~b * ~a +``` +Rules defined with the `!=` act as *anti*-rules for checking contradictions in e-graph +rewriting. If two terms, corresponding to the left and right hand side of an +*anti-rule* are found in an `EGraph`, saturation is halted immediately. -""" -Rules defined as `left_hand --> right_hand` are -called *symbolic rewrite* rules. Application of a *rewrite* Rule -is a replacement of the `left_hand` pattern with -the `right_hand` substitution, with the correct instantiation -of pattern variables. Function call symbols are not treated as pattern -variables, all other identifiers are treated as pattern variables. -Literals such as `5, :e, "hello"` are not treated as pattern -variables. +```julia +!a != a +```` +Rules defined with the `=>` operator are +called *dynamic rules*. Dynamic rules behave like anonymous functions. +Instead of a symbolic substitution, the right hand of +a dynamic `=>` rule is evaluated during rewriting: +matched values are bound to pattern variables as in a +regular function call. This allows for dynamic computation +of right hand sides. ```julia -@rule ~a * ~b --> ~b * ~a +@rule ~a::Number * ~b::Number => ~a*~b ``` """ -@auto_hash_equals fields = (left, right) struct RewriteRule <: SymbolicRule - left - right - matcher +Base.@kwdef struct RewriteRule{Op<:Function} + name::String = "" + op::Op + left::AbstractPat + right::Union{Function,AbstractPat} patvars::Vector{Symbol} - ematcher! + ematcher_left!::Function + ematcher_right!::Union{Nothing,Function} = nothing + matcher_left::Function + matcher_right::Union{Nothing,Function} = nothing + stack::OptBuffer{UInt16} = OptBuffer{UInt16}(STACK_SIZE) + lhs_original = nothing + rhs_original = nothing end -function RewriteRule(l, r) - pvars = patvars(l) ∪ patvars(r) - # sort!(pvars) - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - RewriteRule(l, r, matcher(l), pvars, ematcher_yield(l, length(pvars))) -end +function --> end +const DirectedRule = RewriteRule{typeof(-->)} +const EqualityRule = RewriteRule{typeof(==)} +const UnequalRule = RewriteRule{typeof(!=)} +# FIXME => is not a function we have to use |> +const DynamicRule = RewriteRule{typeof(|>)} -Base.show(io::IO, r::RewriteRule) = print(io, :($(r.left) --> $(r.right))) +is_bidirectional(r::RewriteRule) = r.op in (==, !=) -function (r::RewriteRule)(term) - # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = n == 1 ? instantiate(term, r.right, bindings) : nothing +# TODO equivalence up-to debruijn index +Base.:(==)(a::RewriteRule, b::RewriteRule) = a.op == b.op && a.left == b.left && a.right == b.right - try - r.matcher(success, (term,), EMPTY_DICT) - catch err - throw(RuleRewriteError(r, term)) - end +function Base.show(io::IO, r::RewriteRule) + print(io, r.left) + print(io, " ") + print(io, r.op == (|>) ? :(=>) : String(nameof(r.op))) + print(io, " ") + print(io, r.rhs_original) + isempty(r.name) || print(io, "\t#= $(r.name) =#") end -# ============================================================ -# EqualityRule -# ============================================================ -""" -An `EqualityRule` can is a symbolic substitution rule that -can be rewritten bidirectional. Therefore, it should only be used -with the EGraphs backend. +(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), r.stack) +(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), r.stack) -```julia -@rule ~a * ~b == ~b * ~a -``` -""" -@auto_hash_equals struct EqualityRule <: BidirRule - left - right - patvars::Vector{Symbol} - ematcher! -end +# --------------------- +# Theories -function EqualityRule(l, r) - pvars = patvars(l) ∪ patvars(r) - extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) - if !isempty(extravars) - error("unbound pattern variables $extravars when creating bidirectional rule") - end - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - EqualityRule(l, r, pvars, ematcher_yield_bidir(l, r, length(pvars))) -end +const Theory = Vector{RewriteRule} +# struct Theory +# name::String +# rules::Vector{RewriteRule} +# end -Base.show(io::IO, r::EqualityRule) = print(io, :($(r.left) == $(r.right))) +# --------------------- +# Instantiation -function (r::EqualityRule)(x) - throw(RuleRewriteError(r, x)) +function instantiate(left, pat::PatExpr, bindings) + ntail = [] + for parg in arguments(pat) + instantiate_arg!(ntail, left, parg, bindings) + end + maketerm(typeof(left), operation(pat), ntail, nothing) end +function instantiate(left::Expr, pat::PatExpr, bindings) + ntail = [] + if iscall(pat) + for parg in arguments(pat) + instantiate_arg!(ntail, left, parg, bindings) + end + op = operation(pat) + op_name = op isa Union{Function,DataType,UnionAll} ? nameof(op) : op + maketerm(Expr, :call, [op_name; ntail], nothing) + else + for parg in children(pat) + instantiate_arg!(ntail, left, parg, bindings) + end + maketerm(Expr, head(pat), ntail, nothing) + end +end -# ============================================================ -# UnequalRule -# ============================================================ +instantiate_arg!(acc, left, parg::PatSegment, bindings) = append!(acc, instantiate(left, parg, bindings)) +instantiate_arg!(acc, left, parg::AbstractPat, bindings) = push!(acc, instantiate(left, parg, bindings)) + +instantiate(_, pat::PatLiteral, bindings) = pat.value +instantiate(_, pat::Union{PatVar,PatSegment}, bindings) = bindings[pat.idx] + +"Inverts the direction of a rewrite rule, swapping the LHS and the RHS" +function Base.inv(r::RewriteRule) + RewriteRule( + name = r.name, + op = r.op, + left = r.right, + right = r.left, + patvars = r.patvars, + ematcher_left! = r.ematcher_right!, + ematcher_right! = r.ematcher_left!, + matcher_left = r.matcher_right, + matcher_right = r.matcher_left, + lhs_original = r.rhs_original, + rhs_original = r.lhs_original, + ) +end """ -This type of *anti*-rules is used for checking contradictions in the EGraph -backend. If two terms, corresponding to the left and right hand side of an -*anti-rule* are found in an [`EGraph`], saturation is halted immediately. +Turns an EqualityRule into a DirectedRule. For example, ```julia -!a ≠ a +direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) ``` - """ -@auto_hash_equals struct UnequalRule <: BidirRule - left - right - patvars::Vector{Symbol} - ematcher! -end - -function UnequalRule(l, r) - pvars = patvars(l) ∪ patvars(r) - extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) - if !isempty(extravars) - error("unbound pattern variables $extravars when creating bidirectional rule") - end - # sort!(pvars) - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - UnequalRule(l, r, pvars, ematcher_yield_bidir(l, r, length(pvars))) +function direct(r::EqualityRule) + RewriteRule(r.name, -->, (getfield(r, k) for k in fieldnames(DirectedRule)[3:end])...) end -Base.show(io::IO, r::UnequalRule) = print(io, :($(r.left) ≠ $(r.right))) - -# ============================================================ -# DynamicRule -# ============================================================ """ -Rules defined as `left_hand => right_hand` are -called `dynamic` rules. Dynamic rules behave like anonymous functions. -Instead of a symbolic substitution, the right hand of -a dynamic `=>` rule is evaluated during rewriting: -matched values are bound to pattern variables as in a -regular function call. This allows for dynamic computation -of right hand sides. +Turns an EqualityRule into a DirectedRule, but right to left. For example, -Dynamic rule ```julia -@rule ~a::Number * ~b::Number => ~a*~b +direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) ``` """ -@auto_hash_equals struct DynamicRule <: AbstractRule - left - rhs_fun::Function - rhs_code - matcher - patvars::Vector{Symbol} # useful set of pattern variables - ematcher! -end - -function DynamicRule(l, r::Function, rhs_code = nothing) - pvars = patvars(l) - setdebrujin!(l, pvars) - isnothing(rhs_code) && (rhs_code = repr(rhs_code)) - - DynamicRule(l, r, rhs_code, matcher(l), pvars, ematcher_yield(l, length(pvars))) -end - - -Base.show(io::IO, r::DynamicRule) = print(io, :($(r.left) => $(r.rhs_code))) - -function (r::DynamicRule)(term) - # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = - if n == 1 - bvals = [bindings[i] for i in 1:length(r.patvars)] - return r.rhs_fun(term, nothing, bvals...) - end - - try - return r.matcher(success, (term,), EMPTY_DICT) - catch err - rethrow(err) - throw(RuleRewriteError(r, term)) - end -end - -export SymbolicRule -export RewriteRule -export BidirRule -export EqualityRule -export UnequalRule -export DynamicRule -export AbstractRule +direct_right_to_left(r::EqualityRule) = inv(direct(r)) +direct_left_to_right(r::EqualityRule) = direct(r) end diff --git a/src/Syntax.jl b/src/Syntax.jl index 3f3d4760..f3babb56 100644 --- a/src/Syntax.jl +++ b/src/Syntax.jl @@ -2,8 +2,9 @@ module Syntax using Metatheory.Patterns using Metatheory.Rules using TermInterface +using Metatheory: Metatheory -using Metatheory: alwaystrue, cleanast, binarize +using Metatheory: alwaystrue, cleanast, ematch_compile, match_compile export @rule export @theory @@ -13,40 +14,54 @@ export @capture # FIXME this thing eats up macro calls! """ -Remove LineNumberNode from quoted blocks of code +Remove LineNumberNode from quoted blocks of code. Not on macros. """ -rmlines(e::Expr) = Expr(e.head, map(rmlines, filter(x -> !(x isa LineNumberNode), e.args))...) +function rmlines(e::Expr) + if e.head == :macrocall + Expr(e.head, e.args[1], map(rmlines, e.args[2:end])...) + else + Expr(e.head, map(rmlines, filter(x -> !(x isa LineNumberNode), e.args))...) + end +end rmlines(a) = a -function_object_or_quote(op::Symbol, mod)::Expr = :(isdefined($mod, $(QuoteNode(op))) ? $op : $(QuoteNode(op))) -function_object_or_quote(op, mod) = op - -function makesegment(s::Expr, pvars) - if !(exprhead(s) == :(::)) +function makesegment(s::Expr, pvars, mod) + if s.head != :(::) error("Syntax for specifying a segment is ~~x::\$predicate, where predicate is a boolean function or a type") end - name, predicate = arguments(s) + name, predicate = children(s) + if !(predicate isa Symbol) && isdefined(mod, predicate) + error("Invalid predicate in $s. Predicates must be names of functions or types defined in current module.") + end name ∉ pvars && push!(pvars, name) - return :($PatSegment($(QuoteNode(name)), -1, $predicate, $(QuoteNode(predicate)))) + return PatSegment(name, -1, getfield(mod, predicate)) end -function makesegment(name::Symbol, pvars) +function makesegment(name::Symbol, pvars, mod) name ∉ pvars && push!(pvars, name) PatSegment(name) end -function makevar(s::Expr, pvars) - if !(exprhead(s) == :(::)) - error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function or a type") +function makevar(s::Expr, pvars, mod) + if s.head != :(::) + throw( + DomainError( + s, + "Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function or a type", + ), + ) end - name, predicate = arguments(s) + name, predicate = children(s) + if !(predicate isa Symbol) && isdefined(mod, predicate) + error("Invalid predicate in $s. Predicates must be names of functions or types defined in current module.") + end name ∉ pvars && push!(pvars, name) - return :($PatVar($(QuoteNode(name)), -1, $predicate, $(QuoteNode(predicate)))) + return PatVar(name, -1, getfield(mod, predicate)) end -function makevar(name::Symbol, pvars) +function makevar(name::Symbol, pvars, mod) name ∉ pvars && push!(pvars, name) PatVar(name) end @@ -54,91 +69,115 @@ end # Make a dynamic rule right hand side function makeconsequent(expr::Expr) - head = exprhead(expr) - args = arguments(expr) - op = operation(expr) - if head === :call + if iscall(expr) + op = operation(expr) + args = arguments(expr) if op === :(~) - if args[1] isa Symbol - return args[1] - elseif args[1] isa Expr && operation(args[1]) == :(~) - n = arguments(args[1])[1] - @assert n isa Symbol - return n - else - error("Error when parsing right hand side") + let v = args[1] + if v isa Symbol + v + elseif v isa Expr && iscall(v) && operation(v) === :(~) + n = v.args[2] + @assert n isa Symbol + n + else + throw( + DomainError( + v, + "Wrong usage of `~` in patterns. Must be a pattern variable `~x` or a segment variable `~~x`", + ), + ) + end end else - return Expr(head, makeconsequent(op), map(makeconsequent, args)...) + Expr(expr.head, makeconsequent(op), map(makeconsequent, args)...) end else - return Expr(head, map(makeconsequent, args)...) + Expr(expr.head, map(makeconsequent, children(expr))...) end end makeconsequent(x) = x # treat as a literal -function makepattern(x, pvars, slots, mod = @__MODULE__, splat = false) - x in slots ? (splat ? makesegment(x, pvars) : makevar(x, pvars)) : x +function makepattern(x, pvars, slots, mod, splat = false) + if x in slots + splat ? makesegment(x, pvars, mod) : makevar(x, pvars, mod) + elseif x isa Symbol + PatLiteral(getfield(mod, x)) + elseif x isa QuoteNode + PatLiteral(x.value) + else + PatLiteral(x) + end end function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false) - head = exprhead(ex) - op = operation(ex) - # Retrieve the function object if available - # Optionally quote function objects - args = arguments(ex) - istree(op) && (op = makepattern(op, pvars, slots, mod)) - - if head === :call - if operation(ex) === :(~) # is a variable or segment + h = head(ex) + + if iscall(ex) + op = operation(ex) + # If operation is a pattern variable + iscall(op) && operation(op) == :(~) && (op = makepattern(op, pvars, slots, mod)) + # Optionally quote function objects + args = arguments(ex) + if op === :(~) # is a variable or segment let v = args[1] - if v isa Expr && operation(v) == :(~) + if v isa Expr && iscall(v) && operation(v) === :(~) # matches ~~x::predicate or ~~x::predicate... - makesegment(arguments(v)[1], pvars) + makesegment(v.args[2], pvars, mod) elseif splat # matches ~x::predicate... - makesegment(v, pvars) + makesegment(v, pvars, mod) else - makevar(v, pvars) + makevar(v, pvars, mod) end end - else # Matches a term + else# Matches a term patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - :($PatTerm(:call, $(function_object_or_quote(op, mod)), [$(patargs...)])) + isdef = isdefined_nested(mod, op) + op_obj = isdef ? getfield_nested(mod, op) : op + + if isdef && op isa Expr || op isa Symbol + # Support fully qualified function symbols such as `Main.foo` + PatExpr(iscall(ex), op_obj, op, patargs) + else + PatExpr(iscall(ex), op_obj, patargs) + end end - elseif head === :... - makepattern(args[1], pvars, slots, mod, true) - elseif head == :(::) && args[1] in slots - splat ? makesegment(ex, pvars) : makevar(ex, pvars) - elseif head === :ref - # getindex - patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - :($PatTerm(:ref, getindex, [$(patargs...)])) - elseif head === :$ - args[1] + elseif h === :... + makepattern(ex.args[1], pvars, slots, mod, true) + elseif h == :(::) && ex.args[1] in slots + splat ? makesegment(ex, pvars) : makevar(ex, pvars, mod) + elseif h === :$ + ex.args[1] else - patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - :($PatTerm($(QuoteNode(head)), $(function_object_or_quote(op, mod)), [$(patargs...)])) + patargs = map(i -> makepattern(i, pvars, slots, mod), ex.args) # recurse + PatExpr(false, h, patargs) end end -function rule_sym_map(ex::Expr) - h = operation(ex) - if h == :(-->) || h == :(→) - RewriteRule - elseif h == :(=>) - DynamicRule - elseif h == :(==) - EqualityRule - elseif h == :(!=) || h == :(≠) - UnequalRule - else - error("Cannot parse rule with operator '$h'") - end +# If it's not a symbol or expr, it's defined. +isdefined_nested(mod, ex) = true +isdefined_nested(mod, ex::Symbol) = isdefined(mod, ex) +function isdefined_nested(mod, ex::Expr) + @assert ex.head === :. + r_unquoted = ex.args[2] + r = r_unquoted isa QuoteNode ? r_unquoted.value : r_unquoted + isdefined_nested(mod, ex.args[1]) || return false + + isdefined_nested(getfield_nested(mod, ex.args[1]), r) +end + +# If it's not a symbol or expr, it's defined. +getfield_nested(mod, ex) = ex +getfield_nested(mod, ex::Symbol) = getfield(mod, ex) +function getfield_nested(mod, ex::Expr) + @assert ex.head === :. + r_unquoted = ex.args[2] + r = r_unquoted isa QuoteNode ? r_unquoted.value : r_unquoted + getfield(getfield_nested(mod, ex.args[1]), r) end -rule_sym_map(ex) = error("Cannot parse rule from $ex") """ rewrite_rhs(expr::Expr) @@ -147,8 +186,8 @@ Rewrite the `expr` by dealing with `:where` if necessary. The `:where` is rewritten from, for example, `~x where f(~x)` to `f(~x) ? ~x : nothing`. """ function rewrite_rhs(ex::Expr) - if exprhead(ex) == :where - rhs, predicate = arguments(ex) + if ex.head == :where + rhs, predicate = children(ex) return :($predicate ? $rhs : nothing) end ex @@ -158,9 +197,13 @@ rewrite_rhs(x) = x function addslots(expr, slots) if expr isa Expr - if expr.head === :macrocall && - expr.args[1] in [Symbol("@rule"), Symbol("@capture"), Symbol("@slots"), Symbol("@theory")] - Expr(:macrocall, expr.args[1:2]..., slots..., expr.args[3:end]...) + if expr.head === :macrocall + if expr.args[1] == Symbol("@rule") + name = expr.args[3] isa String ? expr.args[3] : "" + Expr(:macrocall, expr.args[1:2]..., name, slots..., expr.args[3:end]...) + elseif expr.args[1] in [Symbol("@rule"), Symbol("@capture"), Symbol("@slots"), Symbol("@theory")] + Expr(:macrocall, expr.args[1:2]..., slots..., expr.args[3:end]...) + end else Expr(expr.head, addslots.(expr.args, (slots,))...) end @@ -169,7 +212,6 @@ function addslots(expr, slots) end end - """ @slots [SLOTS...] ex Declare SLOTS as slot variables for all `@rule` or `@capture` invocations in the expression `ex`. @@ -195,14 +237,14 @@ end """ @rule [SLOTS...] LHS operator RHS -Creates an `AbstractRule` object. A rule object is callable, and takes an +Creates a `RewriteRule` object. A rule object is callable, and takes an expression and rewrites it if it matches the LHS pattern to the RHS pattern, returns `nothing` otherwise. The rule language is described below. -LHS can be any possibly nested function call expression where any of the arugments can +LHS can be any possibly nested function call expression where any of the arguments can optionally be a Slot (`~x`) or a Segment (`~x...`) (described below). -SLOTS is an optional list of symbols to be interpeted as slots or segments +SLOTS is an optional list of symbols to be interpreted as slots or segments directly (without using `~`). To declare slots for several rules at once, see the `@slots` macro. @@ -213,9 +255,9 @@ matches found for these variables in the LHS. **Rule operators**: - `LHS => RHS`: create a `DynamicRule`. The RHS is *evaluated* on rewrite. -- `LHS --> RHS`: create a `RewriteRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite. -- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `RewriteRule` but can go in both directions. Doesn't work in classical rewriting -- `LHS ≠ RHS`: create a `UnequalRule`. Can only be used in e-graphs, and is used to eagerly stop the process of rewriting if LHS is found to be equal to RHS. +- `LHS --> RHS`: create a `DirectedRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite. +- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `DirectedRule` but can go in both directions. Doesn't work in classical rewriting +- `LHS != RHS`: create a `UnequalRule`. Can only be used in e-graphs, and is used to eagerly stop the process of rewriting if LHS is found to be equal to RHS. **Slot**: @@ -327,33 +369,82 @@ Segment variables may still be written as (`~~x`), and slot (`~x`) and segment ( See also: [`@capture`](@ref), [`@slots`](@ref) """ macro rule(args...) - length(args) >= 1 || ArgumentError("@rule requires at least one argument") + length(args) >= 1 || throw(ArgumentError("@rule requires at least one argument")) + rule_name = if args[1] isa String + str = args[1] + args = args[2:end] + str + else + "" + end + slots = args[1:(end - 1)] expr = args[end] - e = macroexpand(__module__, expr) - e = rmlines(e) - RuleType = rule_sym_map(e) + ex = macroexpand(__module__, expr) + ex = rmlines(ex) + + op = iscall(ex) ? operation(ex) : head(ex) - l, r = arguments(e) + @assert op in (:(==), :(=>), :(-->), :(!=)) + + l, r = iscall(ex) ? arguments(ex) : children(ex) pvars = Symbol[] - lhs = makepattern(l, pvars, slots, __module__) - rhs = RuleType <: SymbolicRule ? esc(makepattern(r, [], slots, __module__)) : r + lhs::AbstractPat = makepattern(l, pvars, slots, __module__) + ppvars = Patterns.patvars(lhs) + + @assert pvars == ppvars + + ematcher_right_expr = :nothing + matcher_right_expr = :nothing + + rhs = rhs_original = :(println("replace me")) - if RuleType == DynamicRule + if op == :(=>) # Dynamic Rule rhs_rewritten = rewrite_rhs(r) - rhs_consequent = makeconsequent(rhs_rewritten) + rhs_original = makeconsequent(rhs_rewritten) params = Expr(:tuple, :_lhs_expr, :_egraph, pvars...) - rhs = :($(esc(params)) -> $(esc(rhs_consequent))) - return quote - $(__source__) - DynamicRule($(esc(lhs)), $rhs, $(QuoteNode(rhs_consequent))) + rhs = :($(esc(params)) -> $(esc(rhs_original))) + else + rhs = makepattern(r, pvars, slots, __module__) + setdebrujin!(rhs, pvars) + rhs_original = r + end + + setdebrujin!(lhs, pvars) + + + ematcher_left_expr = esc(ematch_compile(lhs, pvars, 1)) + + if op in (:(==), :(!=)) # Bidirectional rule + ematcher_right_expr = esc(ematch_compile(rhs, pvars, -1)) + matcher_right_expr = esc(match_compile(rhs, pvars)) + extravars = setdiff(pvars, patvars(lhs) ∩ patvars(rhs)) + if !isempty(extravars) + error("unbound pattern variables $extravars when creating bidirectional rule") end end + matcher_left_expr = match_compile(lhs, pvars) + + # FIXME => is not a function we have to use |> + op = (op == :(=>)) ? :(|>) : op + quote $(__source__) - ($RuleType)($(esc(lhs)), $rhs) + RewriteRule(; + name = $rule_name, + op = $op, + left = $lhs, + right = $rhs, + patvars = $ppvars, + ematcher_left! = $ematcher_left_expr, + ematcher_right! = $ematcher_right_expr, + matcher_left = $matcher_left_expr, + matcher_right = $matcher_right_expr, + lhs_original = $(QuoteNode(l)), + rhs_original = $(QuoteNode(rhs_original)), + ) end end @@ -384,24 +475,32 @@ julia> v = [ ``` """ macro theory(args...) - length(args) >= 1 || ArgumentError("@rule requires at least one argument") + esc(_theory(args...)) +end + +function _theory(args...) + length(args) >= 1 || ArgumentError("@theory requires at least one argument") slots = args[1:(end - 1)] expr = args[end] - e = macroexpand(__module__, expr) - e = rmlines(e) + e = rmlines(expr) # e = interp_dollar(e, __module__) - if exprhead(e) == :block - ee = Expr(:vect, map(x -> addslots(:(@rule($x)), slots), arguments(e))...) - esc(ee) - else - error("theory is not in form begin a => b; ... end") + e.head == :block || error("theory is not in form begin a => b; ... end") + + rules = children(e) + rules = map(rules) do r + if r.head == :macrocall && r.args[1] == Symbol("@rule") && r.args[2] isa Union{LineNumberNode,Nothing} + addslots(r, slots) + else + addslots(:(@rule($r)), slots) + end end + # ee = Expr(:ref, RewriteRule, map(x -> addslots(:(@rule($x)), slots))...) + Expr(:ref, RewriteRule, rules...) end - """ @capture ex pattern Uses a `Rule` object to capture an expression if it matches the `pattern`. Returns `true` and injects @@ -423,28 +522,57 @@ macro capture(args...) length(args) >= 2 || ArgumentError("@capture requires at least two arguments") slots = args[1:(end - 2)] ex = args[end - 1] - lhs = args[end] - lhs = macroexpand(__module__, lhs) - lhs = rmlines(lhs) + l = args[end] + l = macroexpand(__module__, l) + l = rmlines(l) + pvars = Symbol[] - lhs_term = makepattern(lhs, pvars, slots, __module__) - bind = Expr( - :block, - map(key -> :($(esc(key)) = getindex(__MATCHES__, findfirst((==)($(QuoteNode(key))), $pvars))), pvars)..., - ) - quote + lhs = makepattern(l, pvars, slots, __module__) + bind_exprs = Expr[] + + for key in pvars + idx = findfirst((==)(key), pvars) + push!(bind_exprs, :($(esc(key)) = __MATCHES__[$idx])) + end + + setdebrujin!(lhs, pvars) + + matcher_left_expr = match_compile(lhs, pvars) + + + ret = quote $(__source__) - lhs_pattern = $(esc(lhs_term)) - __MATCHES__ = DynamicRule(lhs_pattern, (_lhs_expr, _egraph, pvars...) -> pvars, nothing)($(esc(ex))) - if __MATCHES__ !== nothing - $bind + rule = DynamicRule(; + op = (|>), + patvars = $pvars, + left = $lhs, + right = (_lhs_expr, _egraph, pvars...) -> pvars, + matcher_left = $matcher_left_expr, + ematcher_left! = () -> (), + ) + __MATCHES__ = rule($(esc(ex))) + if !isnothing(__MATCHES__) + $(bind_exprs...) true else false end end + ret end - +macro match(target, rules) + # @show target + # println(target) + # println(rules) + t = _theory(:_, rules) + quote + $(Metatheory.Rewriters.RestartedChain)($t)($(esc(target))) + end end + +export @match + + +end \ No newline at end of file diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index ea092dd3..ecb9b5ae 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -1,166 +1,339 @@ -module EMatchCompiler - -using TermInterface -using ..Patterns -using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock! - -function ematcher(p::Any) - function literal_ematcher(next, g, data, bindings) - !islist(data) && return - ecid = lookup_pat(g, p) - if ecid > 0 && ecid == car(data) - next(bindings, 1) +Base.@kwdef mutable struct EMatchCompilerState + """ + As ground terms are matched at the beginning. + Store the index of the σ variable (address) that represents the first non-ground term. + """ + first_nonground::Int = 0 + + "Ground terms e-class IDs can be stored in a single σ variable" + ground_terms_to_addr::Dict{AbstractPat,Int} = Dict{AbstractPat,Int}() + + """ + Given a pattern variable with Debrujin index i + This vector stores the σ variable index (address) for that variable at position i + """ + patvar_to_addr::Vector{Int} = Int[] + + """ + Addresses of σ variables that should iterate e-nodes in an e-class, + used to generate `enode_idx` variables + """ + enode_idx_addresses::Vector{Int} = Int[] + + "List of actual e-matching instructions" + program::Vector{Expr} = Expr[] + + "How many σ variables are needed to e-match" + memsize = 1 +end + +function ematch_compile(p, pvars, direction) + # Create the compiler state with the right number of pattern variables + state = EMatchCompilerState(; patvar_to_addr = fill(-1, length(pvars))) + + ematch_compile_ground!(p, state, 1) + + state.first_nonground = state.memsize + state.memsize += 1 + + ematch_compile!(p, state, state.first_nonground) + + push!(state.program, yield_expr(state.patvar_to_addr, direction)) + + pat_constants_checks = check_constant_exprs!(Expr[], p) + + quote + function ($(gensym("ematcher")))( + g::$(Metatheory.EGraphs.EGraph), + rule_idx::Int, + root_id::$(Metatheory.Id), + stack::$(Metatheory.OptBuffer){UInt16}, + ematch_buffer::$(Metatheory.OptBuffer){UInt128}, + )::Int + # If the constants in the pattern are not all present in the e-graph, just return + $(pat_constants_checks...) + # Initialize σ variables (e-classes memory) and enode iteration indexes + $(make_memory(state.memsize, state.first_nonground)...) + $([:($(Symbol(:enode_idx, i)) = 1) for i in state.enode_idx_addresses]...) + + n_matches = 0 + # Backtracking stack + stack_idx = 0 + + # Instruction 0 is used to return when the backtracking stack is empty. + # We start from 1. + push!(stack, 0x0000) + pc = 0x0001 + + # We goto this label when: + # 1) After backtracking, the pc is popped from the stack. + # 2) When an instruction succeeds, the pc is incremented. + @label compute + # Instruction 0 is used to return when the backtracking stack is empty. + pc === 0x0000 && return n_matches + + # For each instruction in the program, create an if statement, + # Checking if the current value + $([:( + if pc === $(UInt16(i)) + $code + end + ) for (i, code) in enumerate(state.program)]...) + + error("unreachable code!") + + @label backtrack + pc = pop!(stack) + + @goto compute + + return -1 end end end -checktype(n, T) = istree(n) ? symtype(n) <: T : false -function predicate_ematcher(p::PatVar, pred::Type) - function type_ematcher(next, g, data, bindings) - !islist(data) && return - id = car(data) - eclass = g[id] - for (enode_idx, n) in enumerate(eclass) - if !istree(n) && operation(n) isa pred - next(assoc(bindings, p.idx, (id, enode_idx)), 1) - end - end +check_constant_exprs!(buf, p::PatLiteral) = push!(buf, :(has_constant(g, $(last(p.n))) || return 0)) +check_constant_exprs!(buf, ::AbstractPat) = buf +function check_constant_exprs!(buf, p::PatExpr) + if !(p.head isa AbstractPat) + push!(buf, :(has_constant(g, $(p.head_hash)) || has_constant(g, $(p.quoted_head_hash)) || return 0)) + end + for child in children(p) + check_constant_exprs!(buf, child) end + buf end -function predicate_ematcher(p::PatVar, pred) - function predicate_ematcher(next, g, data, bindings) - !islist(data) && return - id::Int = car(data) - eclass = g[id] - if pred(eclass) - enode_idx = 0 - # Is this for cycle needed? - for (j, n) in enumerate(eclass) - # Find first literal if available - if !istree(n) - enode_idx = j - break - end - end - next(assoc(bindings, p.idx, (id, enode_idx)), 1) +""" +Create a vector of assignment expressions in the form of +`σi = 0x0000000000000000` where `i`` is a number from 1 to n. +If `i == first_nonground`, create an expression `σi = root_id`, +where root_id is a parameter of the ematching function, defined +in scope. +""" +make_memory(n, first_nonground) = [:($(Symbol(:σ, i)) = $(i == first_nonground ? :root_id : Id(0))) for i in 1:n] + +# ============================================================== +# Ground Term E-Matchers +# TODO explain what is a ground term +# ============================================================== + +"Don't compile non-ground terms as ground terms" +ematch_compile_ground!(::AbstractPat, ::EMatchCompilerState, ::Int) = nothing + +# Ground e-matchers +function ematch_compile_ground!(p::Union{PatExpr,PatLiteral}, state::EMatchCompilerState, addr::Int) + haskey(state.ground_terms_to_addr, p) && return nothing + + if isground(p) + # Remember that it has been searched and its stored in σaddr + state.ground_terms_to_addr[p] = addr + # Add the lookup instruction to the program + push!(state.program, lookup_expr(addr, p)) + # Memory needs one more register + state.memsize += 1 + else + # Search for ground patterns in the children. + for child_p in children(p) + ematch_compile_ground!(child_p, state, state.memsize) end end end -function ematcher(p::PatVar) - pred_matcher = predicate_ematcher(p, p.predicate) +# ============================================================== +# Term E-Matchers +# ============================================================== - function var_ematcher(next, g, data, bindings) - id = car(data) - ecid = get(bindings, p.idx, 0)[1] - if ecid > 0 - ecid == id ? next(bindings, 1) : nothing - else - # Variable is not bound, check predicate and bind - pred_matcher(next, g, data, bindings) - end +function ematch_compile!(p::PatExpr, state::EMatchCompilerState, addr::Int) + if haskey(state.ground_terms_to_addr, p) + push!(state.program, check_eq_expr(addr, state.ground_terms_to_addr[p])) + return + end + + c = state.memsize + nargs = arity(p) + memrange = c:(c + nargs - 1) + state.memsize += nargs + + push!(state.enode_idx_addresses, addr) + push!(state.program, bind_expr(addr, p, memrange)) + for (i, child_p) in enumerate(arguments(p)) + ematch_compile!(child_p, state, memrange[i]) end end -Base.@pure @inline checkop(x::Union{Function,DataType}, op) = isequal(x, op) || isequal(nameof(x), op) -Base.@pure @inline checkop(x, op) = isequal(x, op) -function canbind(p::PatTerm) - eh = exprhead(p) - op = operation(p) - ar = arity(p) - function canbind(n) - istree(n) && exprhead(n) == eh && checkop(op, operation(n)) && arity(n) == ar +function ematch_compile!(p::PatVar, state::EMatchCompilerState, addr::Int) + instruction = if state.patvar_to_addr[p.idx] != -1 + # Pattern variable with the same Debrujin index has appeared in the + # pattern before this. Just check if the current e-class id matches the one + # That was already encountered. + check_eq_expr(addr, state.patvar_to_addr[p.idx]) + else + # Variable has not been seen before. Store its memory address + state.patvar_to_addr[p.idx] = addr + # insert instruction for checking predicates or type. + push!(state.enode_idx_addresses, addr) + check_var_expr(addr, p.predicate) end + push!(state.program, instruction) end +# Pattern not supported. +function ematch_compile!(p::AbstractPat, state::EMatchCompilerState, ::Int) + push!( + state.program, + :(throw(DomainError(p, "Pattern type $(typeof(p)) not supported in e-graph pattern matching")); return 0), + ) +end -function ematcher(p::PatTerm) - ematchers = map(ematcher, arguments(p)) - if isground(p) - return function ground_term_ematcher(next, g, data, bindings) - !islist(data) && return - ecid = lookup_pat(g, p) - if ecid > 0 && ecid == car(data) - next(bindings, 1) - end + +function ematch_compile!(p::PatLiteral, state::EMatchCompilerState, addr::Int) + push!(state.program, check_eq_expr(addr, state.ground_terms_to_addr[p])) +end + + +# ============================================================== +# Actual Instructions +# ============================================================== + +function bind_expr(addr, p::PatExpr, memrange) + quote + eclass = g[$(Symbol(:σ, addr))] + eclass_length = length(eclass.nodes) + if $(Symbol(:enode_idx, addr)) <= eclass_length + push!(stack, pc) + + n = eclass.nodes[$(Symbol(:enode_idx, addr))] + + v_flags(n) === $(v_flags(p.n)) || @goto $(Symbol(:skip_node, addr)) + v_signature(n) === $(v_signature(p.n)) || @goto $(Symbol(:skip_node, addr)) + v_head(n) === $(v_head(p.n)) || (v_head(n) === $(p.quoted_head_hash) || @goto $(Symbol(:skip_node, addr))) + + # Node has matched. + $([:($(Symbol(:σ, j)) = n[$i + $VECEXPR_META_LENGTH]) for (i, j) in enumerate(memrange)]...) + pc += 0x0001 + $(Symbol(:enode_idx, addr)) += 1 + @goto compute + + @label $(Symbol(:skip_node, addr)) + # This node did not match. Try next node and backtrack. + $(Symbol(:enode_idx, addr)) += 1 + @goto backtrack end + + + # # Restart from first option + $(Symbol(:enode_idx, addr)) = 1 + @goto backtrack end +end - canbindtop = canbind(p) - function term_ematcher(success, g, data, bindings) - !islist(data) && return nothing +function check_var_expr(addr, predicate::typeof(alwaystrue)) + quote + # eclass = g[$(Symbol(:σ, addr))] + # for (j, n) in enumerate(eclass.nodes) + # if !v_isexpr(n) + # $(Symbol(:enode_idx, addr)) = j + 1 + # break + # end + # end + pc += 0x0001 + @goto compute + end +end - function loop(children_eclass_ids, bindings′, ematchers′) - if !islist(ematchers′) - # term is empty - if !islist(children_eclass_ids) - # we have correctly matched the term - return success(bindings′, 1) +function check_var_expr(addr, predicate::Function) + quote + eclass = g[$(Symbol(:σ, addr))] + if ($predicate)(g, eclass) + for (j, n) in enumerate(eclass.nodes) + if !v_isexpr(n) + $(Symbol(:enode_idx, addr)) = j + 1 + break end - return nothing - end - car(ematchers′)(g, children_eclass_ids, bindings′) do b, n_of_matched # next - # recursion case: - # take the first matcher, on success, - # keep looping by matching the rest - # by removing the first n matched elements - # from the term, with the bindings, - loop(drop_n(children_eclass_ids, n_of_matched), b, cdr(ematchers′)) end + pc += 0x0001 + @goto compute end + @goto backtrack + end +end + - for n in g[car(data)] - if canbindtop(n) - loop(LL(arguments(n), 1), bindings, ematchers) +function check_var_expr(addr, T::Type) + quote + eclass = g[$(Symbol(:σ, addr))] + eclass_length = length(eclass.nodes) + if $(Symbol(:enode_idx, addr)) <= eclass_length + push!(stack, pc) + n = eclass.nodes[$(Symbol(:enode_idx, addr))] + + if !v_isexpr(n) + hn = Metatheory.EGraphs.get_constant(g, v_head(n)) + if hn isa $T + $(Symbol(:enode_idx, addr)) += 1 + pc += 0x0001 + @goto compute + end end + + # This node did not match. Try next node and backtrack. + $(Symbol(:enode_idx, addr)) += 1 + @goto backtrack end + + # Restart from first option + $(Symbol(:enode_idx, addr)) = 1 + @goto backtrack end end -const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}() - """ -Substitutions are efficiently represented in memory as vector of tuples of two integers. -This should allow for static allocation of matches and use of LoopVectorization.jl -The buffer has to be fairly big when e-matching. -The size of the buffer should double when there's too many matches. -The format is as follows -* The first pair denotes the index of the rule in the theory and the e-class id - of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional - the direction is right-to-left. -* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable -* The end of a substitution is delimited by (0,0) +Constructs an e-matcher instruction `Expr` that checks if 2 e-class IDs +contained in memory addresses `addr_a` and `addr_b` are equal, +backtracks otherwise. """ -function ematcher_yield(p, npvars::Int, direction::Int) - em = ematcher(p) - function ematcher_yield(g, rule_idx, id)::Int - n_matches = 0 - em(g, (id,), EMPTY_ECLASS_DICT) do b, n - maybelock!(g) do - push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) - n_matches += 1 - end +function check_eq_expr(addr_a, addr_b) + quote + if $(Symbol(:σ, addr_a)) == $(Symbol(:σ, addr_b)) + pc += 0x0001 + @goto compute + else + @goto backtrack end - n_matches end end -ematcher_yield(p, npvars) = ematcher_yield(p, npvars, 1) - -function ematcher_yield_bidir(l, r, npvars::Int) - eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) - function ematcher_yield_bidir(g, rule_idx, id)::Int - eml(g, rule_idx, id) + emr(g, rule_idx, id) +function lookup_expr(addr, p::AbstractPat) + quote + ecid = lookup_pat(g, $p) + if ecid > 0 + $(Symbol(:σ, addr)) = ecid + pc += 0x0001 + @goto compute + end + @goto backtrack end end -ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p") - -export ematcher_yield, ematcher_yield_bidir - +function yield_expr(patvar_to_addr, direction) + push_exprs = [ + :(push!(ematch_buffer, v_pair($(Symbol(:σ, addr)), reinterpret(UInt64, $(Symbol(:enode_idx, addr)) - 1)))) for + addr in patvar_to_addr + ] + quote + g.needslock && lock(g.lock) + push!(ematch_buffer, v_pair(root_id, reinterpret(UInt64, rule_idx * $direction))) + $(push_exprs...) + # Add delimiter to buffer. + push!(ematch_buffer, 0xffffffffffffffffffffffffffffffff) + n_matches += 1 + g.needslock && unlock(g.lock) + @goto backtrack + end end + diff --git a/src/match_compiler.jl b/src/match_compiler.jl new file mode 100644 index 00000000..a50ec961 --- /dev/null +++ b/src/match_compiler.jl @@ -0,0 +1,377 @@ +using Metatheory: alwaystrue +using TermInterface + +Base.@kwdef mutable struct MatchCompilerState + "For each pattern variable, store if it has already been encountered or not" + pvars_bound::Vector{Bool} + "List of actual instructions" + program::Vector{Expr} = Expr[] + "Pair of variables needed by the pattern matcher and their initial value" + variables = Pair{Symbol,Any}[] + """ + For each segment pattern variable, store the reference to the vector + that will be used to construct the view. + """ + segments::Vector{Pair{Symbol,Symbol}} = Pair{Symbol,Symbol}[] + """ + When matching segment variables, we can count how many non-segment terms + are remaining in the tail of the pattern term, to avoid matching extra terms + """ + current_term_n_remaining::Int = 0 +end + +function match_compile(p::AbstractPat, pvars) + npvars = length(pvars) + + state = MatchCompilerState(; pvars_bound = fill(false, npvars)) + + # Tree coordinates are a vector of integers. + # Each index `i` in the vector corresponds to the depth of the term + # Each value `n` at index `i` selects the `n`-th children of the term at depth i + # Example: in f(x, g(y, k, h(z))), to get z the coordinate is [2,3,1] + coordinate = Int[] + + match_compile!(p, state, coordinate, Symbol[]) + + push!(state.program, match_yield_expr(state, pvars)) + + quote + function ($(gensym("matcher")))(_term_being_matched, _callback::Function, stack::$(OptBuffer{UInt16})) + # Assign and empty the variables for patterns + $([:($(varname(var)) = nothing) for var in setdiff(pvars, first.(state.segments))]...) + + # Initialize the variables needed in the outermost scope (accessible by instruction blocks) + $([:(local $(Symbol(k)) = $v) for (k, v) in state.variables]...) + + # Backtracking stack + local stack_idx = 0 + + # Instruction 0 is used to return when the backtracking stack is empty. + # We start from 1. + push!(stack, 0x0000) + local pc = 0x0001 + + # We goto this label when: + # 1) After backtracking, the pc is popped from the stack. + # 2) When an instruction succeeds, the pc is incremented. + @label compute + # Instruction 0 is used to fail the backtracking stack is empty. + pc === 0x0000 && return nothing + + # For each instruction in the program, create an if statement, + # Checking if the current value + $([:( + if pc === $(UInt16(i)) + $code + end + ) for (i, code) in enumerate(state.program)]...) + + error("unreachable code!") + + @label backtrack + pc = pop!(stack) + @goto compute + end + end +end + +function match_yield_expr(state::MatchCompilerState, pvars) + steps = Expr[] + for (pvar, local_args) in state.segments + start_idx = Symbol(varname(pvar), :_start) + end_idx = Symbol(varname(pvar), :_end) + push!(steps, :($(varname(pvar)) = view($local_args, ($start_idx):($end_idx)))) + end + push!(steps, :(return _callback($(map(varname, pvars)...)))) + Expr(:block, steps...) +end + +# ============================================================== +# Term Matchers +# ============================================================== + +function make_coord_symbol(coordinate) + isempty(coordinate) && return :_term_being_matched + Symbol("_term_being_matched_", join(coordinate, "_")) +end + +offset_so_far(segments) = foldl( + (x, y) -> :($x + $y), + map(n -> :(length(($(Symbol(varname(n), :_start))):($(Symbol(varname(n), :_end)))) - 1), segments); + init = 0, +) + +function get_coord(coordinate, segments_so_far) + isempty(coordinate) && return :_term_being_matched + coord_obj = get_coord_obj(coordinate) + coord = get_idx(coordinate, segments_so_far) + quote + $coord <= length($coord_obj) || @goto backtrack + $(coord_obj)[$coord] + end +end + +function get_coord_obj(coordinate) + tsym = make_coord_symbol(coordinate[1:(end - 1)]) + Symbol(tsym, :_args) +end + +get_idx(coordinate, segments_so_far) = :($(last(coordinate)) + $(offset_so_far(segments_so_far))) + +# TODO FIXME Report on Julialang ? +# This workaround is needed because otherwise pattern variables named `val` +# Are going to clash with @inbounds generated val. +# See this: +# julia> @macroexpand @inbounds v[i:j] +# quote +# $(Expr(:inbounds, true)) +# local var"#11517#val" = v[i:j] +# $(Expr(:inbounds, :pop)) +# var"#11517#val" +# end +varname(patvarname::Symbol) = Symbol(:_pvar_, patvarname) + +function match_compile!(pattern::PatExpr, state::MatchCompilerState, coordinate::Vector{Int}, parent_segments) + tsym = make_coord_symbol(coordinate) + !isempty(coordinate) && push!(state.variables, tsym => nothing) + push!(state.variables, Symbol(tsym, :_op) => nothing) + push!(state.variables, Symbol(tsym, :_args) => nothing) + + pat_op = operation(pattern) + if pat_op isa PatVar + match_compile!(pat_op, state, coordinate, parent_segments, true) + end + push!(state.program, match_term_expr(pattern, coordinate, parent_segments)) + + p_args = arguments(pattern) + p_arity = length(p_args) + state.current_term_n_remaining = 0 + + segments_so_far = Symbol[] + + for (i, child_pattern) in enumerate(p_args) + state.current_term_n_remaining = p_arity - i - count(x -> (x isa PatSegment), @view(p_args[(i + 1):end])) + match_compile!(child_pattern, state, [coordinate; i], segments_so_far) + end + + push!(state.program, match_term_expr_closing(pattern, state, [coordinate; p_arity], segments_so_far)) +end + +function match_compile!( + patvar::Union{PatVar,PatSegment}, + state::MatchCompilerState, + coordinate::Vector{Int}, + parent_segments, + is_term_operation_patvar = false, +) + tsym = make_coord_symbol(coordinate[1:(end - 1)]) + tsym_args = Symbol(tsym, :_args) + + to_compare = if is_term_operation_patvar && patvar isa PatVar + :(operation($tsym)) + else + get_coord(coordinate, parent_segments) + end + instruction = if state.pvars_bound[patvar.idx] + # Pattern variable with the same Debrujin index has appeared in the + # pattern before this (is bound). Just check for equality. + match_eq_expr(patvar, state, to_compare, coordinate, parent_segments) + else + # Variable has not been seen before. Store it + state.pvars_bound[patvar.idx] = true + # insert instruction for checking predicates or type. + match_var_expr(patvar, state, to_compare, coordinate, parent_segments) + end + + + if patvar isa PatSegment + push!(parent_segments, patvar.name) + push!(state.segments, patvar.name => tsym_args) + push!(state.variables, Symbol(varname(patvar.name), :_start) => -1) + push!(state.variables, Symbol(varname(patvar.name), :_end) => -2) + push!(state.variables, Symbol(varname(patvar.name), :_n_dropped) => 0) + end + push!(state.program, instruction) +end + + +function match_compile!(p::PatLiteral, state::MatchCompilerState, coordinate::Vector{Int}, segments_so_far) + to_compare = get_coord(coordinate, segments_so_far) + push!(state.program, match_eq_expr(p, state, to_compare, coordinate, segments_so_far)) +end + +# ============================================================== +# Actual Instructions +# ============================================================== + +function match_term_op(pattern, tsym, ::Union{Function,DataType,UnionAll}) + t_op = Symbol(tsym, :_op) + :($t_op == $(pattern.head) || $t_op == $(QuoteNode(pattern.quoted_head)) || @goto backtrack) +end + +match_term_op(pattern, tsym, ::Union{Symbol,Expr}) = + :($(Symbol(tsym, :_op)) == $(QuoteNode(pattern.head)) || @goto backtrack) + +match_term_op(::AbstractPat, tsym, patvar::PatVar) = + :($(Symbol(tsym, :_op)) == $(varname(patvar.name)) || @goto backtrack) + + +function match_term_expr(pattern::PatExpr, coordinate, segments_so_far) + tsym = make_coord_symbol(coordinate) + op_fun = iscall(pattern) ? :operation : :head + args_fun = iscall(pattern) ? :arguments : :children + + op_guard = match_term_op(pattern, tsym, operation(pattern)) + + quote + $tsym = $(get_coord(coordinate, segments_so_far)) + + isexpr($tsym) || @goto backtrack + iscall($tsym) === $(iscall(pattern)) || @goto backtrack + + $(Symbol(tsym, :_op)) = $(op_fun)($tsym) + $(Symbol(tsym, :_args)) = $(args_fun)($tsym) + + $op_guard + + pc += 0x0001 + @goto compute + end +end + +function match_term_expr_closing(pattern, state, coordinate, segments_so_far) + tsym = make_coord_symbol(coordinate[1:(end - 1)]) + tsym_args = Symbol(tsym, :_args) + + quote + if ($(get_idx(coordinate, segments_so_far))) == length($tsym_args) + pc += 0x0001 + @goto compute + end + @goto backtrack + end +end + +match_var_expr_if_guard(patvar::Union{PatVar,PatSegment}, predicate::Function) = + :($(predicate)($(varname(patvar.name)))) +match_var_expr_if_guard(patvar::Union{PatVar,PatSegment}, predicate::typeof(alwaystrue)) = true +match_var_expr_if_guard(patvar::Union{PatVar,PatSegment}, T::Type) = :($(varname(patvar.name)) isa $T) + + +function match_var_expr(patvar::PatVar, state::MatchCompilerState, to_compare, coordinate, segments_so_far) + quote + $(varname(patvar.name)) = $to_compare + if $(match_var_expr_if_guard(patvar, patvar.predicate)) + pc += 0x0001 + @goto compute + end + @goto backtrack + end +end + + +function match_var_expr(patvar::PatSegment, state::MatchCompilerState, to_compare, coordinate, segments_so_far) + tsym = make_coord_symbol(coordinate[1:(end - 1)]) + tsym_args = Symbol(tsym, :_args) + n_dropped_sym = Symbol(varname(patvar.name), :_n_dropped) + + + quote + start_idx = $(get_idx(coordinate, segments_so_far)) + end_idx = length($tsym_args) - $(state.current_term_n_remaining) + + if end_idx - $n_dropped_sym >= start_idx - 1 + push!(stack, pc) + + # $(patvar.name) = view($tsym_args, start_idx:(end_idx - $n_dropped_sym)) + $(Symbol(varname(patvar.name), :_start)) = start_idx + $(Symbol(varname(patvar.name), :_end)) = end_idx - $n_dropped_sym + + + $n_dropped_sym += 1 + + if $(match_var_expr_if_guard(patvar, patvar.predicate)) + pc += 0x0001 + @goto compute + end + + @goto backtrack + end + + # Restart + $n_dropped_sym = 0 + @goto backtrack + end +end + + + +function match_eq_expr(patvar::PatVar, state::MatchCompilerState, to_compare, coordinate, segments_so_far) + quote + if isequal($(varname(patvar.name)), $to_compare) + pc += 0x0001 + @goto compute + else + @goto backtrack + end + end +end + + +function match_eq_expr(patvar::PatSegment, state::MatchCompilerState, to_compare, coordinate, segments_so_far) + # This method should be called only when a PatSegment is already bound. + # Get parent term variable name + # TODO reuse in function, duplicate from get_coord + tsym = make_coord_symbol(coordinate[1:(end - 1)]) + tsym_args = Symbol(tsym, :_args) + + start_idx = get_idx(coordinate, segments_so_far) + + previous_local_args = nothing + for (p, args_sym) in state.segments + if patvar.name == p + previous_local_args = args_sym + end + end + @assert !isnothing(previous_local_args) + # Start and end indexes in the vector of term arguments that + # matched on the previous occurrence of the segment variable. + previous_start_idx = Symbol(varname(patvar.name), :_start) + previous_end_idx = Symbol(varname(patvar.name), :_end) + + quote + len = length(($previous_start_idx):($previous_end_idx)) + if $start_idx > length($tsym_args) + # We're checking a segment variable that was previously bound. + # We start checking from arguments of term at index `start_idx`. + # `tsym_args` are the arguments of the term. + # If `start_idx` is > than the length of the terms, we mean that + # we have no more space to match. + # This means that if the previously bound segment variable was empty, + # and contains no matches, then we can safely proceed. + # Otherwise we need to fail. + len == 0 || @goto backtrack + end + $start_idx + len - 1 <= length($tsym_args) || @goto backtrack + + for i in 1:len + # ($tsym_args)[$start_idx + i - 1] == $(patvar.name)[i] || @goto backtrack + isequal(($tsym_args)[$start_idx + i - 1], $previous_local_args[$previous_start_idx + i - 1]) || @goto backtrack + end + + + pc += 0x0001 + @goto compute + end +end + +function match_eq_expr(pat::PatLiteral, state::MatchCompilerState, to_compare, coordinate, segments_so_far) + quote + if isequal($(pat.value isa Union{Symbol,Expr} ? QuoteNode(pat.value) : pat.value), $to_compare) + pc += 0x0001 + @goto compute + else + @goto backtrack + end + end +end \ No newline at end of file diff --git a/src/matchers.jl b/src/matchers.jl deleted file mode 100644 index e93dbd14..00000000 --- a/src/matchers.jl +++ /dev/null @@ -1,172 +0,0 @@ -#### Pattern matching -### Matching procedures -# A matcher is a function which takes 3 arguments -# 1. Callback: takes arguments Dictionary × Number of elements matched -# 2. Expression -# 3. Vector of matches debrujin-indexed by pattern variables -# - -using Metatheory: islist, car, cdr, assoc, drop_n, take_n - -function matcher(val::Any) - function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing - end -end - -function matcher(slot::PatVar) - pred = slot.predicate - if slot.predicate isa Type - pred = x -> typeof(x) <: slot.predicate - end - function slot_matcher(next, data, bindings) - !islist(data) && return - val = get(bindings, slot.idx, nothing) - if val !== nothing - if isequal(val, car(data)) - return next(bindings, 1) - end - else - # Variable is not bound, first time it is found - # check the predicate - if pred(car(data)) - next(assoc(bindings, slot.idx, car(data)), 1) - end - end - end -end - -# returns n == offset, 0 if failed -function trymatchexpr(data, value, n) - if !islist(value) - return n - elseif islist(value) && islist(data) - if !islist(data) - # didn't fully match - return nothing - end - - while isequal(car(value), car(data)) - n += 1 - value = cdr(value) - data = cdr(data) - - if !islist(value) - return n - elseif !islist(data) - return nothing - end - end - - return !islist(value) ? n : nothing - elseif isequal(value, data) - return n + 1 - end -end - -function matcher(segment::PatSegment) - function segment_matcher(success, data, bindings) - val = get(bindings, segment.idx, nothing) - if val !== nothing - n = trymatchexpr(data, val, 0) - if !isnothing(n) - success(bindings, n) - end - else - res = nothing - - for i in length(data):-1:0 - subexpr = take_n(data, i) - - if segment.predicate(subexpr) - res = success(assoc(bindings, segment.idx, subexpr), i) - !isnothing(res) && break - end - end - - return res - end - end -end - -# Try to match both against a function symbol or a function object at the same time. -# Slows compile time down a bit but lets this matcher work at the same time on both purely symbolic Expr-like object. -# Execution time should not be affected. -# and SymbolicUtils-like objects that store function references as operations. -function head_matcher(f::Union{Function,DataType,UnionAll}) - checkhead(x) = isequal(x, f) || isequal(x, nameof(f)) - function head_matcher(next, data, bindings) - h = car(data) - if islist(data) && checkhead(h) - next(bindings, 1) - else - nothing - end - end -end - -head_matcher(x) = matcher(x) - -function matcher(term::PatTerm) - op = operation(term) - matchers = (head_matcher(op), map(matcher, arguments(term))...) - function term_matcher(success, data, bindings) - !islist(data) && return nothing - !istree(car(data)) && return nothing - - function loop(term, bindings′, matchers′) # Get it to compile faster - # Base case, no more matchers - if !islist(matchers′) - # term is empty - if !islist(term) - # we have correctly matched the term - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - # recursion case: - # take the first matcher, on success, - # keep looping by matching the rest - # by removing the first n matched elements - # from the term, with the bindings, - loop(drop_n(term, n), b, cdr(matchers′)) - end - end - - loop(car(data), bindings, matchers) # Try to eat exactly one term - end -end - -function TermInterface.similarterm( - x::Expr, - head::Union{Function,DataType}, - args, - symtype = nothing; - metadata = nothing, - exprhead = exprhead(x), -) - similarterm(x, nameof(head), args, symtype; metadata, exprhead) -end - -function instantiate(left, pat::PatTerm, mem) - args = [] - for parg in arguments(pat) - enqueue = parg isa PatSegment ? append! : push! - enqueue(args, instantiate(left, parg, mem)) - end - reference = istree(left) ? left : Expr(:call, :_) - similarterm(reference, operation(pat), args; exprhead = exprhead(pat)) -end - -instantiate(left, pat::Any, mem) = pat - -instantiate(left, pat::AbstractPat, mem) = error("Unsupported pattern ", pat) - -function instantiate(left, pat::PatVar, mem) - mem[pat.idx] -end - -function instantiate(left, pat::PatSegment, mem) - mem[pat.idx] -end diff --git a/src/optbuffer.jl b/src/optbuffer.jl new file mode 100644 index 00000000..94a1d3f6 --- /dev/null +++ b/src/optbuffer.jl @@ -0,0 +1,37 @@ + +"Optimized, unsafe, infinite-growing byte buffer implementation" +mutable struct OptBuffer{T<:Unsigned} + v::Vector{T} + i::Int + cap::Int + growth::Float64 +end + +function OptBuffer{T}(cap::Int, growth = 0.4) where {T<:Unsigned} + v = Vector{T}(undef, cap) + OptBuffer{T}(v, 0, cap, growth) +end + +Base.@inline function Base.push!(b::OptBuffer{T}, el::T) where {T} + b.i += 1 + if b.i === b.cap + delta = ceil(Int, b.cap * b.growth) + 1 + Base._growend!(b.v, delta) + b.cap += delta + end + @inbounds b.v[b.i] = el + b +end + +Base.@inline function Base.pop!(b::OptBuffer{T})::T where {T} + # THIS IS UNSAFE! ASSUMES ALWAYS THAT b.i is > 1 + val = @inbounds b.v[b.i] + b.i -= 1 + val +end + + +Base.isempty(b::OptBuffer{T}) where {T} = b.i === 0 +Base.empty!(b::OptBuffer{T}) where {T} = (b.i = 0) +@inline Base.length(b::OptBuffer{T}) where {T} = b.i +Base.iterate(b::OptBuffer{T}, i=1) where {T} = iterate(b.v[1:b.i], i) \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 8e627165..377132e5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,48 +1,16 @@ using Base: ImmutableDict - -function binarize(e::T) where {T} - !istree(e) && return e - head = exprhead(e) - if head == :call - op = operation(e) - args = arguments(e) - meta = metadata(e) - if op ∈ binarize_ops && arity(e) > 2 - return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args) - end - end - return e -end - -""" -Recursive version of binarize -""" -function binarize_rec(e::T) where {T} - !istree(e) && return e - head = exprhead(e) - op = operation(e) - args = map(binarize_rec, arguments(e)) - meta = metadata(e) - if head == :call - if op ∈ binarize_ops && arity(e) > 2 - return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args) - end - end - return similarterm(e, op, args, symtype(e); metadata = meta, exprhead = head) -end - - +using TimerOutputs const binarize_ops = [:(+), :(*), (+), (*)] function cleanast(e::Expr) # TODO better line removal - if isexpr(e, :block) + if e.head === :block return Expr(e.head, filter(x -> !(x isa LineNumberNode), e.args)...) end # Binarize - if isexpr(e, :call) + if iscall(e) op = e.args[1] if op ∈ binarize_ops && length(e.args) > 3 return foldl((x, y) -> Expr(:call, op, x, y), @view e.args[2:end]) @@ -51,127 +19,6 @@ function cleanast(e::Expr) return e end -# Linked List interface -@inline assoc(d::ImmutableDict, k, v) = ImmutableDict(d, k => v) - -struct LL{V} - v::V - i::Int -end - -islist(x) = istree(x) || !isempty(x) - -Base.empty(l::LL) = empty(l.v) -Base.isempty(l::LL) = l.i > length(l.v) - -Base.length(l::LL) = length(l.v) - l.i + 1 -@inline car(l::LL) = l.v[l.i] -@inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i + 1) - -# Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY -# Base.isempty(t::Term) = false -# @inline car(t::Term) = operation(t) -# @inline cdr(t::Term) = arguments(t) - -@inline car(v) = istree(v) ? operation(v) : first(v) -@inline function cdr(v) - if istree(v) - arguments(v) - else - islist(v) ? LL(v, 2) : error("asked cdr of empty") - end -end - -@inline take_n(ll::LL, n) = isempty(ll) || n == 0 ? empty(ll) : @views ll.v[(ll.i):(n + ll.i - 1)] # @views handles Tuple -@inline take_n(ll, n) = @views ll[1:n] - -@inline function drop_n(ll, n) - if n === 0 - return ll - else - istree(ll) ? drop_n(arguments(ll), n - 1) : drop_n(cdr(ll), n - 1) - end -end -@inline drop_n(ll::Union{Tuple,AbstractArray}, n) = drop_n(LL(ll, 1), n) -@inline drop_n(ll::LL, n) = LL(ll.v, ll.i + n) - - - -isliteral(::Type{T}) where {T} = x -> x isa T -is_literal_number(x) = isliteral(Number)(x) - -# are there nested ⋆ terms? -function isnotflat(⋆) - function (x) - args = arguments(x) - for t in args - if istree(t) && operation(t) === (⋆) - return true - end - end - return false - end -end - -function hasrepeats(x) - length(x) <= 1 && return false - for i in 1:(length(x) - 1) - if isequal(x[i], x[i + 1]) - return true - end - end - return false -end - -function merge_repeats(merge, xs) - length(xs) <= 1 && return false - merged = Any[] - i = 1 - - while i <= length(xs) - l = 1 - for j in (i + 1):length(xs) - if isequal(xs[i], xs[j]) - l += 1 - else - break - end - end - if l > 1 - push!(merged, merge(xs[i], l)) - else - push!(merged, xs[i]) - end - i += l - end - return merged -end - -# Take a struct definition and make it be able to match in `@rule` -macro matchable(expr) - @assert expr.head == :struct - name = expr.args[2] - if name isa Expr - name.head === :(<:) && (name = name.args[1]) - name isa Expr && name.head === :curly && (name = name.args[1]) - end - fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args) - get_name(s::Symbol) = s - get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) - fields = map(get_name, fields) - quote - $expr - TermInterface.istree(::$name) = true - TermInterface.operation(::$name) = $name - TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) - TermInterface.arity(x::$name) = $(length(fields)) - Base.length(x::$name) = $(length(fields) + 1) - end |> esc -end - - -using TimerOutputs - const being_timed = Ref{Bool}(false) macro timer(name, expr) @@ -184,54 +31,36 @@ macro timer(name, expr) ) end -macro iftimer(expr) - esc(expr) -end - -function timerewrite(f) - reset_timer!() - being_timed[] = true - x = f() - being_timed[] = false - print_timer() - println() - x -end - -""" - @timerewrite expr - -If `expr` calls `simplify` or a `RuleSet` object, track the amount of time -it spent on applying each rule and pretty print the timing. - -This uses [TimerOutputs.jl](https://github.com/KristofferC/TimerOutputs.jl). - -## Example: +"Useful for debugging: prints the content of the e-graph match buffer in readable format." +function buffer_readable(g, limit, ematch_buffer) + k = length(ematch_buffer) + + while k > limit + delimiter = ematch_buffer.v[k] + @assert delimiter == 0xffffffffffffffffffffffffffffffff + n = k - 1 + + next_delimiter_idx = 0 + n_elems = 0 + for i in n:-1:1 + n_elems += 1 + if ematch_buffer.v[i] == 0xffffffffffffffffffffffffffffffff + n_elems -= 1 + next_delimiter_idx = i + break + end + end -```julia + match_info = ematch_buffer.v[next_delimiter_idx + 1] + id = v_pair_first(match_info) + rule_idx = reinterpret(Int, v_pair_last(match_info)) + rule_idx = abs(rule_idx) -julia> expr = foldr(*, rand([a,b,c,d], 100)) -(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28) + bindings = @view ematch_buffer.v[(next_delimiter_idx + 2):n] -julia> @timerewrite simplify(expr) - ──────────────────────────────────────────────────────────────────────────────────────────────── - Time Allocations - ────────────────────── ─────────────────────── - Tot / % measured: 340ms / 15.3% 92.2MiB / 10.8% + print("$id E-Classes: ", map(x -> reinterpret(Int, v_pair_first(x)), bindings)) + print(" Nodes: ", map(x -> reinterpret(Int, v_pair_last(x)), bindings), "\n") - Section ncalls time %tot avg alloc %tot avg - ──────────────────────────────────────────────────────────────────────────────────────────────── - Rule((~y) ^ ~n * ~y => (~y) ^ (~n ... 667 11.1ms 21.3% 16.7μs 2.66MiB 26.8% 4.08KiB - RHS 92 277μs 0.53% 3.01μs 14.4KiB 0.14% 160B - Rule((~x) ^ ~n * (~x) ^ ~m => (~x)... 575 7.63ms 14.6% 13.3μs 1.83MiB 18.4% 3.26KiB - (*)(~(~(x::!issortedₑ))) => sort_arg... 831 6.31ms 12.1% 7.59μs 738KiB 7.26% 910B - RHS 164 3.03ms 5.81% 18.5μs 250KiB 2.46% 1.52KiB - ... - ... - ──────────────────────────────────────────────────────────────────────────────────────────────── -(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28) -``` -""" -macro timerewrite(expr) - :(timerewrite(() -> $(esc(expr)))) -end + k = next_delimiter_idx + end +end \ No newline at end of file diff --git a/src/vecexpr.jl b/src/vecexpr.jl new file mode 100644 index 00000000..c18059f9 --- /dev/null +++ b/src/vecexpr.jl @@ -0,0 +1,124 @@ +module VecExprModule + +export Id, + VecExpr, + VECEXPR_FLAG_ISTREE, + VECEXPR_FLAG_ISCALL, + VECEXPR_META_LENGTH, + v_new, + v_flags, + v_unset_flags!, + v_check_flags, + v_set_flag!, + v_isexpr, + v_iscall, + v_head, + v_set_head!, + v_children, + v_children_range, + v_arity, + v_hash!, + v_hash, + v_unset_hash!, + v_signature, + v_set_signature!, + v_pair, + v_pair_first, + v_pair_last + +const Id = UInt64 + +""" + struct VecExpr + data::Vector{Id} + end + +An e-node is represented by `Vector{Id}` where: +* Position 1 stores the hash of the `VecExpr`. +* Position 2 stores the bit flags (`isexpr` or `iscall`). +* Position 3 stores the signature +* Position 4 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants. +* The rest of the positions store the e-class ids of the children nodes. + +The expression is represented as an array of integers to improve performance. +The hash value for the VecExpr is cached in the first position for faster lookup performance in dictionaries. +""" +struct VecExpr + data::Vector{Id} +end + +const VECEXPR_FLAG_ISTREE = 0x01 +const VECEXPR_FLAG_ISCALL = 0x10 +const VECEXPR_META_LENGTH = 4 + +@inline v_flags(n::VecExpr)::Id = @inbounds n.data[2] +@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[2] = 0) +@inline v_check_flags(n::VecExpr, flag::Id)::Bool = !iszero(v_flags(n) & flags) +@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[2] = n.data[2] | flag) + +"""Returns `true` if the e-node ID points to a an expression tree.""" +@inline v_isexpr(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISTREE) + +"""Returns `true` if the e-node ID points to a function call.""" +@inline v_iscall(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISCALL) + +"""Number of children in the e-node.""" +@inline v_arity(n::VecExpr)::Int = length(n.data) - VECEXPR_META_LENGTH + +""" +Compute the hash of a `VecExpr` and store it as the first element. +""" +@inline function v_hash!(n::VecExpr)::Id + if iszero(n.data[1]) + n.data[1] = hash(@view n.data[2:end]) + else + # h = hash(@view n[2:end]) + # @assert h == n[1] + n.data[1] + end +end + +"""The hash of the e-node.""" +@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1] +Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here +Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end]) + +"""Set e-node hash to zero.""" +@inline v_unset_hash!(n::VecExpr)::Id = @inbounds (n.data[1] = Id(0)) + +"""E-class IDs of the children of the e-node.""" +@inline v_children(n::VecExpr) = @view n.data[(VECEXPR_META_LENGTH + 1):end] + +@inline v_signature(n::VecExpr)::Id = @inbounds n.data[3] + +@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[3] = sig) + +"The constant ID of the operation of the e-node, or the e-node ." +@inline v_head(n::VecExpr)::Id = @inbounds n.data[VECEXPR_META_LENGTH] + +"Update the E-Node operation ID." +@inline v_set_head!(n::VecExpr, h::Id) = @inbounds (n.data[VECEXPR_META_LENGTH] = h) + +"""Construct a new, empty `VecExpr` with `len` children.""" +@inline function v_new(len::Int)::VecExpr + n = VecExpr(Vector{Id}(undef, len + VECEXPR_META_LENGTH)) + v_unset_hash!(n) + v_unset_flags!(n) + n +end + +@inline v_children_range(n::VecExpr) = ((VECEXPR_META_LENGTH + 1):length(n.data)) + + +v_pair(a::UInt64, b::UInt64) = UInt128(a) << 64 | b +v_pair_first(p::UInt128)::UInt64 = UInt64(p >> 64) +v_pair_last(p::UInt128)::UInt64 = UInt64(p & 0xffffffffffffffff) + +@inline Base.length(n::VecExpr) = length(n.data) +@inline Base.getindex(n::VecExpr, i) = n.data[i] +@inline Base.setindex!(n::VecExpr, val, i) = n.data[i] = val +@inline Base.copy(n::VecExpr) = VecExpr(copy(n.data)) +@inline Base.lastindex(n::VecExpr) = lastindex(n.data) +@inline Base.firstindex(n::VecExpr) = firstindex(n.data) + +end diff --git a/test/classic/reductions.jl b/test/classic/reductions.jl index 1ceab4d6..919609af 100644 --- a/test/classic/reductions.jl +++ b/test/classic/reductions.jl @@ -1,4 +1,5 @@ -using Metatheory +using Metatheory, TermInterface +using Test @testset "Reduction Basics" begin t = @theory begin @@ -110,29 +111,142 @@ end @test df == "doesnt_fly" end -@testset "Segment Variables" begin - t = @theory begin - f(~x, ~~y) => Expr(:call, :ok, (~~y)...) - end - sf = rewrite(:(f(1, 2, 3, 4)), t) - @test sf == :(ok(2, 3, 4)) +@testset "New compiled pattern matcher" begin + r = @rule f(1, 2) --> ok() + @test isnothing(r(:(f(1, 2, 3)))) + @test r(:(f(1, 2))) == :(ok()) +end - t = @theory x y begin - f(x, y...) => Expr(:call, :ok, y...) - end - sf = rewrite(:(f(1, 2, 3, 4)), t) +@testset "PatSegment as tail" begin + r = @rule f(~x, ~~y) => Expr(:call, :ok, (~~y)...) + sf = r(:(f(1, 2, 3, 4))) @test sf == :(ok(2, 3, 4)) - t = @theory x y begin - f(x, y...) --> ok(y...) - end - sf = rewrite(:(f(1, 2, 3, 4)), t) - @test sf == :(ok(2, 3, 4)) + r = @rule x y f(x, 2, y...) => Expr(:call, :ok, y...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(3, 4)) + + sf = r(:(f(1, 2, 3))) + @test sf == :(ok(3)) + + # Empty vector + r = @rule x y f(x, 2, 3, 4, y...) --> ok(y...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok()) + + # Entire vector + r = @rule x f(x...) --> ok(x...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(1, 2, 3, 4)) + + # Nested inside + r = @rule x y g(1, f(x, 2, y...), 3) => Expr(:call, :ok, x, y...) + sf = r(:(g(1, f(1, 2, 3, 4), 3))) + @test sf == :(ok(1, 3, 4)) + + sf = r(:(g(1, f(1, 2, 3), 3))) + @test sf == :(ok(1, 3)) + + sf = r(:(g(1, f(1, 2, 3, h(4, 5), 6), 3))) + @test sf == :(ok(1, 3, h(4, 5), 6)) +end + +@testset "PatSegment as head" begin + r = @rule f(~~x, ~y) => Expr(:call, :ok, (~~x)...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(1, 2, 3)) + + r = @rule x y f(x..., 3, 4) => Expr(:call, :ok, x...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(1, 2)) + + # Single element + r = @rule x y f(x, 2, 3, y...) --> ok(y...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(4)) + + # Empty vector + r = @rule x y f(x, 2, 3, 4, y...) --> ok(y...) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok()) +end + +@testset "Multiple PatSegments" begin + r = @rule f(~~x, ~~y) --> ok(~~x, yeah(~~y)) + sf = r(:(f(1, 2, 3, 4))) + @test sf == :(ok(1, 2, 3, 4, yeah())) + + r = @rule f(~~x, 3, ~~y) --> ok(~~x, yeah(~~y)) + sf = r(:(f(1, 2, 3, 4, 5))) + @test sf == :(ok(1, 2, yeah(4, 5))) + + r = @rule f(~~x, 3, ~~y, 5, ~~z) --> ok(~~x, yeah(~~y), ~~z) + sf = r(:(f(1, 2, 3, 4, 5, 6))) + @test sf == :(ok(1, 2, yeah(4), 6)) + + r = @rule f(~~x, 3, ~~y, 5, ~~z, 7) --> ok(~~x, yeah(~~y), ~~z) + sf = r(:(f(1, 2, 2, 3, 4, 4, 5, 6, 7, 7))) + @test sf == :(ok(1, 2, 2, yeah(4, 4), 6, 7)) end +@testset "Multiple Repeated PatSegments" begin + r = @rule f(~~x, ~~x, 4) --> ok(~~x) + sf = r(:(f(1, 2, 1, 2, 4))) + @test sf == :(ok(1, 2)) + + sf = r(:(f(1, 2, 3, 4))) + @test isnothing(sf) + + sf = r(:(f(4))) + @test sf == :(ok()) + + + r = @rule f(~~x, ~~x) --> ok(~~x) + sf = r(:(f(1, 2, 1, 2))) + @test sf == :(ok(1, 2)) + + sf = r(:(f(1, 2, 3, 4))) + @test isnothing(sf) + + r = @rule f(~~x, 3, ~~x) --> ok(~~x) + sf = r(:(f(1, 2, 3, 1, 2))) + @test sf == :(ok(1, 2)) + + r = @rule f(~~x, 3, ~~x) --> ok(~~x) + sf = r(:(f(3))) + @test sf == :(ok()) + + sf = r(:(f(1, 2, 3, 4, 5))) + @test isnothing(sf) + + # Appears 3 times, doesn't work because of `offset_so_far` not counting how many times + # a variable appears + r = @rule f(~~x, 3, ~~x, 5, ~~x) --> ok(~~x) + sf = r(:(f(1, 2, 3, 1, 2, 5, 1, 2))) + @test sf == :(ok(1, 2)) + + sf = r(:(f(1, 2, 3, 3, 1, 2, 5, 1, 2))) + @test isnothing(sf) + + + r = @rule f(~~x, 3, ~~y, 5, ~~x, ~~z, 7, ~~y) --> ok(~~x, yeah(~~y), ~~z) + sf = r(:(f(1, 2, 2, 3, 4, 4, 5, 1, 2, 2, 6, 7, 7, 4, 4))) + @test sf == :(ok(1, 2, 2, yeah(4, 4), 6, 7)) +end + +@testset "Correctly checking bounds" begin + expr = :(-a - b) + r = @rule a b c (a - b) - c --> a - (b + c) + @test isnothing(r(expr)) + + + expr = :(f(g(a, a), b)) + r = @rule a b c f(g(a), c) --> a + @test isnothing(r(expr)) +end module NonCall -using Metatheory +using Metatheory, TermInterface t = [@rule a b (a, b) --> ok(a, b)] test() = rewrite(:(x, y), t) @@ -144,14 +258,13 @@ end @testset "Pattern matcher can match on both function object references and name symbols" begin - ex = :($(+)($(sin)(x)^2, $(cos)(x)^2)) r = @rule(sin(~x)^2 + cos(~x)^2 --> 1) + ex = :($(+)($(sin)(x)^2, $(cos)(x)^2)) @test r(ex) == 1 end - @testset "Pattern variable as pattern term head" begin foo(x) = x + 2 ex = :(($foo)(bar, 2, pazz)) @@ -160,7 +273,6 @@ end @test r(ex) == 4 end -using TermInterface using Metatheory.Syntax: @capture @testset "Capture form" begin @@ -168,7 +280,7 @@ using Metatheory.Syntax: @capture #note that @test inserts a soft local scope (try-catch) that would gobble #the matches from assignment statements in @capture macro, so we call it - #outside the test macro + #outside the test macro ret = @capture ex (~x)^(~x) @test ret @test @isdefined x @@ -195,26 +307,74 @@ using Metatheory.Syntax: @capture @test isnothing(f(:(b + b))) x = 1 - r = (@capture x x) + r = (@capture x ~x) @test r == true end +module QuxTest +using Metatheory, Test, TermInterface +struct Qux + args + Qux(args...) = new(args) +end +TermInterface.iscall(::Qux) = true +TermInterface.isexpr(::Qux) = true +TermInterface.head(::Qux) = Qux +TermInterface.operation(::Qux) = Qux +TermInterface.children(x::Qux) = [x.args...] +TermInterface.arguments(x::Qux) = [x.args...] + +function test() + @test (@rule Qux(1, 2) => "hello")(Qux(1, 2)) == "hello" + @test (@rule Qux(1, 2) => "hello")(Qux(3, 4)) === nothing + @test (@rule Qux(1, 2) => "hello")(1) === nothing + @test (@rule 1 => "hello")(1) == "hello" + @test (@rule 1 => "hello")(Qux(1, 2)) === nothing + @test (@capture Qux(1, 2) Qux(1, 2)) + @test false == (@capture Qux(1, 2) Qux(3, 4)) +end +end -using TermInterface -@testset "Matchable struct" begin - struct qux - args - qux(args...) = new(args) - end - TermInterface.operation(::qux) = qux - TermInterface.istree(::qux) = true - TermInterface.arguments(x::qux) = [x.args...] - @capture qux(1, 2) qux(1, 2) +module LuxTest +using Metatheory, Test, TermInterface +using Metatheory: @matchable + +@matchable struct Lux + a + b +end - @test (@rule qux(1, 2) => "hello")(qux(1, 2)) == "hello" - @test (@rule qux(1, 2) => "hello")(1) === nothing +function test() + @test (@rule Lux(1, 2) => "hello")(Lux(1, 2)) == "hello" + @test (@rule Qux(1, 2) => "hello")(Lux(3, 4)) === nothing + @test (@rule Qux(1, 2) => "hello")(1) === nothing @test (@rule 1 => "hello")(1) == "hello" - @test (@rule 1 => "hello")(qux(1, 2)) === nothing - @test (@capture qux(1, 2) qux(1, 2)) - @test false == (@capture qux(1, 2) qux(3, 4)) + @test (@rule 1 => "hello")(Lux(1, 2)) === nothing + @test (@capture Lux(1, 2) Lux(1, 2)) + @test false == (@capture Lux(1, 2) Lux(3, 4)) +end end + +@testset "Matchable struct" begin + QuxTest.test() + LuxTest.test() +end + + +## Parametric Data Types. TODO: the pattern matcher should support type parameters +@testset "Parametric Data Types are valid pattern operations" begin + abstract type Dim end + + @matchable struct Lit <: Dim + value::Int64 + end + + @matchable struct Plus{T<:Dim,U<:Dim} <: Dim + dim1::T + dim2::U + end + + r = @rule Plus(Lit(0), ~dim1) --> ~dim1 + + @test r(Plus(Lit(0), Lit(1))) == Lit(1) +end \ No newline at end of file diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl index 7a8ae892..092031f6 100644 --- a/test/egraphs/analysis.jl +++ b/test/egraphs/analysis.jl @@ -4,54 +4,48 @@ using Metatheory using Metatheory.Library -using TermInterface -EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeLiteral) = n.value +include("../../examples/prove.jl") +struct NumberFoldAnalysis + n::Number +end -# This should be auto-generated by a macro -function EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeTerm) - if exprhead(n) == :call && arity(n) == 2 - op = operation(n) - args = arguments(n) - l = g[args[1]] - r = g[args[2]] - ldata = getdata(l, :numberfold, nothing) - rdata = getdata(r, :numberfold, nothing) +Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n * b.n) +Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n) - if ldata isa Number && rdata isa Number - if op == :* - return ldata * rdata - elseif op == :+ - return ldata + rdata +# This should be auto-generated by a macro +function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr) where {ExpressionType} + h = get_constant(g, v_head(n)) + v_isexpr(n) || return h isa Number ? NumberFoldAnalysis(h) : nothing + if v_iscall(n) && v_arity(n) == 2 + args = v_children(n) + l, r = g[args[1]], g[args[2]] + + if l.data isa NumberFoldAnalysis && r.data isa NumberFoldAnalysis + if h == :* + return l.data * r.data + elseif h == :+ + return l.data + r.data end end end - - return nothing + # Could not analyze, returns nothing end -function EGraphs.join(an::Val{:numberfold}, from, to) - if from isa Number - if to isa Number - @assert from == to - else - return from - end - end - return to +function EGraphs.join(from::NumberFoldAnalysis, to::NumberFoldAnalysis) + @assert from == to + from end -function EGraphs.modify!(::Val{:numberfold}, g::EGraph, id::Int64) - eclass = g.classes[id] - d = getdata(eclass, :numberfold, nothing) - if d isa Number - merge!(g, addexpr!(g, d), id) - end +# Add the number to the eclass. +function EGraphs.modify!( + g::EGraph{ExpressionType,NumberFoldAnalysis}, + eclass::EClass{NumberFoldAnalysis}, +) where {ExpressionType} + isnothing(eclass.data) || union!(g, addexpr!(g, eclass.data.n), find(g, eclass.id)) end -EGraphs.islazy(::Val{:numberfold}) = false - comm_monoid = @theory begin ~a * ~b --> ~b * ~a @@ -59,41 +53,35 @@ comm_monoid = @theory begin ~a * (~b * ~c) --> (~a * ~b) * ~c end -G = EGraph(:(3 * 4)) -analyze!(G, :numberfold) +g = EGraph{Expr,NumberFoldAnalysis}(:(3 * 4)) -# exit(0) @testset "Basic Constant Folding Example - Commutative Monoid" begin - @test (true == @areequalg G comm_monoid 3 * 4 12) - - @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) + @test test_equality(comm_monoid, :(3 * 4), 12; g) + @test test_equality(comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2); g) end @testset "Basic Constant Folding Example 2 - Commutative Monoid" begin ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) - addexpr!(G, :(12 * a)) - @test (true == @areequalg G comm_monoid (12 * a) * b ((6 * 2) * b) * a) - @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) + g = EGraph{Expr,NumberFoldAnalysis}(ex) + addexpr!(g, :(12 * a)) + @test test_equality(comm_monoid, :((12 * a) * b), :(((6 * 2) * b) * a); g) + @test test_equality(comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); g) end @testset "Basic Constant Folding Example - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - # addexpr!(G, 12) - saturate!(G, comm_monoid) - addexpr!(G, :(a * 2)) - analyze!(G, :numberfold) - saturate!(G, comm_monoid) + g = EGraph{Expr,NumberFoldAnalysis}(:(3 * 4)) + # addexpr!(g, 12) + saturate!(g, comm_monoid) + addexpr!(g, :(a * 2)) + saturate!(g, comm_monoid) - @test (true == areequal(G, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2))) + @test test_equality(comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2); g) ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) + g = EGraph{Expr,NumberFoldAnalysis}(ex) params = SaturationParams(timeout = 15) - @test areequal(G, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params) + @test test_equality(comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params, g) end @testset "Infinite Loops analysis" begin @@ -102,10 +90,10 @@ end end - G = EGraph(:(1 * x)) + g = EGraph(:(1 * x)) params = SaturationParams(timeout = 100) - saturate!(G, boson, params) - ex = extract!(G, astsize) + saturate!(g, boson, params) + ex = extract!(g, astsize) boson = @theory begin @@ -123,229 +111,15 @@ end end -@testset "Extraction" begin - comm_monoid = @commutative_monoid (*) 1 - - fold_mul = @theory begin - ~a::Number * ~b::Number => ~a * ~b - end - - t = comm_monoid ∪ fold_mul - - - @testset "Extraction 1 - Commutative Monoid" begin - G = EGraph(:(3 * 4)) - saturate!(G, t) - @test (12 == extract!(G, astsize)) - - ex = :(a * 3 * b * 4) - G = EGraph(ex) - params = SaturationParams(timeout = 15) - saturate!(G, t, params) - extr = extract!(G, astsize) - @test extr == :((12 * a) * b) || - extr == :(12 * (a * b)) || - extr == :(a * (b * 12)) || - extr == :((a * b) * 12) || - extr == :((12a) * b) || - extr == :(a * (12b)) || - extr == :((b * (12a))) || - extr == :((b * 12) * a) || - extr == :((b * a) * 12) || - extr == :(b * (a * 12)) || - extr == :((12b) * a) - end - - fold_add = @theory begin - ~a::Number + ~b::Number => ~a + ~b - end - - @testset "Extraction 2" begin - comm_group = @commutative_group (+) 0 inv - - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add - - # for i ∈ 1:20 - # sleep(0.3) - ex = :((x * (a + b)) + (y * (a + b))) - G = EGraph(ex) - saturate!(G, t) - # end - - extract!(G, astsize) == :((y + x) * (b + a)) - end - - @testset "Extraction - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - addexpr!(G, 12) - saturate!(G, t) - addexpr!(G, :(a * 2)) - saturate!(G, t) - - saturate!(G, t) - - @test (12 == extract!(G, astsize)) - - # for i ∈ 1:100 - ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) - params = SaturationParams(timeout = 15) - saturate!(G, comm_monoid, params) - - extr = extract!(G, astsize) - - @test extr == :((12 * a) * b) || - extr == :(12 * (a * b)) || - extr == :(a * (b * 12)) || - extr == :((a * b) * 12) || - extr == :((12a) * b) || - extr == :(a * (12b)) || - extr == :((b * (12a))) || - extr == :((b * 12) * a) || - extr == :((b * a) * 12) || - extr == :(b * (a * 12)) - end - - - comm_monoid = @commutative_monoid (*) 1 - - comm_group = @commutative_group (+) 0 inv - - powers = @theory begin - ~a * ~a → (~a)^2 - ~a → (~a)^1 - (~a)^~n * (~a)^~m → (~a)^(~n + ~m) - end - logids = @theory begin - log((~a)^~n) --> ~n * log(~a) - log(~x * ~y) --> log(~x) + log(~y) - log(1) --> 0 - log(:e) --> 1 - :e^(log(~x)) --> ~x - end - - G = EGraph(:(log(e))) - params = SaturationParams(timeout = 9) - saturate!(G, logids, params) - @test extract!(G, astsize) == 1 - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add - @testset "Complex Extraction" begin - G = EGraph(:(log(e) * log(e))) - params = SaturationParams(timeout = 9) - saturate!(G, t, params) - @test extract!(G, astsize) == 1 +@testset "Conditional Dynamic Rule" begin + g = EGraph{Expr,NumberFoldAnalysis}() - G = EGraph(:(log(e) * (log(e) * e^(log(3))))) - params = SaturationParams(timeout = 7) - saturate!(G, t, params) - @test extract!(G, astsize) == 3 - - - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - @test ex == :(a^5) - - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - @test ex == :(a^5) - - function cust_astsize(n::ENodeTerm, g::EGraph) - cost = 1 + arity(n) - - if operation(n) == :^ - cost += 2 - end - - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, cust_astsize) && (cost += Inf; break) - cost += last(getdata(eclass, cust_astsize)) - end - return cost - end - - - cust_astsize(n::ENodeLiteral, g::EGraph) = 1 - - G = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) - saturate!(G, t) - ex = extract!(G, cust_astsize) - @test ex == :(5 * log(a)) || ex == :(log(a) * 5) - end - - function costfun(n::ENodeTerm, g::EGraph) - arity(n) != 2 && (return 1) - left = arguments(n)[1] - left_class = g[left] - ENodeLiteral(:a) ∈ left_class.nodes ? 1 : 100 - end - - costfun(n::ENodeLiteral, g::EGraph) = 1 - - - moveright = @theory begin - (:b * (:a * ~c)) --> (:a * (:b * ~c)) + theo_dyn_cond = @theory a begin + a => !isnothing(a.data) ? a.data.n : nothing # awkward rule to trigger a certain branch in saturation.jl end - expr = :(a * (a * (b * (a * b)))) - res = rewrite(expr, moveright) - - g = EGraph(expr) - saturate!(g, moveright) - resg = extract!(g, costfun) - - @testset "Symbols in Right hand" begin - @test resg == res == :(a * (a * (a * (b * b)))) - end - - function ⋅ end - co = @theory begin - sum(~x ⋅ :bazoo ⋅ :woo) --> sum(:n * ~x) - end - @testset "Consistency with classical backend" begin - ex = :(sum(wa(rio) ⋅ bazoo ⋅ woo)) - g = EGraph(ex) - saturate!(g, co) - - res = extract!(g, astsize) - - resclassic = rewrite(ex, co) - - @test res == resclassic - end - - - @testset "No arguments" begin - ex = :(f()) - g = EGraph(ex) - @test :(f()) == extract!(g, astsize) - - ex = :(sin() + cos()) - - t = @theory begin - sin() + cos() --> tan() - end - - gg = EGraph(ex) - saturate!(gg, t) - res = extract!(gg, astsize) - - @test res == :(tan()) - end - - - @testset "Symbol or function object operators in expressions in EGraphs" begin - ex = :(($+)(x, y)) - t = [@rule a b a + b => 2] - g = EGraph(ex) - saturate!(g, t) - @test extract!(g, astsize) == 2 - end -end + @test !test_equality(theo_dyn_cond, :x, :y, :z; g) + @test !test_equality(theo_dyn_cond, 0, 1, 2; g) +end \ No newline at end of file diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index d58ad0bf..4bc4b51e 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -1,38 +1,34 @@ -# ENV["JULIA_DEBUG"] = Metatheory -using Metatheory -using Metatheory.EGraphs -using Metatheory.EGraphs: in_same_set, find_root +using Test, Metatheory @testset "Merging" begin testexpr = :((a * 2) / 2) testmatch = :(a << 1) - G = EGraph(testexpr) - t2 = addexpr!(G, testmatch) - merge!(G, t2, EClassId(3)) - @test in_same_set(G.uf, t2, EClassId(3)) == true - # DOES NOT UPWARD MERGE + g = EGraph(testexpr) + t2 = addexpr!(g, testmatch) + union!(g, t2, Id(3)) + @test find(g, t2) == find(g, Id(3)) end # testexpr = :(42a + b * (foo($(Dict(:x => 2)), 42))) @testset "Simple congruence - rebuilding" begin - G = EGraph() - ec1 = addexpr!(G, :(f(a, b))) - ec2 = addexpr!(G, :(f(a, c))) + g = EGraph() + ec1 = addexpr!(g, :(f(a, b))) + ec2 = addexpr!(g, :(f(a, c))) testexpr = :(f(a, b) + f(a, c)) - testec = addexpr!(G, testexpr) + testec = addexpr!(g, testexpr) - t1 = addexpr!(G, :b) - t2 = addexpr!(G, :c) + t1 = addexpr!(g, :b) + t2 = addexpr!(g, :c) - c_id = merge!(G, t2, t1) - @test in_same_set(G.uf, c_id, t1) - @test in_same_set(G.uf, t2, t1) - rebuild!(G) - @test in_same_set(G.uf, ec1, ec2) + union!(g, t2, t1) + @test find(g, t2) == find(g, t1) + @test find(g, t2) == find(g, t1) + rebuild!(g) + @test find(g, ec1) == find(g, ec2) end @@ -40,34 +36,36 @@ end apply(n, f, x) = n == 0 ? x : apply(n - 1, f, f(x)) f(x) = Expr(:call, :f, x) - G = EGraph(:a) + g = EGraph{Expr}(:a) - t1 = addexpr!(G, apply(6, f, :a)) - t2 = addexpr!(G, apply(9, f, :a)) + a = addexpr!(g, :a) - c_id = merge!(G, t1, EClassId(1)) # a == apply(6,f,a) - c2_id = merge!(G, t2, EClassId(1)) # a == apply(9,f,a) + t1 = addexpr!(g, apply(6, f, :a)) + t2 = addexpr!(g, apply(9, f, :a)) + c_id = union!(g, t1, a) # a == apply(6,f,a) + c2_id = union!(g, t2, a) # a == apply(9,f,a) - rebuild!(G) + rebuild!(g) + pretty_dict(g) - t3 = addexpr!(G, apply(3, f, :a)) - t4 = addexpr!(G, apply(7, f, :a)) + t3 = addexpr!(g, apply(3, f, :a)) + t4 = addexpr!(g, apply(7, f, :a)) # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a - @test in_same_set(G.uf, t1, EClassId(1)) == true - @test in_same_set(G.uf, t2, EClassId(1)) == true - @test in_same_set(G.uf, t3, EClassId(1)) == true - @test in_same_set(G.uf, t4, EClassId(1)) == false + @test find(g, t1) == find(g, a) + @test find(g, t2) == find(g, a) + @test find(g, t3) == find(g, a) + @test find(g, t4) != find(g, a) # if m or n is prime, f(a) = a - t5 = addexpr!(G, apply(11, f, :a)) - t6 = addexpr!(G, apply(1, f, :a)) - c5_id = merge!(G, t5, EClassId(1)) # a == apply(11,f,a) + t5 = addexpr!(g, apply(11, f, :a)) + t6 = addexpr!(g, apply(1, f, :a)) + c5_id = union!(g, t5, a) # a == apply(11,f,a) - rebuild!(G) + rebuild!(g) - @test in_same_set(G.uf, t5, EClassId(1)) == true - @test in_same_set(G.uf, t6, EClassId(1)) == true + @test find(g, t5) == find(g, a) + @test find(g, t6) == find(g, a) end diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl index 72a6e58f..414c7439 100644 --- a/test/egraphs/ematch.jl +++ b/test/egraphs/ematch.jl @@ -1,77 +1,202 @@ using Metatheory +using Metatheory: OptBuffer using Test using Metatheory.Library -falseormissing(x) = x === missing || !x +# Simple E-Matching + +include("../../examples/prove.jl") + +b = OptBuffer{UInt128}(10) + +@testset "Simple Literal" begin + r = @rule 2 --> true + g = EGraph(2) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + +@testset "Composite Ground Terms" begin + r = @rule f(2, 3) --> true + g = EGraph(:(f(2, 3))) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + + g = EGraph(:(f(2, 4))) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + + + r = @rule f(2, h(3, 4)) --> true + g = EGraph(:(f(2, h(3, 4)))) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 +end + +@testset "Pattern Variables" begin + g = EGraph(:(f(2, 1))) + r = @rule ~a --> true + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 1 + @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 1 +end + +@testset "Type Assertions" begin + r = @rule ~a::Int --> true + g = EGraph(:(f(2, 1))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + g = EGraph(:3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + new_id = addexpr!(g, :f) + union!(g, g.root, new_id) + + new_id = addexpr!(g, 4) + union!(g, g.root, new_id) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 2 +end + +@testset "Predicate Assertions" begin + r = @rule ~a::iseven --> true + Base.iseven(g, ec::EClass) = + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Number && iseven(c) + end + false + end + + g = EGraph(:(f(2, 1))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + g = EGraph(:2) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + g = EGraph(:3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + new_id = addexpr!(g, :f) + union!(g, g.root, new_id) + + new_id = addexpr!(g, 4) + union!(g, g.root, new_id) + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + + +@testset "Non-Ground Terms" begin + g = EGraph(:(f(2, 1))) + r = @rule f(2, ~a) --> true + + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + + r = @rule f(~a, ~a) --> true + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + g = EGraph(:(f(2, 2))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + g = EGraph(:(f(h(3, 4), h(3, 4)))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + r = @theory begin - max(~x, ~y) → 2 * ~x % ~y - max(~x, ~y) → sin(~x) - sin(~x) → max(~x, ~x) + max(~x, ~y) --> 2 * ~x % ~y + max(~x, ~y) --> sin(~x) + sin(~x) --> max(~x, ~x) end @testset "Basic Equalities 1" begin - @test (@areequal r max(b, c) max(d, d)) == false + g = EGraph(:(max(b, c))) + + t2 = addexpr!(g, :(max(d, d))) + saturate!(g, r) + + t1 = addexpr!(g, :(max(b, c))) + + @test !in_same_class(g, t1, t2) end r = @theory begin - ~a * 1 → :foo - ~a * 2 → :bar - 1 * ~a → :baz - 2 * ~a → :mag + ~a * 1 --> :foo + ~a * 2 --> :bar + 1 * ~a --> :baz + 2 * ~a --> :mag end @testset "Matching Literals" begin - g = EGraph(:(a * 1)) - addexpr!(g, :foo) + g = EGraph() + ec_1 = addexpr!(g, :(a * 1)) + ec_2 = addexpr!(g, :(a * 2)) + ec_1r = addexpr!(g, :(1 * a)) + ec_2r = addexpr!(g, :(2 * a)) + ec_foo = addexpr!(g, :foo) + ec_bar = addexpr!(g, :bar) + ec_baz = addexpr!(g, :baz) + ec_mag = addexpr!(g, :mag) + saturate!(g, r) - @test (@areequal r a * 1 foo) == true - @test (@areequal r a * 2 foo) == false - @test (@areequal r a * 1 bar) == false - @test (@areequal r a * 2 bar) == true + @test in_same_class(g, ec_1, ec_foo) + @test !in_same_class(g, ec_2, ec_foo) + @test !in_same_class(g, ec_1, ec_bar) + @test in_same_class(g, ec_2, ec_bar) - @test (@areequal r 1 * a baz) == true - @test (@areequal r 2 * a baz) == false - @test (@areequal r 1 * a mag) == false - @test (@areequal r 2 * a mag) == true + @test in_same_class(g, ec_1r, ec_baz) + @test in_same_class(g, ec_2r, ec_mag) + @test !in_same_class(g, ec_2r, ec_baz) + @test !in_same_class(g, ec_1r, ec_mag) end comm_monoid = @commutative_monoid (*) 1 + @testset "Basic Equalities - Commutative Monoid" begin - @test true == (@areequal comm_monoid a * (c * (1 * d)) c * (1 * (d * a))) - @test true == (@areequal comm_monoid x * y y * x) - @test true == (@areequal comm_monoid (x * x) * (x * 1) x * (x * x)) + @test test_equality(comm_monoid, :(a * (c * (1 * d))), :(c * (1 * (d * a)))) + @test test_equality(comm_monoid, :(x * y), :(y * x)) + @test test_equality(comm_monoid, :((x * x) * (x * 1)), :(x * (x * x))) end comm_group = @commutative_group (+) 0 inv t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) - @testset "Basic Equalities - Comm. Monoid, Abelian Group, Distributivity" begin - @test true == (@areequal t (a * b) + (a * c) a * (b + c)) - @test true == (@areequal t a * (c * (1 * d)) c * (1 * (d * a))) - @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a) - @test true == (@areequal t (x + y) * (a + b) ((a * (x + y)) + b * (x + y)) ((x * (a + b)) + y * (a + b))) - @test true == (@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + b)) - @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a) - @test true == (@areequal t a + inv(a) 0 (x * y) + inv(x * y) 1 * 0) + @test test_equality(t, :((a * b) + (a * c)), :(a * (b + c))) + @test test_equality(t, :(a * (c * (1 * d))), :(c * (1 * (d * a)))) + @test test_equality(t, :(a + (b * (c * d))), :(((d * c) * b) + a)) + @test test_equality(t, :((x + y) * (a + b)), :((a * (x + y)) + b * (x + y)), :((x * (a + b)) + y * (a + b))) + @test test_equality(t, :(((x * a + x * b) + y * a) + y * b), :((x + y) * (a + b))) + @test test_equality(t, :(a + (b * (c * d))), :(((d * c) * b) + a)) + @test test_equality(t, :(a + inv(a)), 0, :((x * y) + inv(x * y)), :(1 * 0)) end - @testset "Basic Equalities - False statements" begin - @test falseormissing(@areequal t (a * b) + (a * c) a * (b + a)) - @test falseormissing(@areequal t (a * c) + (a * c) a * (b + c)) - @test falseormissing(@areequal t a * (c * c) c * (1 * (d * a))) - @test falseormissing(@areequal t c + (b * (c * d)) ((d * c) * b) + a) - @test falseormissing(@areequal t (x + y) * (a + c) ((a * (x + y)) + b * (x + y))) - @test falseormissing(@areequal t ((x * (a + b)) + y * (a + b)) (x + y) * (a + c)) - @test falseormissing(@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + x)) - @test falseormissing(@areequal t a + (b * (c * a)) ((d * c) * b) + a) - @test falseormissing(@areequal t a + inv(a) a) - @test falseormissing(@areequal t (x * y) + inv(x * y) 1) + @test !test_equality(t, :((a * b) + (a * c)), :(a * (b + a))) + @test !test_equality(t, :((a * c) + (a * c)), :(a * (b + c))) + @test !test_equality(t, :(a * (c * c)), :(c * (1 * (d * a)))) + @test !test_equality(t, :(c + (b * (c * d))), :(((d * c) * b) + a)) + @test !test_equality(t, :((x + y) * (a + c)), :((a * (x + y)) + b * (x + y))) + @test !test_equality(t, :((x * (a + b)) + y * (a + b)), :((x + y) * (a + c))) + @test !test_equality(t, :(((x * a + x * b) + y * a) + y * b), :((x + y) * (a + x))) + @test !test_equality(t, :(a + (b * (c * a))), :(((d * c) * b) + a)) + @test !test_equality(t, :(a + inv(a)), :a) + @test !test_equality(t, :((x * y) + inv(x * y)), 1) end # Issue 21 @@ -83,34 +208,30 @@ saturate!(g, simp_theory) @test extract!(g, astsize) == :foo module Bar -foo = 42 -export foo +var = :bar using Metatheory t = @theory begin - :woo => foo + woo(:foo) => var end -export t end module Foo -foo = 12 +var = :foo using Metatheory t = @theory begin - :woo => foo + woo(:foo) => var end -export t end -g = EGraph(:woo); +g = EGraph{Expr}(:(woo(foo))); saturate!(g, Bar.t); saturate!(g, Foo.t); -foo = 12 @testset "Different modules" begin - @test @areequalg g t 42 12 + @test in_same_class(g, addexpr!(g, :foo), addexpr!(g, :bar)) end @@ -123,15 +244,17 @@ end G = EGraph(:(3 * 4)) @testset "Basic Constant Folding Example - Commutative Monoid" begin - @test (true == @areequalg G comm_monoid 3 * 4 12) - @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) + @test test_equality(comm_monoid, :(3 * 4), 12) + @test test_equality(comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2)) end @testset "Basic Constant Folding Example 2 - Commutative Monoid" begin ex = :(a * 3 * b * 4) - G = EGraph(ex) - @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) + g = EGraph(ex) + ids = [addexpr!(g, e) for e in (:((3a) * (4b)), :((12a) * b), :(((6 * 2) * b) * a))] + saturate!(g, comm_monoid) + @test in_same_class(g, ids...) end @testset "Type Assertions in Ematcher" begin @@ -145,33 +268,52 @@ end g = EGraph(:(2 * 3)) saturate!(g, some_theory) - @test true == areequal(g, some_theory, :(2 * 3), :(sin(2, 3))) - @test true == areequal(g, some_theory, :(sin(2, 3)), :(cos(3, 2))) + @test test_equality(some_theory, :(2 * 3), :(sin(2, 3)); g) + @test test_equality(some_theory, :(sin(2, 3)), :(cos(3, 2)); g) end -Base.iszero(ec::EClass) = ENodeLiteral(0) ∈ ec @testset "Predicates in Ematcher" begin + g = EGraph(:(2 * 3)) + zero_id = addexpr!(g, 0) + some_theory = @theory begin ~a::iszero * ~b --> 0 ~a * ~b --> ~b * ~a end + Base.iszero(g::EGraph, ec::EClass) = in_same_class(g, zero_id, ec.id) + + saturate!(g, some_theory) + + @test test_equality(some_theory, :(a * b * 0), 0) +end + +@testset "Dynamic rule predicates in EMatcher" begin g = EGraph(:(2 * 3)) + zero_id = addexpr!(g, 0) + + some_theory = @theory begin + ~a * ~b => 0 where (iszero(a) || iszero(b)) + ~a * ~b --> ~b * ~a + end + + Base.iszero(ec::EClass) = in_same_class(g, zero_id, ec.id) + saturate!(g, some_theory) - @test true == areequal(g, some_theory, :(a * b * 0), 0) + @test test_equality(some_theory, :(a * b * 0), 0) end @testset "Inequalities" begin failme = @theory p begin - p ≠ !p + p != !p :foo == !:foo :foo --> :bazoo :bazoo --> :wazoo end - g = EGraph(:foo) + g = EGraph{Expr}(:foo) report = saturate!(g, failme) @test report.reason === :contradiction end diff --git a/test/egraphs/extract.jl b/test/egraphs/extract.jl new file mode 100644 index 00000000..97c667f5 --- /dev/null +++ b/test/egraphs/extract.jl @@ -0,0 +1,184 @@ + +using Metatheory +using Metatheory.Library + +comm_monoid = @commutative_monoid (*) 1 + +fold_mul = @theory begin + ~a::Number * ~b::Number => ~a * ~b +end + + + +@testset "Extraction 1 - Commutative Monoid" begin + t = comm_monoid ∪ fold_mul + g = EGraph(:(3 * 4)) + saturate!(g, t) + @test (12 == extract!(g, astsize)) + + ex = :(a * 3 * b * 4) + g = EGraph(ex) + params = SaturationParams(timeout = 15) + saturate!(g, t, params) + extr = extract!(g, astsize) + @test extr == :((12 * a) * b) || + extr == :(12 * (a * b)) || + extr == :(12 * (b * a)) || + extr == :(a * (b * 12)) || + extr == :((a * b) * 12) || + extr == :((12a) * b) || + extr == :(a * (12b)) || + extr == :((b * (12a))) || + extr == :((b * 12) * a) || + extr == :((b * a) * 12) || + extr == :(b * (a * 12)) || + extr == :((12b) * a) +end + +fold_add = @theory begin + ~a::Number + ~b::Number => ~a + ~b +end + +@testset "Extraction 2" begin + comm_group = @commutative_group (+) 0 inv + + + t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add + + ex = :((x * (a + b)) + (y * (a + b))) + g = EGraph(ex) + saturate!(g, t) + extract!(g, astsize) == :((y + x) * (b + a)) +end + +comm_monoid = @commutative_monoid (*) 1 + +comm_group = @commutative_group (+) 0 inv + +powers = @theory begin + ~a * ~a --> (~a)^2 + ~a --> (~a)^1 + (~a)^~n * (~a)^~m --> (~a)^(~n + ~m) +end +logids = @theory begin + log((~a)^~n) --> ~n * log(~a) + log(~x * ~y) --> log(~x) + log(~y) + log(1) --> 0 + log(:e) --> 1 + :e^(log(~x)) --> ~x +end + +@testset "Extraction 3" begin + g = EGraph(:(log(e))) + params = SaturationParams(timeout = 9) + saturate!(g, logids, params) + @test extract!(g, astsize) == 1 +end + +t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add + +@testset "Complex Extraction" begin + g = EGraph(:(log(e) * log(e))) + params = SaturationParams(timeout = 9) + saturate!(g, t, params) + @test extract!(g, astsize) == 1 + + g = EGraph(:(log(e) * (log(e) * e^(log(3))))) + params = SaturationParams(timeout = 7) + saturate!(g, t, params) + @test extract!(g, astsize) == 3 + + + g = EGraph(:(a^3 * a^2)) + saturate!(g, t) + ex = extract!(g, astsize) + @test ex == :(a^5) +end + +@testset "Custom Cost Function 1" begin + function cust_astsize(n::VecExpr, head, children_costs::Vector{Float64})::Float64 + v_isexpr(n) || return 1 + cost = 1 + v_arity(n) + + if head == :^ + cost += 2 + end + + cost + sum(children_costs) + end + + g = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) + saturate!(g, t) + ex = extract!(g, cust_astsize) + @test ex == :(5 * log(a)) || ex == :(log(a) * 5) +end + +@testset "Symbols in Right hand" begin + expr = :(a * (a * (b * (a * b)))) + g = EGraph(expr) + + a_id = addexpr!(g, :a) + + function costfun(n::VecExpr, op, children_costs::Vector{Float64})::Float64 + v_isexpr(n) || return 1 + v_arity(n) == 2 || return 1 + + left = v_children(n)[1] + in_same_class(g, left, a_id) ? 1 : 100 + end + + + moveright = @theory begin + (:b * (:a * ~c)) --> (:a * (:b * ~c)) + end + + res = rewrite(expr, moveright) + + saturate!(g, moveright) + resg = extract!(g, costfun) + + @test resg == res == :(a * (a * (a * (b * b)))) +end + +@testset "Consistency with classical backend" begin + co = @theory begin + sum(~x ⋅ :bazoo ⋅ :woo) --> sum(:n * ~x) + end + + ex = :(sum(wa(rio) ⋅ bazoo ⋅ woo)) + g = EGraph(ex) + saturate!(g, co) + + res = extract!(g, astsize) + resclassic = rewrite(ex, co) + + @test res == resclassic +end + + +@testset "No arguments" begin + ex = :(f()) + g = EGraph(ex) + @test :(f()) == extract!(g, astsize) + + ex = :(sin() + cos()) + + t = @theory begin + sin() + cos() --> tan() + end + + gg = EGraph(ex) + saturate!(gg, t) + res = extract!(gg, astsize) + + @test res == :(tan()) +end + + +@testset "Symbol or function object operators in expressions in EGraphs" begin + ex = :(($+)(x, y)) + t = RewriteRule[@rule a b a + b => 2] + g = EGraph(ex) + saturate!(g, t) + @test extract!(g, astsize) == 2 +end diff --git a/test/egraphs/unionfind.jl b/test/egraphs/unionfind.jl new file mode 100644 index 00000000..86c4a674 --- /dev/null +++ b/test/egraphs/unionfind.jl @@ -0,0 +1,22 @@ +using Metatheory +using Test + +n = 10 + +uf = UnionFind() +for _ in 1:n + push!(uf) +end + +union!(uf, Id(1), Id(2)) +union!(uf, Id(1), Id(3)) +union!(uf, Id(1), Id(4)) + +union!(uf, Id(6), Id(8)) +union!(uf, Id(6), Id(9)) +union!(uf, Id(6), Id(10)) + +for i in 1:n + find(uf, Id(i)) +end +@test uf.parents == Id[1, 1, 1, 1, 5, 6, 7, 6, 6, 6] diff --git a/test/integration/broken/cas.jl b/test/integration/cas.jl similarity index 57% rename from test/integration/broken/cas.jl rename to test/integration/cas.jl index 21758b71..5fd1c068 100644 --- a/test/integration/broken/cas.jl +++ b/test/integration/cas.jl @@ -1,8 +1,6 @@ -using Test -using Metatheory +using Metatheory, TermInterface, Test using Metatheory.Library using Metatheory.Schedulers -using TermInterface mult_t = @commutative_monoid (*) 1 plus_t = @commutative_monoid (+) 0 @@ -64,24 +62,48 @@ end fold_t = @theory a b begin -(a::Number) => -a a::Number + b::Number => a + b + a::Number - b::Number => a - b a::Number * b::Number => a * b - a::Number^b::Number => begin + a::Number ^ b::Number => begin + a == 0 && b <= 0 && return nothing + a < 0 && b != round(b) && return nothing # only allow integer exponents for negative base b < 0 && a isa Int && (a = float(a)) a^b end a::Number / b::Number => a / b end -using Calculus: differentiate -function ∂ end +diff_t_onearg = @theory x begin + diff(sqrt(x), x) --> 1 / 2 / sqrt(x) + diff(cbrt(x), x) --> 1 / 3 / cbrt(x)^2 + diff(log(x), x) --> 1 / x + diff(exp(x), x) --> exp(x) + diff(sin(x), x) --> cos(x) + diff(cos(x), x) --> -sin(x) + diff(tan(x), x) --> (1 + tan(x)^2) + diff(sec(x), x) --> sec(x) * tan(x) + diff(csc(x), x) --> -csc(x) * cot(x) + diff(cot(x), x) --> -(1 + cot(x)^2) +end -diff_t = @theory x y begin - ∂(y, x::Symbol) => begin - z = extract!(_egraph, simplcost; root = y.id) - differentiate(z, x) - end +diff_t_base = @theory x y n begin + diff(x, x) --> 1 + diff(n::Number, x) --> 0 + # diff(x^1, x) --> 1 # special case of next rule + diff(x^(n::Number), x) --> n * x^(n - 1) +end + +diff_t_composite = @theory x a b c begin + diff(a + b, x) == diff(a, x) + diff(b, x) + diff(a * b, x) == diff(a, x) * b + a * diff(b, x) # product rule + # if diff(b,x) == 0 + # return :( $y * $xp * ($x ^ ($y - 1)) ) + + diff(a^b, x) == a^b * (diff(b, x) * log(a) + b * (diff(a, x) / a)) end +diff_t = diff_t_base ∪ diff_t_onearg ∪ diff_t_composite + cas = fold_t ∪ mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t ∪ pow_t ∪ div_t ∪ trig_t ∪ diff_t @@ -116,40 +138,29 @@ canonical_t = @theory x y n xs ys begin end -function simplcost(n::ENodeTerm, g::EGraph) - cost = 0 + arity(n) - if operation(n) == :∂ - cost += 20 - end - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, simplcost) && (cost += Inf; break) - cost += last(getdata(eclass, simplcost)) - end - return cost +function simplcost(n::VecExpr, op, costs) + v_isexpr(n) || return 1 + # @show op + # @show(sum(costs)) + 1 + sum(costs) + (op in (:∂, diff, :diff) ? 200 : 0) end -simplcost(n::ENodeLiteral, g::EGraph) = 0 - function simplify(ex; steps = 4) params = SaturationParams( - scheduler = ScoredScheduler, - eclasslimit = 5000, - timeout = 7, - schedulerparams = (1000, 5, Schedulers.exprsize), - #stopwhen=stopwhen, + # scheduler = ScoredScheduler, + # eclasslimit = 5000, + # timeout = 7, + # schedulerparams = (match_limit = 1000, ban_length = 5), + #stopwhen=stopwhen, ) hist = UInt64[] push!(hist, hash(ex)) for i in 1:steps g = EGraph(ex) - @profview_allocs saturate!(g, cas, params) + saturate!(g, cas, params) ex = extract!(g, simplcost) ex = rewrite(ex, canonical_t) - if !TermInterface.istree(ex) - return ex - end - if hash(ex) ∈ hist + if !isexpr(ex) || hash(ex) ∈ hist return ex end push!(hist, hash(ex)) @@ -177,32 +188,21 @@ end @test :(y + sec(x)^2) == simplify(:(1 + y + tan(x)^2)) @test :(y + csc(x)^2) == simplify(:(1 + y + cot(x)^2)) +@test simplify(:(diff(x^2, x))) == :(2x) +@test_broken simplify(:(diff(x^(cos(x)), x))) == :((cos(x) / x + -(sin(x)) * log(x)) * x^cos(x)) +@test simplify(:(x * diff(x^2, x) * x)) == :(2x^3) +@test simplify(:(diff(y^3, y) * diff(x^2 + 2, x) / y * x)) == :(6 * y * x ^ 2) # :(3y * 2x^2) -# simplify(:( ∂(x^2, x))) - -simplify(:(∂(x^(cos(x)), x))) - -@test :(2x^3) == simplify(:(x * ∂(x^2, x) * x)) - -# @simplify ∂(y^3, y) * ∂(x^2 + 2, x) / y * x - -# @simplify (6 * x * x * y) +@test simplify(:(6 * x * x * y)) == :(6 * y * x^2) +@test simplify(:(diff(y^3, y) / y)) == :(3y) -# @simplify ∂(y^3, y) / y - -# # ex = :( ∂(x^(cos(x)), x) ) -# ex = :( (6 * x * x * y) ) -# g = EGraph(ex) -# saturate!(g, cas) -# g.classes -# extract!(g, simplcost; root=g.root) # params = SaturationParams( # scheduler=BackoffScheduler, # eclasslimit=5000, # timeout=7, -# schedulerparams=(1000,5), +# (match_limit = 1000, ban_length = 5), # #stopwhen=stopwhen, # ) @@ -214,63 +214,28 @@ simplify(:(∂(x^(cos(x)), x))) # ex = rewrite(ex, canonical_t; clean=false) -# FIXME this is a hack to get the test to work. -if VERSION < v"1.9.0-DEV" - function EGraphs.make(::Val{:type_analysis}, g::EGraph, n::ENodeLiteral) - v = n.value - if v == :im - typeof(im) - else - typeof(v) - end - end - - function EGraphs.make(::Val{:type_analysis}, g::EGraph, n::ENodeTerm) - symtype(n) !== Expr && return Any - if exprhead(n) != :call - # println("$n is not a call") - t = Any - # println("analyzed type of $n is $t") - return t - end - sym = operation(n) - if !(sym isa Symbol) - # println("head $sym is not a symbol") - t = Any - # println("analyzed type of $n is $t") - return t - end - - symval = getfield(@__MODULE__, sym) - child_classes = map(x -> g[x], arguments(n)) - child_types = Tuple(map(x -> getdata(x, :type_analysis, Any), child_classes)) +function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr) + h = get_constant(g, v_head(n)) + v_isexpr(n) || return (h in (:im, im) ? Complex : typeof(h)) + v_iscall(n) || return (Any) - # t = t_arr[1] - t = Core.Compiler.return_type(symval, child_types) - - if t == Union{} - throw(MethodError(symval, child_types)) - end - # println("analyzed type of $n is $t") - return t - end + op = (h isa Symbol) ? getfield(@__MODULE__, h) : op + child_types = map(id -> g[id].data, v_children(n)) + return Base.promote_op(op, child_types...) +end - EGraphs.join(::Val{:type_analysis}, from, to) = typejoin(from, to) - EGraphs.islazy(::Val{:type_analysis}) = true +EGraphs.join(from::Type, to::Type) = typejoin(from, to) - function infer(e) - g = EGraph(e) - analyze!(g, :type_analysis) - getdata(g[g.root], :type_analysis) - end +function infer(e) + g = EGraph{Expr,Type}(e) + g[g.root].data +end - ex1 = :(cos(1 + 3.0) + 4 + (4 - 4im)) - ex2 = :("ciao" * 2) - ex3 = :("ciao" * " mondo") +ex1 = :(cos(1 + 3.0) + 4 + (4 - 4im)) +ex2 = :("ciao" * 2) +ex3 = :("ciao" * " mondo") - @test ComplexF64 == infer(ex1) - @test_throws MethodError infer(ex2) - @test String == infer(ex3) -end +@test Complex == infer(ex1) +@test String == infer(ex3) diff --git a/test/integration/kb_benchmark.jl b/test/integration/kb_benchmark.jl deleted file mode 100644 index dee9d1f5..00000000 --- a/test/integration/kb_benchmark.jl +++ /dev/null @@ -1,72 +0,0 @@ -using Test -using Metatheory -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.Rules -using Metatheory.EGraphs.Schedulers - -function rep(x, op, n::Int) - foldl((x, y) -> :(($op)($x, $y)), repeat([x], n)) -end - -macro rep(x, op, n::Int) - expr = rep(x, op, n) - esc(expr) -end - -rep(:a, :*, 3) - -@rule (@rep :a (*) 3) => :b - -Mid = @theory a begin - a * :ε --> a - :ε * a --> a -end - -Massoc = @theory a b c begin - a * (b * c) --> (a * b) * c - (a * b) * c --> a * (b * c) -end - - -T = [ - @rule :b * :B --> :ε - @rule :a * :a --> :ε - @rule :b * :b * :b --> :ε - @rule :B * :B --> :B - @rule (@rep (:a * :b) (*) 7) --> :ε - @rule (@rep (:a * :b * :a * :B) (*) 7) --> :ε -] - -G = Mid ∪ Massoc ∪ T - - -another_expr = :(b * B) -g = EGraph(another_expr) -saturate!(g, G) -ex = extract!(g, astsize) -@test ex == :ε - -another_expr = :(a * a * a * a) -g = EGraph(another_expr) -some_eclass = addexpr!(g, another_expr) -saturate!(g, G) -ex = extract!(g, astsize; root = some_eclass) -@test ex == :ε - -another_expr = :(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) -g = EGraph(another_expr) -some_eclass = addexpr!(g, another_expr) -saturate!(g, G) -ex = extract!(g, astsize; root = some_eclass) -@test ex == :ε - - -expr = :(a * b * a * a * a * b * b * b * a * B * B * B * B * a) -g = EGraph(expr) -params = SaturationParams(timeout = 9, scheduler = BackoffScheduler)# , schedulerparams=(128,4))#, scheduler=SimpleScheduler) -# params = SaturationParams(timeout = 9, scheduler = SimpleScheduler)# , schedulerparams=(128,4))#, scheduler=SimpleScheduler) -report = saturate!(g, G, params) -ex = extract!(g, astsize) -@test_broken ex == :ε - diff --git a/test/integration/knuth_bendix.jl b/test/integration/knuth_bendix.jl new file mode 100644 index 00000000..6b3642db --- /dev/null +++ b/test/integration/knuth_bendix.jl @@ -0,0 +1,86 @@ +using Test +using Metatheory +using Metatheory.Library +using Metatheory.EGraphs +using Metatheory.Rules +using Metatheory.EGraphs.Schedulers + +function rep(x, op, n::Int) + foldl((x, y) -> :(($op)($x, $y)), repeat([x], n)) +end + +macro rep(x, op, n::Int) + expr = rep(x, op, n) + esc(expr) +end + +rep(:a, :*, 3) + +@rule (@rep :a (*) 3) => :b + +Mid = @theory a begin + a * :ε --> a + :ε * a --> a +end + +Massoc = @theory a b c begin + a * (b * c) == (a * b) * c + # (a * b) * c --> a * (b * c) +end + + +macro kb_theory_237_abab(n) + quote + T = [ + @rule :b * :B --> :ε + @rule :a * :a --> :ε + @rule (:b * :b) * :b --> :ε + @rule :B * :B --> :B + @rule (@rep (:a * :b) (*) 7) --> :ε + @rule (@rep (:a * :b * :a * :B) (*) $n) --> :ε + ] + group_theory = Mid ∪ Massoc ∪ T + end |> esc +end + +@kb_theory_237_abab 5 + +astsize_prefer_empty(n::VecExpr, op, costs)::Float64 = op == :ε ? 0 : astsize(n, op, costs) + +function test_kb(expr, t, params = SaturationParams()) + g = EGraph(expr) + saturate!(g, t, params) + ex = extract!(g, astsize_prefer_empty) + + # TODO: Check if group is trivial + # a = addexpr!(g, :a) + # b = addexpr!(g, :b) + # B = addexpr!(g, :B) + # ε = addexpr!(g, :ε) + # @show in_same_class(g, a, ε) + # @show in_same_class(g, b, ε) + # @show in_same_class(g, B, ε) + ex == :ε +end + + +for n in 5:8 + t = @eval @kb_theory_237_abab $n + + @test test_kb(:(b * B), group_theory) + @test test_kb(:(a * a * a * a), group_theory) + @test test_kb(:(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)), group_theory) + @test test_kb( + :(a * b * a * a * a * b * b * b * a * B * B * B * B * a), + group_theory, + SaturationParams(timeout = 5, scheduler = SimpleScheduler), + ) + + @test !test_kb(:(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (b * a)), group_theory) + @test !test_kb(:(a * a * b * a), group_theory) + @test !test_kb( + :(a * b * b * a * a * b * b * b * a * B * B * B * B * a), + group_theory, + SaturationParams(timeout = 5, scheduler = SimpleScheduler), + ) +end diff --git a/test/integration/lambda_theory.jl b/test/integration/lambda_theory.jl deleted file mode 100644 index 5e3f9ec6..00000000 --- a/test/integration/lambda_theory.jl +++ /dev/null @@ -1,148 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using Metatheory.Library -using TermInterface -using Test - -abstract type LambdaExpr end - -@matchable struct IfThenElse <: LambdaExpr - guard - then - otherwise -end - -@matchable struct Variable <: LambdaExpr - x::Symbol -end - -@matchable struct Fix <: LambdaExpr - variable - expression -end - -@matchable struct Let <: LambdaExpr - variable - value - body -end -@matchable struct λ <: LambdaExpr - x::Symbol - body -end - -@matchable struct Apply <: LambdaExpr - lambda - value -end - -@matchable struct Add <: LambdaExpr - x - y -end - -TermInterface.exprhead(::LambdaExpr) = :call - -function EGraphs.egraph_reconstruct_expression(::Type{<:LambdaExpr}, op, args; metadata = nothing, exprhead = :call) - op(args...) -end - -#%% -EGraphs.make(::Val{:freevar}, ::EGraph, n::ENodeLiteral) = Set{Int64}() - -function EGraphs.make(::Val{:freevar}, g::EGraph, n::ENodeTerm) - free = Set{Int64}() - if exprhead(n) == :call - op = operation(n) - args = arguments(n) - - if op == Variable - push!(free, args[1]) - elseif op == Let - v, a, b = args[1:3] - adata = getdata(g[a], :freevar, Set{Int64}()) - bdata = getdata(g[a], :freevar, Set{Int64}()) - union!(free, adata) - delete!(free, v) - union!(free, bdata) - elseif op == λ - v, b = args[1:2] - bdata = getdata(g[b], :freevar, Set{Int64}()) - union!(free, bdata) - delete!(free, v) - end - end - - return free -end - -EGraphs.join(::Val{:freevar}, from, to) = union(from, to) - -islazy(::Val{:freevar}) = false - -open_term = @theory x e then alt a b c begin - # if-true - IfThenElse(true, then, alt) --> then - IfThenElse(false, then, alt) --> alt - # if-elim - IfThenElse(Variable(x) == e, then, alt) => - if addexpr!(_egraph, Let(x, e, then)) == addexpr!(_egraph, Let(x, e, alt)) - alt - else - _lhs_expr - end - Add(a, b) == Add(b, a) - Add(a, Add(b, c)) == Add(Add(a, b), c) - # (a == b) == (b == a) -end - -subst_intro = @theory v body e begin - Fix(v, e) --> Let(v, Fix(v, e), e) - # beta reduction - Apply(λ(v, body), e) --> Let(v, e, body) -end - -subst_prop = @theory v e a b then alt guard begin - # let-Apply - Let(v, e, Apply(a, b)) --> Apply(Let(v, e, a), Let(v, e, b)) - # let-add - Let(v, e, a + b) --> Let(v, e, a) + Let(v, e, b) - # let-eq - # Let(v, e, a == b) --> Let(v, e, a) == Let(v, e, b) - # let-IfThenElse (let-if) - Let(v, e, IfThenElse(guard, then, alt)) --> IfThenElse(Let(v, e, guard), Let(v, e, then), Let(v, e, alt)) -end - - -subst_elim = @theory v e c v1 v2 body begin - # let-const - Let(v, e, c::Any) --> c - # let-Variable-same - Let(v1, e, Variable(v1)) --> e - # TODO fancy let-Variable-diff - Let(v1, e, Variable(v2)) => if addexpr!(_egraph, v1) != addexpr!(_egraph, v2) - :(Variable($v2)) - else - _lhs_expr - end - # let-lam-same - Let(v1, e, λ(v1, body)) --> λ(v1, body) - # let-lam-diff #TODO captureavoid - Let(v1, e, λ(v2, body)) => if v2.id ∈ getdata(e, :freevar, Set()) # is free - :(λ($fresh, Let($v1, $e, Let($v2, Variable($fresh), $body)))) - else - :(λ($v2, Let($v1, $e, $body))) - end -end - -λT = open_term ∪ subst_intro ∪ subst_prop ∪ subst_elim - -ex = λ(:x, Add(4, Apply(λ(:y, Variable(:y)), 4))) -g = EGraph(ex) - -settermtype!(g, LambdaExpr) -saturate!(g, λT) -@test λ(:x, Add(4, 4)) == extract!(g, astsize) # expected: :(λ(x, 4 + 4)) - -#%% -@test @areequal λT 2 Apply(λ(x, Variable(x)), 2) \ No newline at end of file diff --git a/test/integration/stream_fusion.jl b/test/integration/stream_fusion.jl index e3a25606..b119c069 100644 --- a/test/integration/stream_fusion.jl +++ b/test/integration/stream_fusion.jl @@ -1,8 +1,7 @@ using Metatheory +using TermInterface using Metatheory.Rewriters using Test -using TermInterface -# using SymbolicUtils apply(f, x) = f(x) fand(f, g) = x -> f(x) && g(x) @@ -42,7 +41,7 @@ end asymptot_t = @theory x y z n m f g begin (length(filter(f, x)) <= length(x)) => true length(cat(x, y)) --> length(x) + length(y) - length(map(f, x)) => length(map) + length(map(f, x)) --> length(x) length(x::UnitRange) => length(x) end @@ -50,7 +49,7 @@ fold_theory = @theory x y z begin x::Number * y::Number => x * y x::Number + y::Number => x + y x::Number / y::Number => x / y - x::Number - y::Number => x / y + x::Number - y::Number => x - y # etc... end @@ -60,34 +59,39 @@ import Base.Cartesian: inlineanonymous tryinlineanonymous(x) = nothing function tryinlineanonymous(ex::Expr) - exprhead(ex) != :call && return nothing - f = operation(ex) - (!(f isa Expr) || exprhead(f) !== :->) && return nothing - arg = arguments(ex)[1] + iscall(ex) || return nothing + op = operation(ex) + (!(op isa Expr) || op.head !== :->) && return nothing + args = arguments(ex)[1] + # TODO more args? try - return inlineanonymous(f, arg) + return inlineanonymous(op, args) catch e return nothing end end normalize_theory = @theory x y z f g begin - fand(f, g) => Expr(:->, :x, :(($f)(x) && ($g)(x))) + fand(f, g) => :(x -> ($f)(x) && ($g)(x)) apply(f, x) => Expr(:call, f, x) end -params = SaturationParams() + +function stream_fusion_cost(n::VecExpr, op, costs::Vector{Float64})::Float64 + v_isexpr(n) || return 1 + cost = 1 + v_arity(n) + op ∈ (:map, :filter) && (cost += 10) + cost + sum(costs) +end function stream_optimize(ex) g = EGraph(ex) - saturate!(g, array_theory, params) - ex = extract!(g, astsize) # TODO cost fun with asymptotic complexity - ex = Fixpoint(Postwalk(Chain([tryinlineanonymous, normalize_theory..., fold_theory...])))(ex) + saturate!(g, array_theory) + ex = extract!(g, stream_fusion_cost) # TODO cost fun with asymptotic complexity + ex = Fixpoint(Postwalk(Chain([tryinlineanonymous; normalize_theory; fold_theory])))(ex) return ex end -build_fun(ex) = eval(:(() -> $ex)) - @testset "Stream Fusion" begin ex = :(map(x -> 7 * x, fill(3, 4))) @@ -101,13 +105,10 @@ end # ['a','1','2','3','4'] ex = :(filter(ispow2, filter(iseven, reverse(reverse(fill(4, 100)))))) -opt = stream_optimize(ex) +@test Base.remove_linenums!(stream_optimize(ex)) == + Base.remove_linenums!(:(filter(x -> ispow2(x) && iseven(x), fill(4, 100)))) ex = :(map(x -> 7 * x, reverse(reverse(fill(13, 40))))) -opt = stream_optimize(ex) -opt = stream_optimize(opt) +@test stream_optimize(ex) == :(fill(91, 40)) -macro stream_optimize(ex) - stream_optimize(ex) -end diff --git a/test/integration/taylor.jl b/test/integration/taylor.jl index ff7e703f..710d3790 100644 --- a/test/integration/taylor.jl +++ b/test/integration/taylor.jl @@ -1,4 +1,4 @@ -using Metatheory +using Metatheory, Test struct Σ end diff --git a/test/integration/while_superinterpreter.jl b/test/integration/while_superinterpreter.jl index 2587be16..c21f4ad6 100644 --- a/test/integration/while_superinterpreter.jl +++ b/test/integration/while_superinterpreter.jl @@ -2,23 +2,23 @@ # # Turing Complete Interpreter using Metatheory, Test - +include(joinpath(dirname(pathof(Metatheory)), "../examples/prove.jl")) include(joinpath(dirname(pathof(Metatheory)), "../examples/while_superinterpreter_theory.jl")) @testset "Reading Memory" begin ex = :((x), $(Mem(:x => 2))) - @test true == areequal(read_mem, ex, 2) + @test true == test_equality(read_mem, ex, 2) end @testset "Arithmetic" begin - @test areequal(read_mem ∪ arithm_rules, :((2 + 3), $(Mem())), 5) + @test test_equality(read_mem ∪ arithm_rules, :((2 + 3), $(Mem())), 5) end @testset "Booleans" begin t = read_mem ∪ arithm_rules ∪ bool_rules - @test areequal(t, :((false || false), $(Mem())), false) + @test test_equality(t, :((false || false), $(Mem())), false) exx = :((false || false) || !(false || false), $(Mem(:x => 2))) g = EGraph(exx) @@ -26,52 +26,49 @@ end ex = extract!(g, astsize) @test ex == true params = SaturationParams(timeout = 12) - @test areequal(t, exx, true; params = params) + @test test_equality(t, exx, true; params = params) - @test areequal(t, :((2 < 3) && (3 < 4), $(Mem(:x => 2))), true) - @test areequal(t, :((2 < x) || !(3 < 4), $(Mem(:x => 2))), false) - @test areequal(t, :((2 < x) || !(3 < 4), $(Mem(:x => 4))), true) + @test test_equality(t, :((2 < 3) && (3 < 4), $(Mem(:x => 2))), true) + @test test_equality(t, :((2 < x) || !(3 < 4), $(Mem(:x => 2))), false) + @test test_equality(t, :((2 < x) || !(3 < 4), $(Mem(:x => 4))), true) end @testset "If Semantics" begin - @test areequal(if_language, 2, :(if true + @test test_equality(if_language, :(if true x else 0 - end, $(Mem(:x => 2)))) - @test areequal(if_language, 0, :(if false + end, $(Mem(:x => 2))), 2) + @test test_equality(if_language, :(if false x else 0 - end, $(Mem(:x => 2)))) - @test areequal(if_language, 2, :(if !(false) + end, $(Mem(:x => 2))), 0) + @test test_equality(if_language, :(if !(false) x else 0 - end, $(Mem(:x => 2)))) + end, $(Mem(:x => 2))), 2) params = SaturationParams(timeout = 10) - @test areequal(if_language, 0, :(if !(2 < x) + @test test_equality(if_language, :(if !(2 < x) x else 0 - end, $(Mem(:x => 3))); params = params) + end, $(Mem(:x => 3))), 0; params = params) end @testset "While Semantics" begin exx = :((x = 3), $(Mem(:x => 2))) g = EGraph(exx) saturate!(g, while_language) - ex = extract!(g, astsize) + @test Mem(:x => 3) == extract!(g, astsize) - @test areequal(while_language, Mem(:x => 3), exx) exx = :((x = 4; x = x + 1), $(Mem(:x => 3))) g = EGraph(exx) saturate!(g, while_language) - ex = extract!(g, astsize) + @test Mem(:x => 5) == extract!(g, astsize) - params = SaturationParams(timeout = 10) - @test areequal(while_language, Mem(:x => 5), exx; params = params) params = SaturationParams(timeout = 14, timer = false) exx = :(( @@ -81,14 +78,15 @@ end skip end ), $(Mem(:x => 3))) - @test areequal(while_language, Mem(:x => 4), exx; params = params) + @test test_equality(while_language, exx, Mem(:x => 4); params = params) exx = :((while x < 10 x = x + 1 end; x), $(Mem(:x => 3))) g = EGraph(exx) - params = SaturationParams(timeout = 100) + params = SaturationParams(timeout = 250) saturate!(g, while_language, params) @test 10 == extract!(g, astsize) end + diff --git a/test/runtests.jl b/test/runtests.jl index a02330b4..6743a1f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using Documenter using Metatheory using Test -doctest(Metatheory) +# doctest(Metatheory) function test(file::String) @info file @@ -15,6 +15,7 @@ end allscripts(dir) = [joinpath(@__DIR__, dir, x) for x in readdir(dir) if endswith(x, ".jl")] const TEST_FILES = [ + allscripts("unit") allscripts("classic") allscripts("egraphs") allscripts("integration") diff --git a/test/thesis_example.jl b/test/thesis_example.jl index 3ad808a7..bebf5a95 100644 --- a/test/thesis_example.jl +++ b/test/thesis_example.jl @@ -1,32 +1,32 @@ using Metatheory using Metatheory.EGraphs -using TermInterface using Test -# TODO update - -function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeLiteral) - if n.value isa Real - if n.value == Inf - Inf - elseif n.value == -Inf - -Inf - elseif n.value isa Real # in Julia NaN is a Real - sign(n.value) - else - nothing - end - elseif n.value isa Symbol - s = n.value - s == :x && return 1 - s == :y && return -1 - s == :z && return 0 - s == :k && return Inf - return nothing +function make_value(v::Real) + if v == Inf + Inf + elseif v == -Inf + -Inf + elseif v isa Real # in Julia NaN is a Real + sign(v) + else + nothing end end -function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeTerm) +function make_value(v::Symbol) + s = v + s == :x && return 1 + s == :y && return -1 + s == :z && return 0 + s == :k && return Inf + return nothing +end + + +function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENode) + isexpr(n) || return make_value(operation(n)) + # Let's consider only binary function call terms. if exprhead(n) == :call && arity(n) == 2 # get the symbol name of the operation diff --git a/test/tutorials/calculational_logic.jl b/test/tutorials/calculational_logic.jl index 27f35439..4b8fc374 100644 --- a/test/tutorials/calculational_logic.jl +++ b/test/tutorials/calculational_logic.jl @@ -1,6 +1,7 @@ # # Rewriting Calculational Logic -using Metatheory +using Metatheory, Test +include(joinpath(dirname(pathof(Metatheory)), "../examples/prove.jl")) include(joinpath(dirname(pathof(Metatheory)), "../examples/calculational_logic_theory.jl")) @@ -9,20 +10,19 @@ include(joinpath(dirname(pathof(Metatheory)), "../examples/calculational_logic_t saturate!(g, calculational_logic_theory) extract!(g, astsize) - @test @areequal calculational_logic_theory true ((!p == p) == false) - @test @areequal calculational_logic_theory true ((!p == !p) == true) - @test @areequal calculational_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p) - @test @areequal calculational_logic_theory true ((p ⟹ (p || p)) == true) - params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (1000, 5)) + @test test_equality(calculational_logic_theory, :((!p || !p) == !p), :(!p || p), :(!(!p && p))) - @test areequal(calculational_logic_theory, true, :(((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true); params = params) - # Frege's theorem - @test areequal(calculational_logic_theory, true, :((p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r))); params = params) + @test prove(calculational_logic_theory, :((!p == p) == false)) + @test prove(calculational_logic_theory, :((!p == !p) == true)) + @test prove(calculational_logic_theory, :((p ⟹ (p || p)) == true)) - # Demorgan's - @test @areequal calculational_logic_theory true (!(p || q) == (!p && !q)) + params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (match_limit = 1000, ban_length = 5)) + @test prove(calculational_logic_theory, :(((p ⟹ (p || p)) == ((!(p) && q) ⟹ q))), 1, 10, params) - # Consensus theorem - areequal(calculational_logic_theory, :((x && y) || (!x && z) || (y && z)), :((x && y) || (!x && z)); params = params) + ex = :((p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r))) # Frege's theorem + params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (match_limit = 6000, ban_length = 5)) + @test prove(calculational_logic_theory, ex, 2, 10, params) + + @test prove(calculational_logic_theory, :(!(p || q) == (!p && !q))) # Demorgan's end diff --git a/test/tutorials/custom_types.jl b/test/tutorials/custom_types.jl index 9a8dc3c8..69c7febc 100644 --- a/test/tutorials/custom_types.jl +++ b/test/tutorials/custom_types.jl @@ -16,18 +16,18 @@ # ## Concrete example -using Metatheory, TermInterface, Test +using Metatheory, Test using Metatheory.EGraphs +using TermInterface # We first define our custom expression type in `MyExpr`: -# It behaves like `Expr`, but it adds some extra fields. struct MyExpr head::Any args::Vector{Any} foo::String # additional metadata end -MyExpr(head, args) = MyExpr(head, args, "") -MyExpr(head) = MyExpr(head, []) +MyExpr(op, args) = MyExpr(op, args, "") +MyExpr(op) = MyExpr(op, []) # We also need to define equality for our expression. function Base.:(==)(a::MyExpr, b::MyExpr) @@ -37,30 +37,38 @@ end # ## Overriding `TermInterface`` methods # First, we need to discern when an expression is a leaf or a tree node. -# We can do it by overriding `istree`. -TermInterface.istree(::MyExpr) = true - -# The `operation` function tells us what's the node's represented operation. -TermInterface.operation(e::MyExpr) = e.head -# `arguments` tells the system how to extract the children nodes. -TermInterface.arguments(e::MyExpr) = e.args - -# A particular function is `exprhead`. It is used to bridge our custom `MyExpr` -# type, together with the `Expr` functionality that is used in Metatheory rule syntax. -# In this example we say that all expressions of type `MyExpr`, can be represented (and matched against) by -# a pattern that is represented by a `:call` Expr. -TermInterface.exprhead(::MyExpr) = :call - -# While for common usage you will always define `exprhead` it to be `:call`, +# We can do it by overriding `isexpr`. +TermInterface.isexpr(::MyExpr) = true +# By default, our expression trees always represent a function call +TermInterface.iscall(::MyExpr) = true + +# The `head` function tells us what's the node's represented operation. +TermInterface.head(e::MyExpr) = e.head +# `children` tells the system how to extract the children nodes. +TermInterface.children(e::MyExpr) = e.args + +# `operation` and `arguments` are functions used by the pattern matcher, required +# when `iscall` is true on an expression. Since our custom expression type +# **always represents function calls**, we can just define them to be `head` and `children`. +TermInterface.operation(e::MyExpr) = head(e) +TermInterface.arguments(e::MyExpr) = children(e) + +# While for common usage you will always define `iscall` to be `true`, # there are some cases where you would like to match your expression types -# against more complex patterns, for example, to match an expression `x` against an `a[b]` kind of pattern, -# you would need to inform the system that `exprhead(x)` is `:ref`, because +# against more complex patterns that are not function calls, for example, to match an expression `x` against an `a[b]` kind of pattern, +# you would need to inform the system that `iscall` is `false`, and that its operation can match against `:ref` or `getindex` because ex = :(a[b]) (ex.head, ex.args) # `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`. -TermInterface.metadata(e::MyExpr) = e.foo +# TermInterface.metadata(e::MyExpr) = e.foo + +# struct MetadataAnalysis +# metadata +# end + +# function EGraphs.make(g::EGraph{MyExprHead,MetadataAnalysis}, n::VecExpr) = # Additionally, you can override `EGraphs.preprocess` on your custom expression # to pre-process any expression before insertion in the E-Graph. @@ -68,22 +76,13 @@ TermInterface.metadata(e::MyExpr) = e.foo EGraphs.preprocess(e::MyExpr) = MyExpr(e.head, e.args, uppercase(e.foo)) -# `TermInterface` provides a very important function called `similarterm`. +# `TermInterface` provides a very important function called `maketerm`. # It is used to create a term that is in the same closure of types of `x`. -# Given an existing term `x`, it is used to instruct Metatheory how to recompose -# a similar expression, given a `head` (the result of `operation`), some children (given by `arguments`) -# and additionally, `metadata` and `exprehead`, in case you are recomposing an `Expr`. -function TermInterface.similarterm(x::MyExpr, head, args; metadata = nothing, exprhead = :call) - MyExpr(head, args, isnothing(metadata) ? "" : metadata) -end +# Given an existing head `h`, it is used to instruct Metatheory how to recompose +# a similar expression, given some children in `c` +# and additionally, `metadata` and `type`, in case you are recomposing an `Expr`. +TermInterface.maketerm(::Type{MyExpr}, h, c, metadata) = MyExpr(h, c, isnothing(metadata) ? "" : metadata) -# Since `similarterm` works by making a new term similar to an existing term `x`, -# in the e-graphs system, there won't be enough information such as a 'reference' object. -# Only the type of the object is known. This extra function adds a bit of verbosity, due to compatibility -# with SymbolicUtils.jl -function EGraphs.egraph_reconstruct_expression(::Type{MyExpr}, op, args; metadata = nothing, exprhead = nothing) - MyExpr(op, args, (isnothing(metadata) ? () : metadata)) -end # ## Theory Example @@ -96,15 +95,17 @@ end # Let's create an example expression and e-graph hcall = MyExpr(:h, [4], "hello") ex = MyExpr(:f, [MyExpr(:z, [2]), hcall]) -g = EGraph(ex; keepmeta = true) - -# We use `settermtype!` on an existing e-graph to inform the system about +# We use the first type parameter an existing e-graph to inform the system about # the *default* type of expressions that we want newly added expressions to have. -settermtype!(g, MyExpr) +g = EGraph{MyExpr}(ex) # Now let's test that it works. saturate!(g, t) -expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") + +# TODO metadata +# expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") +expected = MyExpr(:f, [MyExpr(:h, [4], "")], "") + extracted = extract!(g, astsize) @test expected == extracted diff --git a/test/tutorials/fibonacci.jl b/test/tutorials/fibonacci.jl index 4f7acb08..c8ea5556 100644 --- a/test/tutorials/fibonacci.jl +++ b/test/tutorials/fibonacci.jl @@ -1,7 +1,6 @@ # # Benchmarking Fibonacci. E-Graphs memoize computation. -using Metatheory -using Test +using Metatheory, Test function fib end diff --git a/test/tutorials/lambda_theory.jl b/test/tutorials/lambda_theory.jl new file mode 100644 index 00000000..9f74a725 --- /dev/null +++ b/test/tutorials/lambda_theory.jl @@ -0,0 +1,250 @@ +using Metatheory, Test, TermInterface + +# # Lambda theory +# +# This tutorial demonstrates how to implement a simple lambda calculus in Metatheory. +# Importantly, it shows a practical example of [*e-graph analysis*](/egraphs/#EGraph-Analyses). +# The three building blocks of lambda calculus are *variables*, $\lambda$-functions, and *function +# application*, which we can implement as subtypes of an abstract `LambdaExpr`ession: + +abstract type LambdaExpr end + +function TermInterface.maketerm(::Type{<:LambdaExpr}, head, children, metadata = nothing) + head(children...) +end + +@matchable struct Variable <: LambdaExpr + x +end +Base.show(io::IO, x::Variable) = print(io, "$(x.x)") + +@matchable struct λ <: LambdaExpr + x + body +end +function Base.show(io::IO, x::λ) + b = x.body isa Variable ? "$(x.body)" : "($(x.body))" + print(io, "λ$(x.x).$b") +end + +@matchable struct Apply <: LambdaExpr + lambda + value +end +function Base.show(io::IO, x::Apply) + l = x.lambda isa Variable ? "$(x.lambda)" : "($(x.lambda))" + v = x.value isa Variable ? "$(x.value)" : "($(x.value))" + print(io, "$l$v") +end + +# With the above we can construct arbitrary lambda expressions: + +x = Variable(:x) +λ(:x, Apply(x, x)) + +# The $\beta$-reduction can be implemented via an additional type `Let`. To get started we can ignore +# the cases where we need $\alpha$-conversion and already implement + +@matchable struct Let <: LambdaExpr + variable + value + body +end +Base.show(io::IO, x::Let) = print(io, "$(x.body)[$(x.variable) := $(x.value)]") + +λT = @theory v e c v1 v2 a b body begin + Let(v, e, c::Any) --> c # let-const + Let(v1, e, Variable(v1)) --> e # let-Variable-same + Let(v1, e, Variable(v2)) => v1 == v2 ? e : Variable(v2) # let-Variable-diff + Let(v1, e, λ(v1, body)) --> λ(v1, body) # let-lam-same + Let(v1, e, λ(v2, body)) --> λ(v2, Let(v1, e, body)) # let-lam-diff + Apply(λ(v, body), e) --> Let(v, e, body) # beta reduction + Let(v, e, Apply(a, b)) --> Apply(Let(v, e, a), Let(v, e, b)) # let-Apply +end + +x = Variable(:x) +y = Variable(:y) +ex = Apply(λ(:x, λ(:y, Apply(x, y))), y) +g = EGraph(ex) +saturate!(g, λT) +extract!(g, astsize) + + +# Unfortunately, the above does not correctly perform $\alpha$-conversion. To do +# so we need to keep track of free and bound variables in each eclass. +# Essentially, we want to add a rule to our theory which reads: +# +# ```julia +# Let(v1, e, λ(v2, body)) => if isfree(_egraph,e,v2) +# fresh = freshvar() +# λ(fresh, Let(v1, e, Let(v2, Variable(fresh), body))) +# else +# λ(v2, Let(v1, e, body)) +# end +# ``` +# +# > Recently, a much better way to represent languages with bound variables with +# > [*slotted E-Graphs*](https://pldi24.sigplan.org/details/egraphs-2024-papers/10/Slotted-E-Graphs) +# > has been proposed. They make bound variables a built in feature of the e-graph. +# +# In the more basic implementation here we just want to be able to check if a variable is free: + +function isfree(g::EGraph, eclass, var) + @assert length(var.nodes) == 1 + var_sym = get_constant(g, v_head(var.nodes[1])) + @assert var_sym isa Symbol + var_sym ∈ getdata(eclass) +end + +# This can be done via a `LambdaAnalysis` datastructure which we can include in an +# `EClass`. We overload Egraphs.make such that whenever we add a new enode to +# the egraph we keep track of the free variables. + +const LambdaAnalysis = Set{Symbol} + +getdata(eclass) = eclass.data + +function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType} + v_isexpr(n) || return LambdaAnalysis() + if v_iscall(n) + h = v_head(n) + op = get_constant(g, h) + args = v_children(n) + eclass = g[args[1]] + free = copy(getdata(eclass)) + + if op == Variable + push!(free, get_constant(g, v_head(eclass.nodes[1]))) + elseif op == Let + v, a, b = args[1:3] # v=a in b + vclass = g[v] + vsy = get_constant(g, v_head(vclass.nodes[1])) + adata = getdata(g[a]) + bdata = getdata(g[b]) + union!(free, bdata) + delete!(free, vsy) + union!(free, adata) + elseif op == λ + v, b = args[1:2] + vclass = g[v] + vsy = get_constant(g, v_head(vclass.nodes[1])) + bdata = getdata(g[b]) + union!(free, bdata) + delete!(free, vsy) + elseif op == Apply + l, v = args[1:2] + ldata = getdata(g[l]) + vdata = getdata(g[v]) + union!(free, ldata) + union!(free, vdata) + end + return free + end +end + +function EGraphs.join(from::LambdaAnalysis, to::LambdaAnalysis) + if issubset(from, to) # includes case from==to + from + elseif issubset(to, from) + to + else + error("inconsistent free variable sets from: $from to: $to") + end +end + +function fresh_var_generator() + idx = 0 + function generate() + idx += 1 + chars = collect(string(idx)) + subs = map(digit -> Char(Int(digit) + Int('₀') - Int('0')), chars) + Symbol("a$(String(subs))") + end +end + +freshvar = fresh_var_generator() + +# The final ruleset then looks like below and correctly renames variables when needed: + +λT = @theory v e c v1 v2 a b body begin + # let(v,e,body) means let v = e in body + Let(v, e, c::Any) --> c + Let(v1, e, Variable(v1)) --> e + Let(v1, e, Variable(v2)) => v1 == v2 ? e : Variable(v2) + Let(v1, e, λ(v1, body)) --> λ(v1, body) + Apply(λ(v, body), e) --> Let(v, e, body) + Let(v, e, Apply(a, b)) --> Apply(Let(v, e, a), Let(v, e, b)) + Let(v1, e, λ(v2, body)) => if isfree(_egraph, e, v2) + fresh = freshvar() + λ(fresh, Let(v1, e, Let(v2, Variable(fresh), body))) + else + λ(v2, Let(v1, e, body)) + end +end + +x = Variable(:x) +y = Variable(:y) +ex = Apply(λ(:x, λ(:y, Apply(x, y))), y) +g = EGraph{LambdaExpr,LambdaAnalysis}(ex) +params = SaturationParams( + timer = false, + check_memo = true, + check_analysis = true +) +saturate!(g, λT, params) +@test λ(:a₄, Apply(y, Variable(:a₄))) == extract!(g, astsize) +@test Set([:y]) == g[g.root].data + + +# With the above we can implement, for example, Church numerals. + +s = Variable(:s) +z = Variable(:z) +n = Variable(:n) +zero = λ(:s, λ(:z, z)) +one = λ(:s, λ(:z, Apply(s, z))) +two = λ(:s, λ(:z, Apply(s, Apply(s, z)))) +suc = λ(:n, λ(:x, λ(:y, Apply(x, Apply(Apply(n, x), y))))) + +# Compute the successor of `one`: + +freshvar = fresh_var_generator() +g = EGraph{LambdaExpr,LambdaAnalysis}(Apply(suc, one)) +params = SaturationParams( + timeout = 20, + scheduler = Schedulers.BackoffScheduler, + schedulerparams = (match_limit = 6000, ban_length = 5), + timer = false, + check_memo = true, + check_analysis = true +) +saturate!(g, λT, params) +two_ = extract!(g, astsize) +@test two_ == λ(:x, λ(:y, Apply(Variable(:x), Apply(Variable(:x), Variable(:y))))) +@test g[g.root].data == Set([]) +two_ + +# which is the same as `two` up to $\alpha$-conversion: + +two + +# check semantic analysis for free variables +function test_free_variable_analysis(expr, free) + g = EGraph{LambdaExpr,LambdaAnalysis}(expr) + g[g.root].data == free +end + +@test test_free_variable_analysis(Variable(:x), Set([:x])) +@test test_free_variable_analysis(Apply(Variable(:x), Variable(:y)), Set([:x, :y])) +@test test_free_variable_analysis(λ(:z, Variable(:x)), Set([:x])) +@test test_free_variable_analysis(λ(:z, Variable(:z)), Set{Symbol}()) +@test test_free_variable_analysis(λ(:z, λ(:x, Variable(:x))), Set{Symbol}()) + +let_expr = Let(:x, Variable(:z), λ(:x, Variable(:y))) +@test test_free_variable_analysis(let_expr, Set([:z, :y])) +# after saturation the expression becomes λ(:x, Variable(:y)) where only :y is left as free variable +freshvar = fresh_var_generator() +g = EGraph{LambdaExpr,LambdaAnalysis}(let_expr) +saturate!(g, λT, params) +@test extract!(g, astsize) == λ(:x, Variable(:y)) +@test g[g.root].data == Set([:y]) \ No newline at end of file diff --git a/test/tutorials/mu.jl b/test/tutorials/mu.jl index 8fd4ed67..995dd655 100644 --- a/test/tutorials/mu.jl +++ b/test/tutorials/mu.jl @@ -3,30 +3,48 @@ # by repeatedly applying the given rules. In other words, MU is not a theorem of # the MIU formal system. To prove this, one must step "outside" the formal system # itself. [Wikipedia](https://en.wikipedia.org/wiki/MU_puzzle#Solution) +# using Metatheory, Test -# Here are the axioms of MU: +include("../../examples/prove.jl") + +# +# Original source: Douglas Hofstadter: Gödel, Escher, Bach: An Eternal Golden Braid, 1999, pp 42-43 +# Rule 1: If you possess a string whose last letter is I, you can add a U at the end. +# Rule 2: Suppose you have Mx. Then you may add Mxx to your collection. +# Rule 3: If III occurs in one of the strings in your collection, you may make a +# new string with U in place of III. +# Rule 4: If UU occurs in one of your strings, you can drop it. + +# Here are the axioms of MU for equality saturation: # * Composition of the string monoid is associative -# * Add a uf to the end of any string ending in I +# * Add a U to the end of any string ending in I # * Double the string after the M # * Replace any III with a U # * Remove any UU +# We enforce an :END symbol, so that we do not need to handle the empty chain in UU --> \eps. function ⋅ end miu = @theory x y z begin - x ⋅ (y ⋅ z) --> (x ⋅ y) ⋅ z - x ⋅ :I ⋅ :END --> x ⋅ :I ⋅ :U ⋅ :END + (x ⋅ y) ⋅ z == x ⋅ (y ⋅ z) + :I ⋅ :END --> :I ⋅ :U ⋅ :END :M ⋅ x ⋅ :END --> :M ⋅ x ⋅ x ⋅ :END :I ⋅ :I ⋅ :I --> :U - x ⋅ :U ⋅ :U ⋅ y --> x ⋅ y + :U ⋅ :U ⋅ y --> y end # No matter the timeout we set here, # MU is not a theorem of the MIU system -params = SaturationParams(timeout = 12, eclasslimit = 8000) +params = SaturationParams(timeout = 20, eclasslimit = 20000) start = :(M ⋅ I ⋅ END) -g = EGraph(start) -saturate!(g, miu) -@test false == areequal(g, miu, start, :(M ⋅ U ⋅ END); params = params) +@test false == test_equality(miu, start, :(M ⋅ U ⋅ END); params) +# Examples given in Douglas Hofstadter: Gödel, Escher, Bach: An Eternal Golden Braid, 1999, page 44 +@test true == test_equality(miu, start, :(M ⋅ I ⋅ END); params) # (1) inital axiom +@test true == test_equality(miu, start, :(M ⋅ I ⋅ I ⋅ END); params) # (2) from (1) by Rule 2 +@test true == test_equality(miu, start, :(M ⋅ I ⋅ I ⋅ I ⋅ I ⋅ END); params) # (3) from (2) by Rule 2 [this is incorrectly given as MIII in the book] +@test true == test_equality(miu, start, :(M ⋅ I ⋅ I ⋅ I ⋅ I ⋅ U ⋅ END); params) # (4) from (3) by Rule 1 +@test true == test_equality(miu, start, :(M ⋅ U ⋅ I ⋅ U ⋅ END); params) # (5) from (4) by Rule 3 +@test true == test_equality(miu, start, :(M ⋅ U ⋅ I ⋅ U ⋅ U ⋅ I ⋅ U ⋅ END); params) # (6) from (5) by Rule 2 +@test true == test_equality(miu, start, :(M ⋅ U ⋅ I ⋅ I ⋅ U ⋅ END); params) # (7) from (6) by Rule 4 diff --git a/test/tutorials/propositional_logic.jl b/test/tutorials/propositional_logic.jl index 05367064..23722f0f 100644 --- a/test/tutorials/propositional_logic.jl +++ b/test/tutorials/propositional_logic.jl @@ -1,29 +1,31 @@ # Proving Propositional Logic Statements -using Test -using Metatheory -using TermInterface +using Metatheory, Test +include(joinpath(dirname(pathof(Metatheory)), "../examples/prove.jl")) include(joinpath(dirname(pathof(Metatheory)), "../examples/propositional_logic_theory.jl")) @testset "Prop logic" begin ex = rewrite(:(((p ⟹ q) && (r ⟹ s) && (p || r)) ⟹ (q || s)), impl) - @test prove(propositional_logic_theory, ex, 5, 10, 5000) + @test prove(propositional_logic_theory, ex, 5, 10) - @test @areequal propositional_logic_theory true ((!p == p) == false) - @test @areequal propositional_logic_theory true ((!p == !p) == true) - @test @areequal propositional_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p) - @test @areequal propositional_logic_theory p (p || p) - @test @areequal propositional_logic_theory true ((p ⟹ (p || p))) - @test @areequal propositional_logic_theory true ((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true + @test prove(propositional_logic_theory, :((!p == p) == false)) + @test prove(propositional_logic_theory, :((!p == !p) == true)) + @test test_equality(propositional_logic_theory, :((!p || !p) == !p), :(!p || p), :(!(!p && p))) + @test prove(propositional_logic_theory, :((p || p) == p)) + @test prove(propositional_logic_theory, :((p ⟹ (p || p)))) + @test prove(propositional_logic_theory, :((p ⟹ (p || p)) == ((!(p) && q) ⟹ q))) - # Frege's theorem - @test @areequal propositional_logic_theory true (p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r)) + @test prove(propositional_logic_theory, :((p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r))))# Frege's theorem - # Demorgan's - @test @areequal propositional_logic_theory true (!(p || q) == (!p && !q)) - - # Consensus theorem - # @test_broken @areequal propositional_logic_theory true ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z)) + @test prove(propositional_logic_theory, :(!(p || q) == (!p && !q))) # Demorgan's end + +# Consensus theorem +# @test_broken test_equality( +# propositional_logic_theory, +# :((x && y) || (!x && z) || (y && z)), +# :((x && y) || (!x && z)), +# true, +# ) diff --git a/test/tutorials/while_interpreter.jl b/test/tutorials/while_interpreter.jl index 72c9b459..6c8c6b46 100644 --- a/test/tutorials/while_interpreter.jl +++ b/test/tutorials/while_interpreter.jl @@ -40,7 +40,8 @@ using Test, Metatheory # For example, if a `σ::Mem` holds the value `σ[:a] = 2`, this means that at that given moment, in our program # the variable `a` holds the value 2. -Mem = Dict{Symbol,Union{Bool,Int}} +const WhileLangValue = Union{Bool,Int} +Mem = Dict{Symbol,WhileLangValue} # We are now ready to define our first rewrite rule. # In WHILE, un-evaluated expressions are represented by a tuple of `(program, state)`. @@ -159,7 +160,7 @@ eval_bool(ex, mem) = strategy(bool_rules)(:($ex, $mem)) eval_bool(:((false || false) || !(false || false)), Mem(:x => 2)) == true eval_bool(:((2 < 3) && (3 < 4)), Mem(:x => 2)) == true eval_bool(:((2 < x) || !(3 < 4)), Mem(:x => 2)) == false - eval_bool(:((2 < x) || !(3 < 4)), Mem(:x => 4)) == true + eval_bool(:((2 < x)), Mem(:x => 4)) == true ], ) @@ -202,8 +203,10 @@ end # `store(a, 5)` will store the value 5 in the `a` variable inside the program's memory. write_mem = @theory sym val σ begin - (store(sym::Symbol, val), σ) => (σ[sym] = eval_if(val, σ); - σ) + (store(sym::Symbol, val), σ) => begin + σ[sym] = eval_if(val, σ) + σ + end end # ## While loops and sequential computation. @@ -213,7 +216,7 @@ while_rules = @theory guard a b σ begin ((:skip; b), σ::Mem) --> (b, σ) (seq(a, b), σ::Mem) --> (b, merge((a, σ), σ)) merge(a::Mem, σ::Mem) => merge(σ, a) - merge(a::Union{Bool,Int}, σ::Mem) --> σ + merge(a::WhileLangValue, σ::Mem) --> σ (loop(guard, a), σ::Mem) --> (cond(guard, seq(a, loop(guard, a)), :skip), σ) end diff --git a/test/unit/rules.jl b/test/unit/rules.jl new file mode 100644 index 00000000..8989e4a2 --- /dev/null +++ b/test/unit/rules.jl @@ -0,0 +1,48 @@ +using Metatheory, Test + +@testset "Fully Qualified Function names" begin + r = @rule Main.identity(~a) --> ~a + + @test operation(r.left) == identity + @test r.right == PatVar(:a, 1) + + expr = :(Main.test(11, 12)) + rule = @rule Main.test(~a, ~b) --> ~b + @test rule(expr) == 12 +end + +@testset begin + r = @rule f(~x) --> ~x + + @test isempty(r.name) + + r = @rule "totti" f(~x) --> ~x + @test r.name == "totti" + @test operation(r.left) == :f + @test arguments(r.left) == [PatVar(:x, 1)] + @test r.right == PatVar(:x, 1) +end + + +@testset "String representation" begin + r = @rule f(~x) --> ~x + r == eval(:(@rule $(Meta.parse(repr(r))))) + + r = @rule Main.f(~~x) --> ~x + r == eval(:(@rule $(Meta.parse(repr(r))))) +end + + +@testset "EqualityRule to DirectedRule(s)" begin + r = @rule "distributive" x y z x*(y + z) == x*y + x*z + r_ltr = @rule "distributive" x y z x * (y + z) --> x*y + x*z + r_rtl = @rule "distributive" x y z x*y + x*z --> x * (y + z) + r1 = direct(r) + r2 = Metatheory.direct_right_to_left(r) + + @test r1 isa DirectedRule + @test r2 isa DirectedRule + @test repr(r1) == repr(r_ltr) + @test repr(r2) == repr(r_rtl) +end +