Skip to content

Commit

Permalink
Add some bounds checks for Diagonal([1]) (#40728)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored May 6, 2021
1 parent 15b5143 commit fbe28e4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
69 changes: 43 additions & 26 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,55 @@ Diagonal(v::AbstractVector{T}) where {T} = Diagonal{T,typeof(v)}(v)
Diagonal{T}(v::AbstractVector) where {T} = Diagonal(convert(AbstractVector{T}, v)::AbstractVector{T})

"""
Diagonal(A::AbstractMatrix)
Diagonal(V::AbstractVector)
Construct a matrix from the diagonal of `A`.
Construct a matrix with `V` as its diagonal.
See also [`diag`](@ref), [`diagm`](@ref).
# Examples
```jldoctest
julia> A = [1 2 3; 4 5 6; 7 8 9]
3×3 Matrix{Int64}:
1 2 3
4 5 6
7 8 9
julia> Diagonal(A)
3×3 Diagonal{Int64, Vector{Int64}}:
1 ⋅ ⋅
⋅ 5 ⋅
⋅ ⋅ 9
julia> Diagonal([1, 10, 100])
3×3 Diagonal{$Int, Vector{$Int}}:
1 ⋅ ⋅
⋅ 10 ⋅
⋅ ⋅ 100
julia> diagm([7, 13])
2×2 Matrix{$Int}:
7 0
0 13
```
"""
Diagonal(A::AbstractMatrix) = Diagonal(diag(A))
Diagonal(V::AbstractVector)

"""
Diagonal(V::AbstractVector)
Diagonal(A::AbstractMatrix)
Construct a matrix with `V` as its diagonal.
Construct a matrix from the diagonal of `A`.
# Examples
```jldoctest
julia> V = [1, 2]
2-element Vector{Int64}:
1
2
julia> Diagonal(V)
2×2 Diagonal{Int64, Vector{Int64}}:
1 ⋅
⋅ 2
julia> A = permutedims(reshape(1:15, 5, 3))
3×5 Matrix{Int64}:
1 2 3 4 5
6 7 8 9 10
11 12 13 14 15
julia> Diagonal(A)
3×3 Diagonal{$Int, Vector{$Int}}:
1 ⋅ ⋅
⋅ 7 ⋅
⋅ ⋅ 13
julia> diag(A, 2)
3-element Vector{$Int}:
3
9
15
```
"""
Diagonal(V::AbstractVector)
Diagonal(A::AbstractMatrix) = Diagonal(diag(A))

Diagonal(D::Diagonal) = D
Diagonal{T}(D::Diagonal{T}) where {T} = D
Expand Down Expand Up @@ -211,12 +220,20 @@ end

function rmul!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
nA, nD = size(A, 2), length(D.diag)
if nA != nD
throw(DimensionMismatch("second dimension of A, $nA, does not match the first of D, $nD"))
end
A .= A .* permutedims(D.diag)
return A
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
require_one_based_indexing(B)
nB, nD = size(B, 1), length(D.diag)
if nB != nD
throw(DimensionMismatch("second dimension of D, $nD, does not match the first of B, $nB"))
end
B .= D.diag .* B
return B
end
Expand Down
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,15 @@ end
@test yt*D*y == (yt*D)*y == (yt*A)*y
end

@testset "Multiplication of single element Diagonal (#36746)" begin
@testset "Multiplication of single element Diagonal (#36746, #40726)" begin
@test_throws DimensionMismatch Diagonal(randn(1)) * randn(5)
@test_throws DimensionMismatch Diagonal(randn(1)) * Diagonal(randn(3, 3))
A = [1 0; 0 2]
v = [3, 4]
@test Diagonal(A) * v == A * v
@test Diagonal(A) * Diagonal(A) == A * A
@test_throws DimensionMismatch [1 0;0 1] * Diagonal([2 3]) # Issue #40726
@test_throws DimensionMismatch lmul!(Diagonal([1]), [1,2,3]) # nearby
end

@testset "Triangular division by Diagonal #27989" begin
Expand Down

2 comments on commit fbe28e4

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here. cc @maleadt

Please sign in to comment.