diff --git a/src/QSymbolicsBase/basic_ops_homogeneous.jl b/src/QSymbolicsBase/basic_ops_homogeneous.jl index ed41ff8..1263438 100644 --- a/src/QSymbolicsBase/basic_ops_homogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_homogeneous.jl @@ -29,7 +29,7 @@ arguments(x::SScaled) = [x.coeff,x.obj] operation(x::SScaled) = * head(x::SScaled) = :* children(x::SScaled) = [:*,x.coeff,x.obj] -function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} +function Base.:(*)(c::U, x::Symbolic{T}) where {U<:Union{Number, Symbolic{<:Number}},T<:QObj} if (isa(c, Number) && iszero(c)) || iszero(x) SZero{T}() elseif _isone(c) @@ -40,9 +40,9 @@ function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} SScaled{T}(c, x) end end -Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x +Base.:(*)(x::Symbolic{T}, c::Number) where {T<:QObj} = c*x Base.:(*)(x::Symbolic{T}, y::Symbolic{S}) where {T<:QObj,S<:QObj} = throw(ArgumentError("multiplication between $(typeof(x)) and $(typeof(y)) is not defined; maybe you are looking for a tensor product `tensor`")) -Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x +Base.:(/)(x::Symbolic{T}, c::Number) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x basis(x::SScaled) = basis(x.obj) const SScaledKet = SScaled{AbstractKet} @@ -94,13 +94,13 @@ arguments(x::SAdd) = x._arguments_precomputed operation(x::SAdd) = + head(x::SAdd) = :+ children(x::SAdd) = [:+; x._arguments_precomputed] -function Base.:(+)(xs::Vararg{Symbolic{T},N}) where {T<:QObj,N} +function Base.:(+)(x::Symbolic{T}, xs::Vararg{Symbolic{T}, N}) where {T<:QObj, N} + xs = (x, xs...) xs = collect(xs) f = first(xs) nonzero_terms = filter!(x->!iszero(x),xs) isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SScaled{T})) end -Base.:(+)(xs::Vararg{Symbolic{<:QObj},0}) = 0 # to avoid undefined type parameters issue in the above method basis(x::SAdd) = basis(first(x.dict).first) const SAddBra = SAdd{AbstractBra} @@ -137,7 +137,8 @@ arguments(x::SMulOperator) = x.terms operation(x::SMulOperator) = * head(x::SMulOperator) = :* children(x::SMulOperator) = [:*;x.terms] -function Base.:(*)(xs::Symbolic{AbstractOperator}...) +function Base.:(*)(x::Symbolic{AbstractOperator}, xs::Vararg{Symbolic{AbstractOperator}, N}) where {N} + xs = (x, xs...) zero_ind = findfirst(x->iszero(x), xs) if isnothing(zero_ind) if any(x->!(samebases(basis(x),basis(first(xs)))),xs) diff --git a/src/QSymbolicsBase/basic_superops.jl b/src/QSymbolicsBase/basic_superops.jl index b52792c..2cea4c2 100644 --- a/src/QSymbolicsBase/basic_superops.jl +++ b/src/QSymbolicsBase/basic_superops.jl @@ -29,6 +29,8 @@ kraus(xs::Symbolic{AbstractOperator}...) = KrausRepr(collect(xs)) basis(x::KrausRepr) = basis(first(x.krausops)) Base.:(*)(sop::KrausRepr, op::Symbolic{AbstractOperator}) = (+)((i*op*dagger(i) for i in sop.krausops)...) Base.:(*)(sop::KrausRepr, k::Symbolic{AbstractKet}) = (+)((i*SProjector(k)*dagger(i) for i in sop.krausops)...) +Base.:(*)(sop::KrausRepr, k::SZeroOperator) = SZeroOperator() +Base.:(*)(sop::KrausRepr, k::SZeroKet) = SZeroOperator() Base.show(io::IO, x::KrausRepr) = print(io, "𝒦("*join([symbollabel(i) for i in x.krausops], ",")*")") ## diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 24f6440..df8fec6 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,7 +1,42 @@ @testitem "Aqua" tags=[:aqua] begin using Aqua - Aqua.test_all(QuantumSymbolics, - ambiguities=(;broken=true), - piracies=(;broken=true), - ) + + # Add any new types needed to QObj, or here if QObj if not appropriate. + # Add types from elsewhere in the ecosystem here or preferably to QObj + own_types = [Base.uniontypes(QObj)...,] + own_types_union = Union{SymQObj,} + + Aqua.test_all(QuantumSymbolics, piracies=(;treat_as_own=own_types)) + + function normalize_arguments(method) + args = Base.unwrap_unionall(method.sig).types[2:end] + normalized_args = [] + # handle few edge cases specific to our analysis + for arg in args + # mutation and order of if-conditions is intedtional here + if (arg isa UnionAll) && (arg.body <: Type) arg = arg.body.parameters[1] end + if (arg isa Core.TypeofVararg) arg = arg.T end + if (arg isa TypeVar) arg = arg.ub end + push!(normalized_args, arg) + end + return normalized_args + end + + # Custom type-piracy detection, to catch uses of QuantumInterface types without a Symbolic + filtered_piracies = filter(Aqua.Piracy.hunt(QuantumSymbolics)) do m + !any(normalize_arguments(m) .<: own_types_union) + end + + aqua_piracies = Aqua.Piracy.hunt(QuantumSymbolics, treat_as_own=own_types) + internally_detected_piracies = setdiff(filtered_piracies, aqua_piracies) + if !isempty(internally_detected_piracies) + printstyled( + stderr, + "Internally flagged possible type-piracy:\n"; + color = Base.warn_color() + ) + show(stderr, MIME"text/plain"(), internally_detected_piracies) + println(stderr, "\n") + end + @test isempty(internally_detected_piracies) end