Skip to content

Commit

Permalink
Fix unflatten_long_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Feb 19, 2021
1 parent ddcba6c commit 7d5839d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
end

function unflatten_args(f, args, N=4)
length(args) < N && return Term{Number}(f, args)
unflatten_args(f, [Term{Number}(f, group)
length(args) < N && return Term{Real}(f, args)
unflatten_args(f, [Term{Real}(f, group)
for group in Iterators.partition(args, N)], N)
end

function unflatten_long_ops(op, N=4)
op = value(op)
!istree(op) && return
!istree(op) && return Num(op)
rule1 = @rule((+)(~~x) => length(~~x) > N ? unflatten_args(+, ~~x, 4) : nothing)
rule2 = @rule((*)(~~x) => length(~~x) > N ? unflatten_args(*, ~~x, 4) : nothing)

Rewriters.Postwalk(Rewriters.Chain([rule1, rule2]))(op)
Num(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2]))(op))
end


Expand Down

0 comments on commit 7d5839d

Please sign in to comment.