Skip to content

Commit

Permalink
Enables full (nested) associative-commutative pattern matching for `+…
Browse files Browse the repository at this point in the history
…` and `*` operators.
  • Loading branch information
zengmao committed Aug 29, 2024
1 parent aab293a commit 0a380ed
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
34 changes: 28 additions & 6 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#
function matcher(val::Any)
iscall(val) && return term_matcher(val)
matcher(val, false)
end

# `fullac_flag == true` enables fully nested associative-commutative pattern matching
function matcher(val::Any, fullac_flag)
iscall(val) && return term_matcher(val, fullac_flag)
function literal_matcher(next, data, bindings)
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
end
end

function matcher(slot::Slot)
function matcher(slot::Slot, fullac_flag) # fullac_flag unused but needed to keep the interface uniform
function slot_matcher(next, data, bindings)
!islist(data) && return
val = get(bindings, slot.name, nothing)
Expand Down Expand Up @@ -56,7 +61,7 @@ function trymatchexpr(data, value, n)
end
end

function matcher(segment::Segment)
function matcher(segment::Segment, fullac_flag) # fullac_flag unused but needed to keep the interface uniform
function segment_matcher(success, data, bindings)
val = get(bindings, segment.name, nothing)

Expand Down Expand Up @@ -84,8 +89,8 @@ function matcher(segment::Segment)
end
end

function term_matcher(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
function term_matcher(term, fullac_flag = false)
matchers = (matcher(operation(term), fullac_flag), map(a -> matcher(a, fullac_flag), arguments(term))...,)
function term_matcher(success, data, bindings)

!islist(data) && return nothing
Expand All @@ -103,6 +108,23 @@ function term_matcher(term)
end
end

loop(car(data), bindings, matchers) # Try to eat exactly one term
if !(fullac_flag && iscall(term) && operation(term) in ((+), (*)))
loop(car(data), bindings, matchers) # Try to eat exactly one term
else # try all permutations of `car(data)` to see if a match is possible
data1 = car(data)
args = arguments(data1)
op = operation(data1)
data_arg_perms = permutations(args)
result = nothing
T = symtype(data)
for perm in data_arg_perms
data_permuted = Term{T}(op, perm)
result = loop(data_permuted, bindings, matchers) # Try to eat exactly one term
if !(result isa Nothing)
break
end
end
return result
end
end
end
25 changes: 23 additions & 2 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,30 @@ whether the predicate holds or not.
_In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side
of an expression.
**Full (nested) associative-commutative matching**:
@rule LHS => RHS fullac
creates a rule that fully respects associative-commutative (AC) operations. Unlike `@acrule LHS => RHS` which only considers AC properties of the top-level function, here we impose AC properties on all subexpressions.
```
julia> @syms a b;
julia> r = @rule ~a + ~a*~b => ~a * (1+~b) fullac;
julia> r(b + a*b)
(1 + a)*b
```
"""
macro rule(expr)
macro rule(expr, option...)
@assert expr.head == :call && expr.args[1] == :(=>)
fullac = false
if length(option) > 0
@assert option[1] == :fullac "@rule only accepts one option `fullac` after the rule itself"
fullac = true
end
lhs = expr.args[2]
rhs = rewrite_rhs(expr.args[3])
keys = Symbol[]
Expand All @@ -310,7 +331,7 @@ macro rule(expr)
lhs_pattern = $(lhs_term)
Rule($(QuoteNode(expr)),
lhs_pattern,
matcher(lhs_pattern),
matcher(lhs_pattern, $fullac),
__MATCHES__ -> $(makeconsequent(rhs)),
rule_depth($lhs_term))
end
Expand Down
8 changes: 8 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ end
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])
end

@testset "Full associative-commutative matching" begin
@eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(a + a*b) == a * (1+b)
@eqtest (@rule ~a + ~a*~b => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule
@eqtest (@rule ~a*~b + ~a => ~a * (1+~b) fullac)(b + a*b) == b * (1+a) # fails with @acrule
@eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + a*c) == a * (b+c)
@eqtest (@rule ~a*~b + ~a*~c => ~a * (~b+~c) fullac)(a*b + b*c) == b * (a+c) # fails with @acrule
end

using SymbolicUtils: @capture

@testset "Capture form" begin
Expand Down

0 comments on commit 0a380ed

Please sign in to comment.