Skip to content

Commit

Permalink
CSE only exprs that appear more than once.
Browse files Browse the repository at this point in the history
Co-authored-by: "Yingbo Ma" <[email protected]>
  • Loading branch information
shashi and YingboMa committed Sep 28, 2021
1 parent 1ba01aa commit 2c60e25
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
SpawnFetch, Multithreaded, cse

import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, Sym, Term, istree, operation, arguments,
symtype, similarterm
symtype, similarterm, unsorted_arguments, metadata

##== state management ==##

Expand Down Expand Up @@ -618,17 +619,15 @@ function _cse!(mem, expr)
end

function cse(expr)
!istree(expr) && return expr
assignments = Assignment[]
final = _cse!((assignments, Dict{Any,Int}()), expr)
Let(assignments, final)
state = Dict{Any, Int}()
cse_state!(state, expr)
cse_block(state, expr)
end


function _cse(exprs::AbstractArray)
assignments = Assignment[]
mem = (assignments, Dict{Any,Int}())
final = map(Base.Fix1(_cse!, mem), exprs)
assignments, final
letblock = cse(Term{Any}(tuple, exprs))
letblock.pairs, arguments(letblock.body)
end

function cse(x::MakeArray)
Expand All @@ -652,4 +651,37 @@ function cse(x::MakeSparseArray)
end
end


function cse_state!(state, t)
!istree(t) && return t
state[t] = Base.get!(state, t, 0) + 1
foreach(x->cse_state!(state, x), unsorted_arguments(t))
end

function cse_block!(assignments, counter, names, name, state, x)
if get(state, x, 0) > 1
if haskey(names, x)
return names[x]
else
sym = Sym{symtype(x)}(Symbol(name, counter[]))
names[x] = sym
push!(assignments, sym x)
counter[] += 1
return sym
end
else
return x
end
end

function cse_block(state, t, name=Symbol("var-", hash(t)))
assignments = Assignment[]
counter = Ref{Int}(1)
names = Dict{Any, Sym}()
t′ = Rewriters.Postwalk(x->cse_block!(assignments, counter, names, name, state, x),
similarterm = (t, f, args) -> similarterm(t, f, args, symtype(t),
metadata=metadata(t)))(t)
Let(assignments, t′)
end

end

0 comments on commit 2c60e25

Please sign in to comment.