Skip to content

Commit

Permalink
avoid fixpoint in unflatten_long_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed Feb 19, 2021
1 parent ae4f53c commit ddcba6c
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,19 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
_build_function(target,args...;kwargs...)
end

function unflatten_long_ops(op, N=4)
rule1 = @rule((+)((~~x)) => length(~~x) > N ?
+(+((~~x)[1:N]...) + (+)((~~x)[N+1:end]...)) : nothing)
rule2 = @rule((*)((~~x)) => length(~~x) > N ?
*(*((~~x)[1:N]...) * (*)((~~x)[N+1:end]...)) : nothing)
function unflatten_args(f, args, N=4)
length(args) < N && return Term{Number}(f, args)
unflatten_args(f, [Term{Number}(f, group)
for group in Iterators.partition(args, N)], N)
end

function unflatten_long_ops(op, N=4)
op = value(op)
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
!istree(op) && return
rule1 = @rule((+)(~~x) => length(~~x) > N ? unflatten_args(+, ~~x, 4) : nothing)
rule2 = @rule((*)(~~x) => length(~~x) > N ? unflatten_args(*, ~~x, 4) : nothing)

Rewriters.Postwalk(Rewriters.Chain([rule1, rule2]))(op)
end


Expand All @@ -76,7 +81,7 @@ function _build_function(target::JuliaTarget, op, args...;
linenumbers = true)

dargs = map(destructure_arg, [args...])
expr = toexpr(Func(dargs, [], op))
expr = toexpr(Func(dargs, [], unflatten_long_ops(op)))

if expression == Val{true}
expr
Expand Down Expand Up @@ -247,7 +252,7 @@ function _make_array(rhss::AbstractArray, similarto)
end
end

_make_array(x, similarto) = x
_make_array(x, similarto) = unflatten_long_ops(x)

## In-place version

Expand Down Expand Up @@ -298,7 +303,7 @@ function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros
end)
end

_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = rhs
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = unflatten_long_ops(rhs)


function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
Expand Down

0 comments on commit ddcba6c

Please sign in to comment.