From 9285d05d3d33a2b891da353f6da7d785a10b725a Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 Jul 2023 17:04:14 -0400 Subject: [PATCH 1/3] Add NaNMath lowering --- src/code.jl | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/code.jl b/src/code.jl index 1eb720516..3873ebff5 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,6 @@ module Code -using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra +using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, @@ -96,7 +96,29 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2]) toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st))) -function_to_expr(op, args, st) = nothing +const NaNMathFuns = ( + :sin, + :cos, + :tan, + :asin, + :acos, + :acosh, + :atanh, + :log, + :log2, + :log10, + :lgamma, + :log1p, + :sqrt, +) +function function_to_expr(op, args, st) + (op isa Function && (name = nameof(op)) in NaNMathFuns) && return nothing + fun = GlobalRef(NaNMath, name) + args = map(Base.Fix2(toexpr, st), arguments(O)) + expr = Expr(:call, fun) + expr.args = args + return expr +end function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) out = get(st.rewrites, O, nothing) From 88e2e138263b63f9ddc115a7ddf63f79bab485cd Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 Jul 2023 17:10:19 -0400 Subject: [PATCH 2/3] Fix typo and test --- src/code.jl | 6 +++--- test/code.jl | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/code.jl b/src/code.jl index 3873ebff5..52be9c85d 100644 --- a/src/code.jl +++ b/src/code.jl @@ -111,12 +111,12 @@ const NaNMathFuns = ( :log1p, :sqrt, ) -function function_to_expr(op, args, st) - (op isa Function && (name = nameof(op)) in NaNMathFuns) && return nothing +function function_to_expr(op, O, st) + (op isa Function && (name = nameof(op)) in NaNMathFuns) || return nothing fun = GlobalRef(NaNMath, name) args = map(Base.Fix2(toexpr, st), arguments(O)) expr = Expr(:call, fun) - expr.args = args + append!(expr.args, args) return expr end diff --git a/test/code.jl b/test/code.jl index 636023a54..5675b878d 100644 --- a/test/code.jl +++ b/test/code.jl @@ -189,6 +189,10 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen f = eval(toexpr(Func([a+b], [], a+b))) @test f(1) == 1 @test f(2) == 2 + + f = eval(toexpr(Func([a, b], [], sqrt(a - b)))) + @test isnan(f(0, 10)) + @test f(10, 2) ≈ sqrt(8) end let From 382a88ffd25eb0d725cca27c1819992b7afaea9a Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 26 Jul 2023 17:17:45 -0400 Subject: [PATCH 3/3] Check function identity --- src/code.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/code.jl b/src/code.jl index 52be9c85d..6674da274 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,6 @@ module Code -using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath +using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, @@ -97,22 +97,23 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2]) toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st))) const NaNMathFuns = ( - :sin, - :cos, - :tan, - :asin, - :acos, - :acosh, - :atanh, - :log, - :log2, - :log10, - :lgamma, - :log1p, - :sqrt, + sin, + cos, + tan, + asin, + acos, + acosh, + atanh, + log, + log2, + log10, + lgamma, + log1p, + sqrt, ) function function_to_expr(op, O, st) - (op isa Function && (name = nameof(op)) in NaNMathFuns) || return nothing + op in NaNMathFuns || return nothing + name = nameof(op) fun = GlobalRef(NaNMath, name) args = map(Base.Fix2(toexpr, st), arguments(O)) expr = Expr(:call, fun)