Skip to content

Commit

Permalink
Rewrite ^ with NaNMath.pow in nanmath-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Nov 8, 2024
1 parent a587847 commit 1c38ee8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicUtils"
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
authors = ["Shashi Gowda"]
version = "3.7.2"
version = "3.7.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
9 changes: 6 additions & 3 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,20 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
end
end

function function_to_expr(::typeof(^), O, st)
function function_to_expr(op::typeof(^), O, st)
args = arguments(O)
if length(args) == 2 && args[2] isa Real && args[2] < 0
ex = args[1]
if args[2] == -1
return toexpr(Term(inv, Any[ex]), st)
else
return toexpr(Term(^, Any[Term(inv, Any[ex]), -args[2]]), st)
args = Any[Term(inv, Any[ex]), -args[2]]
op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow
return toexpr(Term(op, args), st)
end
end
return nothing
get(st.rewrites, :nanmath, false) === true || return nothing
return toexpr(Term(NaNMath.pow, args), st)
end

function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
Expand Down
30 changes: 27 additions & 3 deletions test/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ nanmath_st.rewrites[:nanmath] = true
@test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e))
@test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e))
@test toexpr(a+b) == :($(+)(a, b))
@test toexpr(a^b) == :($(^)(a, b))
@test toexpr(a^2) == :($(^)(a, 2))
@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
@test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t)))
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t))))
s = LazyState()
Expand Down Expand Up @@ -87,8 +84,35 @@ nanmath_st.rewrites[:nanmath] = true
end)
@test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall

for fname in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, :log1p, :sqrt)
f = getproperty(Base, fname)
@test toexpr(f(a)) == :($f(a))
@test toexpr(f(a), nanmath_st) == :($(GlobalRef(NaNMath, fname))(a))

nanmath_f = getproperty(NaNMath, fname)
@test toexpr(nanmath_f(a)) == :($nanmath_f(a))
@test toexpr(nanmath_f(a), nanmath_st) == :($nanmath_f(a))
end

@test toexpr(a^b) == :($(^)(a, b))
@test toexpr(a^b, nanmath_st) == :($(NaNMath.pow)(a, b))
@test toexpr(NaNMath.pow(a, b)) == :($(NaNMath.pow)(a, b))
@test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b))

@test toexpr(a^2) == :($(^)(a, 2))
@test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2))
@test toexpr(NaNMath.pow(a, 2)) == :($(NaNMath.pow)(a, 2))
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2))

@test toexpr(a^-1) == :($(/)(1, a))
@test toexpr(a^-1, nanmath_st) == :($(/)(1, a))
@test toexpr(NaNMath.pow(a, -1)) == :($(NaNMath.pow)(a, -1))
@test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(NaNMath.pow)(a, -1))

@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2)))
@test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)(a, -2))
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)(a, -2))

f = GlobalRef(NaNMath, :sin)
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
Expand Down

0 comments on commit 1c38ee8

Please sign in to comment.