Skip to content

Commit

Permalink
Merge pull request #154 from JuliaSymbolics/s/fast-terms
Browse files Browse the repository at this point in the history
WIP: constructor-level simplification
  • Loading branch information
shashi authored Jan 9, 2021
2 parents 2461789 + c62beb7 commit 7d6f362
Show file tree
Hide file tree
Showing 11 changed files with 454 additions and 93 deletions.
6 changes: 6 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ using SymbolicUtils # hide

{{doc Term Term type}}

{{doc Add Add type}}

{{doc Mul Mul type}}

{{doc Pow Pow type}}

{{doc promote_symtype promote_symtype fn}}

## Interfacing
Expand Down
37 changes: 21 additions & 16 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ where appropriate -->

The main features are:

- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
- A flexible [rule-based rewriting language](#rule-based_rewriting) allowing liberal use of user defined matchers and rewriters.
- Fast expressions
- A [combinator library](#composing-rewriters) for making rewriters.
- A [rule-based rewriting language](#rule-based_rewriting).
- Type promotion:
- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
- Set of [simplification rules](#simplification). These can be remixed and extended for special purposes.


## Table of contents

\tableofcontents <!-- you can use \toc as well -->

## Symbolic expressions
## `Sym`s

First, let's use the `@syms` macro to create a few symbols.

Expand Down Expand Up @@ -66,17 +68,6 @@ expr1 + expr2
```
\out{expr}

### Simplified printing

Tip: you can set `SymbolicUtils.show_simplified[] = true` to enable simplification on printing, or call `SymbolicUtils.showraw(expr)` to display an expression without simplification.
In the REPL, if an expression was successfully simplified before printing, it will appear in yellow rather than white, as a visual cue that what you are looking at is not the exact datastructure.

```julia:showraw
using SymbolicUtils: showraw
showraw(expr1 + expr2)
```
\out{showraw}

**Function-like symbols**

Expand Down Expand Up @@ -106,6 +97,20 @@ g(2//5, g(1, β))

This works because `g` "returns" a `Real`.


## Expression interface

Symbolic expressions are of type `Term{T}`, `Add{T}`, `Mul{T}` or `Pow{T}` and denote some function call where one or more arguments are themselves such expressions or `Sym`s.

All the expression types support the following:

- `istree(x)` -- always returns `true` denoting, `x` is not a leaf node like Sym or a literal.
- `operation(x)` -- the function being called
- `arguments(x)` -- a vector of arguments
- `symtype(x)` -- the "inferred" type (`T`)

See more on the interface [here](/interface)

## Rule-based rewriting

Rewrite rules match and transform an expression. A rule is written using either the `@rule` macro or the `@acrule` macro.
Expand Down Expand Up @@ -151,7 +156,7 @@ Notice that there is a subexpression `(2 * w) + (2 * w)` that could be simplifie

### Predicates for matching

Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression (a `Term` object a `Sym` or any Julia value that is in the expression tree) and returns a boolean value. Such a slot will be considered a match only if `f` returns true.
Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression and returns a boolean value. Such a slot will be considered a match only if `f` returns true.

Similarly `~~x::g` is a way of attaching a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. If the same slot or segment variable appears twice in the matcher pattern, then at most one of the occurance should have a predicate.

Expand Down
6 changes: 4 additions & 2 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ module SymbolicUtils

export @syms, term, showraw

# Sym, Term and other types
# Sym, Term,
# Add, Mul and Pow
using DataStructures
import Base: +, -, *, /, \, ^
include("types.jl")

# Methods on symbolic objects
Expand Down Expand Up @@ -32,7 +35,6 @@ include("matchers.jl")
# Convert to an efficient multi-variate polynomial representation
import AbstractAlgebra.Generic: MPoly, PolynomialRing, ZZ, exponent_vector
using AbstractAlgebra: ismonomial, symbols
using DataStructures
include("abstractalgebra.jl")

# Term ordering
Expand Down
55 changes: 26 additions & 29 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
besselj1, bessely0, bessely1, besselj, bessely, besseli,
besselk, hankelh1, hankelh2, polygamma, beta, logbeta

const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
const monadic = [deg2rad, rad2deg, transpose, conj, asind, log1p, acsch,
acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh,
abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, acotd,
asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10,
Expand All @@ -14,10 +14,9 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
airybiprime, besselj0, besselj1, bessely0, bessely1]

const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign,
const diadic = [max, min, hypot, atan, mod, rem, copysign,
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
polygamma, beta, logbeta]

const previously_declared_for = Set([])

# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{})
Expand All @@ -32,13 +31,17 @@ end
islike(a, T) = symtype(a) <: T

# TODO: keep domains tighter than this
function number_methods(T, rhs1, rhs2)
function number_methods(T, rhs1, rhs2, options=nothing)
exprs = []

skip_basics = !isnothing(options) ? options == :skipbasics : false
basic_monadic = [-, +]
basic_diadic = [+, -, *, /, \, ^]

rhs2 = :($assert_like(f, Number, a, b); $rhs2)
rhs1 = :($assert_like(f, Number, a); $rhs1)

for f in diadic
for f in (skip_basics ? diadic : vcat(basic_diadic, diadic))
for S in previously_declared_for
push!(exprs, quote
(f::$(typeof(f)))(a::$T, b::$S) = $rhs2
Expand All @@ -58,25 +61,38 @@ function number_methods(T, rhs1, rhs2)
push!(exprs, expr)
end

for f in monadic
for f in (skip_basics ? monadic : vcat(basic_monadic, monadic))
push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1))
end
push!(exprs, :(push!($previously_declared_for, $T)))
Expr(:block, exprs...)
end

macro number_methods(T, rhs1, rhs2)
number_methods(T, rhs1, rhs2) |> esc
macro number_methods(T, rhs1, rhs2, options=nothing)
number_methods(T, rhs1, rhs2, options) |> esc
end

@number_methods(Sym, term(f, a), term(f, a, b))
@number_methods(Term, term(f, a), term(f, a, b))
@number_methods(Sym, term(f, a), term(f, a, b), skipbasics)
@number_methods(Term, term(f, a), term(f, a, b), skipbasics)
@number_methods(Add, term(f, a), term(f, a, b), skipbasics)
@number_methods(Mul, term(f, a), term(f, a, b), skipbasics)
@number_methods(Pow, term(f, a), term(f, a, b), skipbasics)

for f in diadic
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end

for f in [+, -, *, \, /, ^]
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end
for f in [+, -, *]
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
end

promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode)

Expand All @@ -93,25 +109,6 @@ rec_promote_symtype(f, x) = promote_symtype(f, x)
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)

# Variadic methods
for f in [+, *]

@eval (::$(typeof(f)))(x::Symbolic) = x

# single arg
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
term($f, x,w...,
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
end
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
end

Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
Expand Down
19 changes: 10 additions & 9 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ function (rw::Fixpoint)(x)
return x
end

struct Walk{ord, C, threaded}
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
similarterm::F
end

using .Threads

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100)
Walk{:post, typeof(rw), threaded}(rw, thread_cutoff)
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100)
Walk{:pre, typeof(rw), threaded}(rw, thread_cutoff)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
end

struct PassThrough{C}
Expand All @@ -128,22 +129,22 @@ end
(p::PassThrough)(x) = (y=p.rw(x); isnothing(y) ? x : y)

passthrough(x, default) = isnothing(x) ? default : x
function (p::Walk{ord, C, false})(x) where {ord, C}
function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
@assert ord === :pre || ord === :post
if istree(x)
if ord === :pre
x = p.rw(x)
end
if istree(x)
x = similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
end
return ord === :post ? p.rw(x) : x
else
return p.rw(x)
end
end

function (p::Walk{ord, C, true})(x) where {ord, C}
function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
@assert ord === :pre || ord === :post
if istree(x)
if ord === :pre
Expand All @@ -158,7 +159,7 @@ function (p::Walk{ord, C, true})(x) where {ord, C}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = similarterm(x, operation(x), args)
t = p.similarterm(x, operation(x), args)
end
return ord === :post ? p.rw(t) : t
else
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function (acr::ACRule)(term)
itr = acr.sets(eachindex(args), acr.arity)

for inds in itr
result = r(similarterm(term, f, @views args[inds]))
result = r(Term{T}(f, @views args[inds]))
if !isnothing(result)
# Assumption: inds are unique
length(args) == length(inds) && return result
Expand Down
Loading

0 comments on commit 7d6f362

Please sign in to comment.