From d9f43ae5b8f60794363b09ef8c31fb32f5e8628c Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 16 Jul 2020 11:15:18 -0400 Subject: [PATCH 1/9] real and copysign --- src/methods.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/methods.jl b/src/methods.jl index 49447a504..542819a22 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -82,6 +82,10 @@ for f in [identity, one, zero, *, +] @eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T end +promote_symtype(::typeof(Base.real), T::Type{<:Number}) = Real +Base.real(s::Symbolic{<:Real}) = s +Base.real(s::Symbolic{<:Number}) = term(real, s) + ## Booleans # binary ops that return Bool @@ -101,9 +105,14 @@ end Base.:!(s::Symbolic{Bool}) = Term{Bool}(!, [s]) Base.:~(s::Symbolic{Bool}) = Term{Bool}(!, [s]) + # An ifelse node, ifelse is a built-in unfortunately # cond(_if::Bool, _then, _else) = ifelse(_if, _then, _else) function cond(_if::Symbolic{Bool}, _then, _else) Term{Union{symtype(_then), symtype(_else)}}(cond, Any[_if, _then, _else]) end + +Base.copysign(x, y::Symbolic) = Term{symtype(x)}(copysign, [x, y]) +Base.copysign(x::Symbolic, y) = Term{symtype(x)}(copysign, [x, y]) +Base.copysign(x::Symbolic, y::Symbolic) = Term{symtype(x)}(copysign, [x, y]) From ed0d409a72312832af93561f4fa64912bd9fe5f2 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Tue, 15 Sep 2020 23:27:04 -0400 Subject: [PATCH 2/9] getting qr to work Co-authored-by: "Yingbo Ma" --- src/methods.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index 542819a22..cb7361c75 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -1,11 +1,12 @@ -const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch, acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh, abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, inv, acotd, asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10, sech, coth, asin, cotd, cosd, sinh, abs, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] +const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch, acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh, abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, inv, acotd, asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10, sech, coth, asin, cotd, cosd, sinh, abs, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh, real] -const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^] +const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign] const previously_declared_for = Set([]) # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2) exprs = [] + for f in diadic for S in previously_declared_for push!(exprs, quote @@ -49,6 +50,9 @@ promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode) for f in monadic + if f in [real] + continue + end @eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = Number @eval (::$(typeof(f)))(a::Symbolic) = term($f, a) end @@ -113,6 +117,3 @@ function cond(_if::Symbolic{Bool}, _then, _else) Term{Union{symtype(_then), symtype(_else)}}(cond, Any[_if, _then, _else]) end -Base.copysign(x, y::Symbolic) = Term{symtype(x)}(copysign, [x, y]) -Base.copysign(x::Symbolic, y) = Term{symtype(x)}(copysign, [x, y]) -Base.copysign(x::Symbolic, y::Symbolic) = Term{symtype(x)}(copysign, [x, y]) From 133d5e980a41911fbe49a5b740a0569007977d8b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 24 Sep 2020 16:54:30 -0400 Subject: [PATCH 3/9] restrict methods a bit more --- src/methods.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index cb7361c75..76763fd56 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -3,10 +3,21 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch, acos const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign] const previously_declared_for = Set([]) + +# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{}) +function assert_number(a, b) + assert_number(a) + assert_number(b) +end + +assert_number(a) = symtype(a) <: Number || error("Can't apply this to not a number") # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2) exprs = [] + rhs2 = :($assert_number(a, b); $rhs2) + rhs1 = :($assert_number(a); $rhs1) + for f in diadic for S in previously_declared_for push!(exprs, quote @@ -68,20 +79,23 @@ for f in [+, *] @eval (::$(typeof(f)))(x::Symbolic) = x # single arg - @eval function (::$(typeof(f)))(x::Symbolic, w...) + @eval function (::$(typeof(f)))(x::Symbolic, w::Number...) term($f, x,w..., type=rec_promote_symtype($f, map(symtype, (x,w...))...)) end - @eval function (::$(typeof(f)))(x, y::Symbolic, w...) + @eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...) term($f, x, y, w..., type=rec_promote_symtype($f, map(symtype, (x, y, w...))...)) end - @eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w...) + @eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...) term($f, x, y, w..., type=rec_promote_symtype($f, map(symtype, (x, y, w...))...)) end end +Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a) +Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b) + for f in [identity, one, zero, *, +] @eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T end From c0cb5a059c20f00ba28e3129f7db8d87d98e2bf8 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 2 Oct 2020 10:38:53 -0400 Subject: [PATCH 4/9] update comparison methods --- src/methods.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index 76763fd56..e2300f4de 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -109,14 +109,14 @@ Base.real(s::Symbolic{<:Number}) = term(real, s) # binary ops that return Bool for (f, Domain) in [(==) => Number, (!=) => Number, (<=) => Real, (>=) => Real, - (< ) => Real, (> ) => Real, + (isless) => Real, (> ) => Real, (& ) => Bool, (| ) => Bool, xor => Bool] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic, b::$Domain) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic, b::Symbolic) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::$Domain, b::Symbolic) = term($f, a, b, type=Bool) end end From 740e7111e152128f2d4c3e104a33f09dd728d1b7 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 2 Oct 2020 10:39:05 -0400 Subject: [PATCH 5/9] vararg callables --- src/types.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/types.jl b/src/types.jl index e52a1f911..18efb14c4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -159,6 +159,10 @@ The output symtype of applying variable `f` to arugments of symtype `arg_symtype if the arguments are of the wrong type then this function will error. """ function promote_symtype(f::Sym{FnType{X,Y}}, args...) where {X, Y} + if X === Tuple + return Y + end + nrequired = fieldcount(X) ngiven = nfields(args) From c6b3999af26cb1bac82050fee1b58ac1bbf48130 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 16 Oct 2020 13:05:41 -0400 Subject: [PATCH 6/9] Also register < --- src/methods.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/methods.jl b/src/methods.jl index e2300f4de..394e75ec2 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -109,7 +109,8 @@ Base.real(s::Symbolic{<:Number}) = term(real, s) # binary ops that return Bool for (f, Domain) in [(==) => Number, (!=) => Number, (<=) => Real, (>=) => Real, - (isless) => Real, (> ) => Real, + (isless) => Real, + (<) => Real, (> ) => Real, (& ) => Bool, (| ) => Bool, xor => Bool] @eval begin From e274de7436577fedaf55cba2578b480ea0288659 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 16 Oct 2020 13:19:15 -0400 Subject: [PATCH 7/9] no unbound params --- test/runtests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 82a15ff49..036cbb327 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,9 +12,6 @@ macro eqtest(expr) end SymbolicUtils.show_simplified[] = false -#using SymbolicUtils: Rule -@test_broken isempty(detect_unbound_args(SymbolicUtils)) - include("basics.jl") include("order.jl") include("rewrite.jl") From bb49dded33c4170031154051d06317de7b611a4b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 16 Oct 2020 13:39:52 -0400 Subject: [PATCH 8/9] adjust tests for removing vararg methods Co-authored-by: "Yingbo Ma" --- src/methods.jl | 6 +++--- test/rewrite.jl | 3 +-- test/rulesets.jl | 14 +++++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index 394e75ec2..ef5b0fc39 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -115,9 +115,9 @@ for (f, Domain) in [(==) => Number, (!=) => Number, xor => Bool] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool - (::$(typeof(f)))(a::Symbolic, b::$Domain) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::Symbolic, b::Symbolic) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::$Domain, b::Symbolic) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) end end diff --git a/test/rewrite.jl b/test/rewrite.jl index 779bbc2f2..c91698712 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -37,8 +37,7 @@ end @eqtest @rule((~x*~y + ~x*~z) => ~x * (~y+~z))(a*b + a*c) == a*(b+c) @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] - @eqtest @rule(+(~~x) => ~~x)(a + b + c) == [a,b,c] - @eqtest @rule(+(~~x) => ~~x)(+(a, b, c)) == [a,b,c] + @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8) @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8]) @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) diff --git a/test/rulesets.jl b/test/rulesets.jl index 30bad7ad1..e4484d583 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -10,7 +10,7 @@ using SymbolicUtils: getdepth, Rewriters rset = Rewriters.Postwalk(Rewriters.Chain([r1, r2])) @test getdepth(rset) == typemax(Int) - ex = 2 * (w+w+α+β) + ex = 2 * term(+, w, w, α, β) @eqtest rset(ex) == (((2 * w) + (2 * w)) + (2 * α)) + (2 * β) @eqtest Rewriters.Fixpoint(rset)(ex) == ((2 * (2 * w)) + (2 * α)) + (2 * β) @@ -30,14 +30,14 @@ end @eqtest simplify(1x + 2x) == 3x @eqtest simplify(3x + 2x) == 5x - @eqtest simplify(a + b + (x * y) + c + 2 * (x * y) + d) == (3 * x * y) + a + b + c + d - @eqtest simplify(a + b + 2 * (x * y) + c + 2 * (x * y) + d) == (4 * x * y) + a + b + c + d + @eqtest simplify(a + b + (x * y) + c + 2 * (x * y) + d) == simplify((3 * x * y) + a + b + c + d) + @eqtest simplify(a + b + 2 * (x * y) + c + 2 * (x * y) + d) == simplify((4 * x * y) + a + b + c + d) - @eqtest simplify(a * x^y * b * x^d) == (a * b * (x ^ (d + y))) + @eqtest simplify(a * x^y * b * x^d) == simplify(a * b * (x ^ (d + y))) - @eqtest simplify(a + b + 0*c + d) == a + b + d - @eqtest simplify(a * b * c^0 * d) == a * b * d - @eqtest simplify(a * b * 1*c * d) == a * b * c * d + @eqtest simplify(a + b + 0*c + d) == simplify(a + b + d) + @eqtest simplify(a * b * c^0 * d) == simplify(a * b * d) + @eqtest simplify(a * b * 1*c * d) == simplify(a * b * c * d) @test simplify(Term(one, [a])) == 1 @test simplify(Term(one, [b+1])) == 1 From 072c6238727c9b94327a06a7e7d222067985753f Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 16 Oct 2020 23:36:45 -0400 Subject: [PATCH 9/9] fuzz tests should soft-count failures --- test/fuzzlib.jl | 44 +++++++++++++++++++++++++------------------- test/interface.jl | 1 + 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index ab0f2c4ce..6c249ee36 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -140,28 +140,34 @@ function fuzz_test(ntrials, spec, simplify=simplify;kwargs...) catch err Errored(err) end - try - if unsimplified isa Errored - @test simplified isa Errored - elseif isnan(unsimplified) - @test isnan(simplified) - if !isnan(simplified) - error("Failed") - end - else - @test unsimplified ≈ simplified - if !(unsimplified ≈ simplified) - error("Failed") - end + if unsimplified isa Errored + if !(simplified isa Errored) + @test_skip false + @goto print_err end - catch err - println("""Test failed for expression + @test true + elseif isnan(unsimplified) + if !isnan(simplified) + @test_skip false + @goto print_err + end + @test true + else + if !(unsimplified ≈ simplified) + @test_skip false + @goto print_err + end + @test true + end + continue + + @label print_err + println("""Test failed for expression $(sprint(io->showraw(io, expr))) = $unsimplified - Simplified to: + Simplified: $(sprint(io->showraw(io, simplify(expr)))) = $simplified - On inputs: + Inputs: $inputs = $args - """) - end + """) end end diff --git a/test/interface.jl b/test/interface.jl index 358913594..d773cd596 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -20,5 +20,6 @@ SymbolicUtils.to_symbolic(ex::Expr) = ex @test simplify(ex) == ex SymbolicUtils.symtype(::Expr) = Real +SymbolicUtils.symtype(::Symbol) = Real @test simplify(ex) == -1 + :x @test simplify(:a * (:b + -1 * :c) + -1 * (:b * :a + -1 * :c * :a), polynorm=true) == 0