Skip to content

Commit

Permalink
refactor: don't use NaNMath.pow in codegen rewriters if integral ex…
Browse files Browse the repository at this point in the history
…ponent
  • Loading branch information
AayushSabharwal committed Jan 13, 2025
1 parent da3bd6d commit 194aa85
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
18 changes: 7 additions & 11 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,14 @@ end

function function_to_expr(op::typeof(^), O, st)
args = arguments(O)
if length(args) == 2 && args[2] isa Real && args[2] < 0
ex = args[1]
if args[2] == -1
return toexpr(Term(inv, Any[ex]), st)
else
args = Any[Term(inv, Any[ex]), -args[2]]
op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow
return toexpr(Term(op, args), st)
end
if args[2] isa Real && args[2] < 0
args[1] = Term(inv, Any[args[1]])
args[2] = -args[2]
end
if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer)
op = NaNMath.pow
end
get(st.rewrites, :nanmath, false) === true || return nothing
return toexpr(Term(NaNMath.pow, args), st)
return toexpr(Term(op, args), st)
end

function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
Expand Down
10 changes: 5 additions & 5 deletions test/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,19 @@ nanmath_st.rewrites[:nanmath] = true
@test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b))

@test toexpr(a^2) == :($(^)(a, 2))
@test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2))
@test toexpr(a^2, nanmath_st) == :($(^)(a, 2))
@test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2))
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2))
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(^)(a, 2))

@test toexpr(a^-1) == :($(/)(1, a))
@test toexpr(a^-1, nanmath_st) == :($(/)(1, a))
@test toexpr(NaNMath.pow(a, -1)) == :($(inv)(a))
@test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a))

@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2)))
@test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)($(inv)(a), 2))
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)($(inv)(a), 2))
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(^)(a, 2)))
@test toexpr(NaNMath.pow(a, -2)) == :($(^)($(inv)(a), 2))
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(^)($(inv)(a), 2))

f = GlobalRef(NaNMath, :sin)
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
Expand Down

0 comments on commit 194aa85

Please sign in to comment.