From b350b61368475259251b91dcb9c89c0914415f1b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 7 May 2020 11:26:46 -0400 Subject: [PATCH 1/2] document context --- src/matchers.jl | 4 ++-- src/rule_dsl.jl | 22 ++++++++++++++++++---- src/simplify.jl | 32 ++++++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 6e277c1b0..e8486772f 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -29,8 +29,8 @@ Base.show(io::IO, s::Segment) = (print(io, "~~"); print(io, s.name)) makesegment(s::Symbol, keys) = (push!(keys, s); Segment(s)) """ -A wrapper for slot and segment predicates which allows them to -take two arguments: the value and a Context +A wrapper indicating that the function inside must be called with +2 arguments. An expression, and the current context. """ struct Contextual{F} f::F diff --git a/src/rule_dsl.jl b/src/rule_dsl.jl index dea9aac16..832fede1e 100644 --- a/src/rule_dsl.jl +++ b/src/rule_dsl.jl @@ -40,7 +40,7 @@ function (r::Rule)(term, ctx=EmptyCtx()) end """ - `@rule LHS => RHS` + @rule LHS => RHS Creates a `Rule` object. A rule object is callable, and takes an expression and rewrites it if it matches the LHS pattern to the RHS pattern, returns `nothing` otherwise. @@ -141,6 +141,19 @@ sin((a + c)) ``` Predicate function gets an array of values if attached to a segment variable (`~~x`). + +**Context**: + +_In predicates_: Contextual predicates are functions wrapped in the `Contextual` type. +The function is called with 2 arguments: the expression and a context object +passed during a call to the Rule object (maybe done by passing a context to `simplify` or +a `RuleSet` object). + +The function can use the inputs however it wants, and must return a boolean indicating +whether the predicate holds or not. + +_In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side +of an expression. """ macro rule(expr) @assert expr.head == :call && expr.args[1] == :(=>) @@ -206,11 +219,12 @@ end #### Rulesets """ - RuleSet(rules::Vector{AbstractRules})(expr; depth=typemax(Int), applyall=false, recurse=true) + RuleSet(rules::Vector{AbstractRules}, context=EmptyCtx())(expr; depth=typemax(Int), applyall=false, recurse=true) -`RuleSet` is an `AbstractRule` which applies the given `rules` throughout an `expr`. +`RuleSet` is an `AbstractRule` which applies the given `rules` throughout an `expr` with the +context `context`. -`RuleSet(rules)(expr)` Note that this only applies the rules in one pass, not until there are no +Note that this only applies the rules in one pass, not until there are no changes to be applied. Use `SymbolicUtils.fixpoint(ruleset, expr)` to apply a RuleSet until there are no changes. diff --git a/src/simplify.jl b/src/simplify.jl index d3667ebf9..78b7f1238 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -1,16 +1,36 @@ ##### Numeric simplification -default_rules(::EmptyCtx) = SIMPLIFY_RULES """ - simplify(x, [rules=SIMPLIFY_RULES]; fixpoint=true, applyall=true, recurse=true) + default_rules(context::T)::RuleSet -Apply a `RuleSet` of rules provided in `rules`. By default -these rules are `SymbolicUtils.SIMPLIFY_RULES`. If `fixpoint=true` -repeatedly applies the set of rules until there are no changes. +The `RuleSet` to be used by default for a given context. Julia packages +defining their own context types should define this method. + +By default, returns SIMPLIFY_RULES +""" +default_rules(::Any) = SIMPLIFY_RULES + +""" + simplify(x, ctx=EmptyCtx(); + rules=default_rules(ctx), + fixpoint=true, + applyall=true, + recurse=true) + +Simplify an expression by applying `rules` until there are no changes. +The second argument, the context is passed to every [`Contextual`](#Contextual) +predicate and can be accessed as `(@ctx)` in the right hand side of `@rule` expression. + +By default the context is an `EmptyCtx()` -- which means there is no contextual information. +Any arbitrary type can be used as a context, and packages defining their own contexts +should define `default_rules(ctx::TheContextType)` to return a `RuleSet` that will +be used by default while simplifying under that context. + +If `fixpoint=true` this will repeatedly apply the set of rules until there are no changes. Applies them once if `fixpoint=false`. The `applyall` and `recurse` keywords are forwarded to the enclosed -`RuleSet`. +`RuleSet`, they are mainly used for internal optimization. """ function simplify(x, ctx=EmptyCtx(); rules=default_rules(ctx), fixpoint=true, applyall=true, recurse=true) if fixpoint From a49ed2aa991cf19d623d3d540f01db9b7debfaf3 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 7 May 2020 11:49:19 -0400 Subject: [PATCH 2/2] also use expression to pick default ruleset --- src/simplify.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/simplify.jl b/src/simplify.jl index 78b7f1238..6de3b6aeb 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -1,18 +1,19 @@ ##### Numeric simplification """ - default_rules(context::T)::RuleSet + default_rules(expr, context::T)::RuleSet -The `RuleSet` to be used by default for a given context. Julia packages -defining their own context types should define this method. +The `RuleSet` to be used by default for a given expression and the context. +Julia packages defining their own context types should define this method. -By default, returns SIMPLIFY_RULES +By default SymbolicUtils will try to apply appropriate rules for expressions +of symtype Number. """ -default_rules(::Any) = SIMPLIFY_RULES +default_rules(x, ctx) = SIMPLIFY_RULES """ simplify(x, ctx=EmptyCtx(); - rules=default_rules(ctx), + rules=default_rules(x, ctx), fixpoint=true, applyall=true, recurse=true) @@ -32,7 +33,7 @@ Applies them once if `fixpoint=false`. The `applyall` and `recurse` keywords are forwarded to the enclosed `RuleSet`, they are mainly used for internal optimization. """ -function simplify(x, ctx=EmptyCtx(); rules=default_rules(ctx), fixpoint=true, applyall=true, recurse=true) +function simplify(x, ctx=EmptyCtx(); rules=default_rules(x, ctx), fixpoint=true, applyall=true, recurse=true) if fixpoint SymbolicUtils.fixpoint(rules, x, ctx; recurse=recurse, applyall=recurse) else