diff --git a/src/code.jl b/src/code.jl index 1eb720516..6674da274 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,6 @@ module Code -using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra +using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, @@ -96,7 +96,30 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2]) toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st))) -function_to_expr(op, args, st) = nothing +const NaNMathFuns = ( + sin, + cos, + tan, + asin, + acos, + acosh, + atanh, + log, + log2, + log10, + lgamma, + log1p, + sqrt, +) +function function_to_expr(op, O, st) + op in NaNMathFuns || return nothing + name = nameof(op) + fun = GlobalRef(NaNMath, name) + args = map(Base.Fix2(toexpr, st), arguments(O)) + expr = Expr(:call, fun) + append!(expr.args, args) + return expr +end function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) out = get(st.rewrites, O, nothing) diff --git a/test/code.jl b/test/code.jl index 636023a54..5675b878d 100644 --- a/test/code.jl +++ b/test/code.jl @@ -189,6 +189,10 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen f = eval(toexpr(Func([a+b], [], a+b))) @test f(1) == 1 @test f(2) == 2 + + f = eval(toexpr(Func([a, b], [], sqrt(a - b)))) + @test isnan(f(0, 10)) + @test f(10, 2) ≈ sqrt(8) end let