Skip to content

Commit

Permalink
Merge pull request #535 from JuliaSymbolics/myb/nanmath
Browse files Browse the repository at this point in the history
Use NaNMath lowering by default
  • Loading branch information
YingboMa authored Jul 26, 2023
2 parents 3dc99d4 + 382a88f commit 7ef8014
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/code.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Code

using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions

export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
Expand Down Expand Up @@ -96,7 +96,30 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2])

toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))

function_to_expr(op, args, st) = nothing
const NaNMathFuns = (
sin,
cos,
tan,
asin,
acos,
acosh,
atanh,
log,
log2,
log10,
lgamma,
log1p,
sqrt,
)
function function_to_expr(op, O, st)
op in NaNMathFuns || return nothing
name = nameof(op)
fun = GlobalRef(NaNMath, name)
args = map(Base.Fix2(toexpr, st), arguments(O))
expr = Expr(:call, fun)
append!(expr.args, args)
return expr
end

function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
out = get(st.rewrites, O, nothing)
Expand Down
4 changes: 4 additions & 0 deletions test/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
f = eval(toexpr(Func([a+b], [], a+b)))
@test f(1) == 1
@test f(2) == 2

f = eval(toexpr(Func([a, b], [], sqrt(a - b))))
@test isnan(f(0, 10))
@test f(10, 2) sqrt(8)
end

let
Expand Down

0 comments on commit 7ef8014

Please sign in to comment.