Skip to content

Commit

Permalink
Merge pull request #536 from JuliaSymbolics/myb/reg
Browse files Browse the repository at this point in the history
Register NaNMath functions
  • Loading branch information
YingboMa authored Aug 1, 2023
2 parents 325b673 + 4f1540b commit 5a923ea
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ const NaNMathFuns = (
sqrt,
)
function function_to_expr(op, O, st)
op in NaNMathFuns || return nothing
(get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing
name = nameof(op)
fun = GlobalRef(NaNMath, name)
args = map(Base.Fix2(toexpr, st), arguments(O))
Expand Down
8 changes: 6 additions & 2 deletions src/methods.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import NaNMath
import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
dawson, digamma, trigamma, invdigamma, polygamma,
airyai, airyaiprime, airybi, airybiprime, besselj0,
Expand All @@ -12,9 +13,12 @@ const monadic = [deg2rad, rad2deg, transpose, asind, log1p, acsch,
atand, sec, acscd, cot, exp2, expm1, atanh, gamma,
loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma,
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite]
airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite,
NaNMath.sin, NaNMath.cos, NaNMath.tan, NaNMath.asin, NaNMath.acos,
NaNMath.acosh, NaNMath.atanh, NaNMath.log, NaNMath.log2,
NaNMath.log10, NaNMath.lgamma, NaNMath.log1p, NaNMath.sqrt]

const diadic = [max, min, hypot, atan, mod, rem, copysign,
const diadic = [max, min, hypot, atan, NaNMath.atanh, mod, rem, copysign,
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
polygamma, beta, logbeta]
const previously_declared_for = Set([])
Expand Down
12 changes: 11 additions & 1 deletion test/code.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, SymbolicUtils
using NaNMath
using SymbolicUtils.Code
using SymbolicUtils.Code: LazyState
using StaticArrays
Expand All @@ -7,6 +8,8 @@ using SparseArrays
using LinearAlgebra

test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
nanmath_st = Code.NameState()
nanmath_st.rewrites[:nanmath] = true

@testset "Code" begin
@syms a b c d e p q t x(t) y(t) z(t)
Expand Down Expand Up @@ -83,6 +86,13 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
end)
@test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall

f = GlobalRef(NaNMath, :sin)
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
$(sin(a+b))
end)), nanmath_st),
:(let x = 1, y = 2
$(f)($(+)(a, b))
end))
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
$(sin(a+b))
end))),
Expand Down Expand Up @@ -190,7 +200,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
@test f(1) == 1
@test f(2) == 2

f = eval(toexpr(Func([a, b], [], sqrt(a - b))))
f = eval(toexpr(Func([a, b], [], sqrt(a - b)), nanmath_st))
@test isnan(f(0, 10))
@test f(10, 2) sqrt(8)
end
Expand Down

0 comments on commit 5a923ea

Please sign in to comment.