Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added naming based on input types #660

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
"""
$(SIGNATURES)

Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types)
Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types)
of `x`. By default this is just `typeof(x)`.
Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules
specific to numbers (such as commutativity of multiplication). Or such
Expand Down Expand Up @@ -561,9 +561,9 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
st = symtype(T)
pst = _promote_symtype(head, args)
# Use promoted symtype only if not a subtype of the existing symtype of T.
# This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])`
# Where the result would have a symtype of Bool.
# Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609
# This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])`
# Where the result would have a symtype of Bool.
# Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609
# TODO this should be optimized.
new_st = if st <: AbstractArray
st
Expand Down Expand Up @@ -816,27 +816,37 @@ function show_ref(io, f, args)
print(io, "]")
end

import Base.nameof
# To fall through the `nameof` in the `show_call` below
Base.nameof(f, arg, args...) = nameof(f)

"""
show_call(io, f, args)
Displays the function call with given args. There are different outputs if `f`
is unary, binary or otherwise. `f`'s output can also be decorated using
`Base.nameof` provided with the function as well as with the `symtype`
of `f`'s arguments.
"""
function show_call(io, f, args)
fname = iscall(f) ? Symbol(repr(f)) : nameof(f)
fname = nameof(f, symtype.(args)...)
frep = Symbol(repr(f))

len_args = length(args)
if Base.isunaryoperator(fname) && len_args == 1

if Base.isunaryoperator(frep) && len_args == 1
print(io, "$fname")
print_arg(io, first(args), paren=true)
elseif Base.isbinaryoperator(fname) && len_args > 1
elseif Base.isbinaryoperator(frep) && len_args > 1
for (i, t) in enumerate(args)
i != 1 && print(io, " $fname ")
print_arg(io, t, paren=true)
end
else
if issym(f)
Base.show_unquoted(io, nameof(f))
else
Base.show(io, f)
end
print(io, "$fname")
print(io, "(")
for i=1:length(args)
for i=1:len_args
print(io, args[i])
i != length(args) && print(io, ", ")
i != len_args && print(io, ", ")
end
print(io, ")")
end
Expand Down
43 changes: 38 additions & 5 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ using SymbolicUtils
using IfElse: ifelse
using Setfield
using Test, ReferenceTests
import Base.nameof

include("utils.jl")

@testset "@syms" begin
let
@syms a b::Float64 f(::Real) g(p, h(q::Real))::Int
@syms a b::Float64 f(::Real) g(p, h(q::Real))::Int

@test issym(a) && symtype(a) == Number
@test a.name === :a
Expand Down Expand Up @@ -235,6 +236,38 @@ end
@test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14)))
end

let

sq(x) = return SymbolicUtils.Term{Number}(sq, [x])

function Base.nameof(::typeof(sq), arg)
if arg <: Real
return :sqrt_R
elseif arg <: Complex
return :sqrt_C
else
return :sqrt
end
end

@testset "call printing" begin
get_print(sym) = begin b = IOBuffer(); print(b, sym); String(take!(b)); end

x,y,z = @syms x::Real y::Complex z
@syms e() f(x) g(x,y) h(x,y,z)

@test get_print(e()) == "e()"
@test get_print(f(x)) == "f(x)"
@test get_print(g(x,y)) == "g(x, y)"
@test get_print(h(x,y,z)) == "h(x, y, z)"

@test get_print(sq(x)) == "sqrt_R(x)"
@test get_print(sq(y)) == "sqrt_C(y)"
@test get_print(sq(z)) == "sqrt(z)"
end

end

@testset "maketerm" begin
@syms a b c
@test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1))
Expand All @@ -249,7 +282,7 @@ end
# test that maketerm sets metadata correctly
metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1")
metadata2 = Base.ImmutableDict{DataType, Any}(Ctx2, "meta_2")

d = b * c
@set! d.metadata = metadata2

Expand Down Expand Up @@ -277,12 +310,12 @@ end
@test symtype(new_expr) == Bool

# Doesn't know return type, promoted symtype is Any
foo(x,y) = x^2 + x
foo(x,y) = x^2 + x
new_expr = SymbolicUtils.maketerm(typeof(ref_expr), foo, [a, b], nothing)
@test symtype(new_expr) == Number

# Promoted symtype is a subtype of referred
@syms x::Int y::Int
@syms x::Int y::Int
new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (+), [x, y], nothing)
@test symtype(new_expr) == Int64

Expand Down Expand Up @@ -384,5 +417,5 @@ end
ax = adjoint(x)
@test isequal(ax, x)
@test ax === x
@test isequal(adjoint(y), conj(y))
@test isequal(adjoint(y), conj(y))
end
Loading