Skip to content

Commit

Permalink
feat: support getname for symbolic complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 19, 2024
1 parent c31c3fd commit 4ff0484
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
17 changes: 17 additions & 0 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,20 @@ Base.iszero(x::Complex{<:Num}) = iszero(real(x)) && iszero(imag(x))
Base.isone(x::Complex{<:Num}) = isone(real(x)) && iszero(imag(x))
_iszero(x::Complex{<:Num}) = _iszero(unwrap(x))
_isone(x::Complex{<:Num}) = _isone(unwrap(x))

function SymbolicIndexingInterface.hasname(x::ComplexTerm)
a = arguments(unwrap(x.im))[1]
b = arguments(unwrap(x.re))[1]
return isequal(a, b) && hasname(a)
end

function _getname(x::ComplexTerm, val)
a = arguments(unwrap(x.im))[1]
b = arguments(unwrap(x.re))[1]
if isequal(a, b)
return _getname(a, val)
end
if val == _fail
throw(ArgumentError("Variable $x doesn't have a name."))
end
end
9 changes: 7 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,18 @@ function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: Abstract
ArraySymbolic()
end

SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x))
SymbolicIndexingInterface.hasname(x::Union{Num,Arr,Complex{Num}}) = hasname(unwrap(x))

function SymbolicIndexingInterface.hasname(x::Symbolic)
issym(x) || !iscall(x) || iscall(x) && (issym(operation(x)) || operation(x) == getindex)
end

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)
# This is type piracy, but changing it breaks precompilation for MTK because it relies on this falling back to
# `_getname` which calls `nameof` which returns the name of the system, when `x::AbstractSystem`.
# FIXME: In a breaking release
function SymbolicIndexingInterface.getname(x, val = _fail)
_getname(unwrap(x), val)
end

function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic, Equation, Inequality}, d::Dict; kwargs...)
val = fixpoint_sub(ex, d; kwargs...)
Expand Down
12 changes: 12 additions & 0 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@ end
z2 = 1.0 + z*im
@test isnothing(metadata(unwrap(z1.re)))
end

@testset "getname" begin
@variables t a b x::Complex y(t)::Complex z(a, b)::Complex
@test hasname(x)
@test getname(x) == :x
@test hasname(y)
@test getname(y) == :y
@test hasname(z)
@test getname(z) == :z
@test !hasname(2x)
@test !hasname(x + y)
end

0 comments on commit 4ff0484

Please sign in to comment.