From 1c38ee8ad9c93d5bdc1d66b7d24805edf4324d47 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 8 Nov 2024 09:49:38 +0100 Subject: [PATCH] Rewrite `^` with `NaNMath.pow` in nanmath-mode --- Project.toml | 2 +- src/code.jl | 9 ++++++--- test/code.jl | 30 +++++++++++++++++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index ec3162512..db1566020 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "3.7.2" +version = "3.7.3" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/code.jl b/src/code.jl index 4128a39fd..8d3b9599e 100644 --- a/src/code.jl +++ b/src/code.jl @@ -138,17 +138,20 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) end end -function function_to_expr(::typeof(^), O, st) +function function_to_expr(op::typeof(^), O, st) args = arguments(O) if length(args) == 2 && args[2] isa Real && args[2] < 0 ex = args[1] if args[2] == -1 return toexpr(Term(inv, Any[ex]), st) else - return toexpr(Term(^, Any[Term(inv, Any[ex]), -args[2]]), st) + args = Any[Term(inv, Any[ex]), -args[2]] + op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow + return toexpr(Term(op, args), st) end end - return nothing + get(st.rewrites, :nanmath, false) === true || return nothing + return toexpr(Term(NaNMath.pow, args), st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) diff --git a/test/code.jl b/test/code.jl index c05200167..0e25437a1 100644 --- a/test/code.jl +++ b/test/code.jl @@ -20,9 +20,6 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e)) @test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e)) @test toexpr(a+b) == :($(+)(a, b)) - @test toexpr(a^b) == :($(^)(a, b)) - @test toexpr(a^2) == :($(^)(a, 2)) - @test toexpr(a^-2) == :($(/)(1, $(^)(a, 2))) @test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t))) @test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t)))) s = LazyState() @@ -87,8 +84,35 @@ nanmath_st.rewrites[:nanmath] = true end) @test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall + for fname in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, :log1p, :sqrt) + f = getproperty(Base, fname) + @test toexpr(f(a)) == :($f(a)) + @test toexpr(f(a), nanmath_st) == :($(GlobalRef(NaNMath, fname))(a)) + nanmath_f = getproperty(NaNMath, fname) + @test toexpr(nanmath_f(a)) == :($nanmath_f(a)) + @test toexpr(nanmath_f(a), nanmath_st) == :($nanmath_f(a)) + end + + @test toexpr(a^b) == :($(^)(a, b)) + @test toexpr(a^b, nanmath_st) == :($(NaNMath.pow)(a, b)) @test toexpr(NaNMath.pow(a, b)) == :($(NaNMath.pow)(a, b)) + @test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b)) + + @test toexpr(a^2) == :($(^)(a, 2)) + @test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2)) + @test toexpr(NaNMath.pow(a, 2)) == :($(NaNMath.pow)(a, 2)) + @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2)) + + @test toexpr(a^-1) == :($(/)(1, a)) + @test toexpr(a^-1, nanmath_st) == :($(/)(1, a)) + @test toexpr(NaNMath.pow(a, -1)) == :($(NaNMath.pow)(a, -1)) + @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(NaNMath.pow)(a, -1)) + + @test toexpr(a^-2) == :($(/)(1, $(^)(a, 2))) + @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2))) + @test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)(a, -2)) + @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)(a, -2)) f = GlobalRef(NaNMath, :sin) test_repr(toexpr(LiteralExpr(:(let x=1, y=2