From 43262130b4297f79e5dd15dc17ced4c5ead94aab Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 31 Jul 2023 18:12:13 -0400 Subject: [PATCH 1/3] Register NaNMath functions --- src/methods.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index bdd96512b..17b21b39b 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -1,3 +1,4 @@ +import NaNMath import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma, trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi, airybiprime, besselj0, @@ -12,9 +13,12 @@ const monadic = [deg2rad, rad2deg, transpose, asind, log1p, acsch, atand, sec, acscd, cot, exp2, expm1, atanh, gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma, trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi, - airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite] + airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite, + NaNMath.sin, NaNMath.cos, NaNMath.tan, NaNMath.asin, NaNMath.acos, + NaNMath.acosh, NaNMath.atanh, NaNMath.log, NaNMath.log2, + NaNMath.log10, NaNMath.lgamma, NaNMath.log1p, NaNMath.sqrt] -const diadic = [max, min, hypot, atan, mod, rem, copysign, +const diadic = [max, min, hypot, atan, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk, hankelh1, hankelh2, polygamma, beta, logbeta] const previously_declared_for = Set([]) From 1189564c8b25235ad216d1d2572ec110549d8c53 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 31 Jul 2023 18:20:33 -0400 Subject: [PATCH 2/3] Fix tests --- test/code.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/code.jl b/test/code.jl index 5675b878d..32d773cdf 100644 --- a/test/code.jl +++ b/test/code.jl @@ -1,4 +1,5 @@ using Test, SymbolicUtils +using NaNMath using SymbolicUtils.Code using SymbolicUtils.Code: LazyState using StaticArrays @@ -84,10 +85,10 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen @test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall test_repr(toexpr(LiteralExpr(:(let x=1, y=2 - $(sin(a+b)) + $(NaNMath.sin(a+b)) end))), :(let x = 1, y = 2 - $(sin)($(+)(a, b)) + $(NaNMath.sin)($(+)(a, b)) end)) test_repr(toexpr(MakeArray([a,b,a+b], :arr)), From 4f1540bd3aa20ef411795fe8c4532521dfb82a10 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 31 Jul 2023 20:02:37 -0400 Subject: [PATCH 3/3] Add nanmath as an option --- src/code.jl | 2 +- test/code.jl | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/code.jl b/src/code.jl index 6674da274..23ca867e4 100644 --- a/src/code.jl +++ b/src/code.jl @@ -112,7 +112,7 @@ const NaNMathFuns = ( sqrt, ) function function_to_expr(op, O, st) - op in NaNMathFuns || return nothing + (get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing name = nameof(op) fun = GlobalRef(NaNMath, name) args = map(Base.Fix2(toexpr, st), arguments(O)) diff --git a/test/code.jl b/test/code.jl index 32d773cdf..1cf181a78 100644 --- a/test/code.jl +++ b/test/code.jl @@ -8,6 +8,8 @@ using SparseArrays using LinearAlgebra test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b)) +nanmath_st = Code.NameState() +nanmath_st.rewrites[:nanmath] = true @testset "Code" begin @syms a b c d e p q t x(t) y(t) z(t) @@ -84,11 +86,18 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen end) @test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall + f = GlobalRef(NaNMath, :sin) test_repr(toexpr(LiteralExpr(:(let x=1, y=2 - $(NaNMath.sin(a+b)) + $(sin(a+b)) + end)), nanmath_st), + :(let x = 1, y = 2 + $(f)($(+)(a, b)) + end)) + test_repr(toexpr(LiteralExpr(:(let x=1, y=2 + $(sin(a+b)) end))), :(let x = 1, y = 2 - $(NaNMath.sin)($(+)(a, b)) + $(sin)($(+)(a, b)) end)) test_repr(toexpr(MakeArray([a,b,a+b], :arr)), @@ -191,7 +200,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen @test f(1) == 1 @test f(2) == 2 - f = eval(toexpr(Func([a, b], [], sqrt(a - b)))) + f = eval(toexpr(Func([a, b], [], sqrt(a - b)), nanmath_st)) @test isnan(f(0, 10)) @test f(10, 2) ≈ sqrt(8) end