Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (l/r)mul! with Diagonal/Bidiagonal #55052

Merged
merged 8 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
# B .= A * B
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in axes(B,2)
for i in axes(ev,1)
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
end
B[end,k] = dv[end] * B[end,k]
end
else
for k in axes(B,2)
for i in reverse(axes(dv,1)[2:end])
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
end
B[1,k] = dv[1] * B[1,k]
end
end
return B
end
# B .= D * B
function lmul!(D::Diagonal, B::Bidiagonal)
_muldiag_size_check(D, B)
(; dv, ev) = B
isL = B.uplo == 'L'
dv[1] = D.diag[1] * dv[1]
for i in axes(ev,1)
ev[i] = D.diag[i + isL] * ev[i]
dv[i+1] = D.diag[i+1] * dv[i+1]
end
return B
end
# B .= B * A
function rmul!(B::AbstractMatrix, A::Bidiagonal)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in reverse(axes(dv,1)[2:end])
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
end
end
for i in axes(B,1)
B[i,1] *= dv[1]
end
else
for k in axes(ev,1)
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
end
end
for i in axes(B,1)
B[i,end] *= dv[end]
end
end
return B
end
# B .= B * D
function rmul!(B::Bidiagonal, D::Diagonal)
_muldiag_size_check(B, D)
(; dv, ev) = B
isU = B.uplo == 'U'
dv[1] *= D.diag[1]
for i in axes(ev,1)
ev[i] *= D.diag[i + isU]
dv[i+1] *= D.diag[i+1]
end
return B
end

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
Expand Down
45 changes: 43 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,49 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
function rmul!(A::AbstractMatrix, D::Diagonal)
_muldiag_size_check(A, D)
for I in CartesianIndices(A)
row, col = Tuple(I)
@inbounds A[row, col] *= D.diag[col]
end
return A
end
# T .= T * D
function rmul!(T::Tridiagonal, D::Diagonal)
_muldiag_size_check(T, D)
(; dl, d, du) = T
d[1] *= D.diag[1]
for i in axes(dl,1)
dl[i] *= D.diag[i]
du[i] *= D.diag[i+1]
d[i+1] *= D.diag[i+1]
end
return T
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
_muldiag_size_check(D, B)
for I in CartesianIndices(B)
row = I[1]
@inbounds B[I] = D.diag[row] * B[I]
end
return B
end

# in-place multiplication with a diagonal
# T .= D * T
function lmul!(D::Diagonal, T::Tridiagonal)
_muldiag_size_check(D, T)
(; dl, d, du) = T
d[1] = D.diag[1] * d[1]
for i in axes(dl,1)
dl[i] = D.diag[i+1] * dl[i]
du[i] = D.diag[i] * du[i]
d[i+1] = D.diag[i+1] * d[i+1]
end
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,41 @@ end
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
end

@testset "rmul!/lmul! with banded matrices" begin
dv, ev = rand(4), rand(3)
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(dv, ev, :U),
Bidiagonal(dv, ev, :L),
Diagonal(dv)
)
@test_throws ArgumentError rmul!(B, A)
@test_throws ArgumentError lmul!(A, B)
end
D = Diagonal(dv)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A
end
@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
for uplo in (:L, :U)
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
D = Diagonal(fill(S22, size(B,2)))
@test rmul!(copy(B), D) B * D
D = Diagonal(fill(S33, size(B,1)))
@test lmul!(D, copy(B)) D * B
end

B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
D = Diagonal(fill(S32, 4))
@test lmul!(B, Array(D)) B * D
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
@test rmul!(Array(D), B) D * B
end
end

@testset "conversion to Tridiagonal for immutable bands" begin
n = 4
dv = FillArrays.Fill(3, n)
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1322,4 +1322,17 @@ end
@test M == D
end

@testset "rmul!/lmul! with banded matrices" begin
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(rand(4), rand(3), :L),
Tridiagonal(rand(3), rand(4), rand(3))
)
BA = Array(B)
D = Diagonal(rand(size(B,1)))
DA = Array(D)
@test rmul!(copy(B), D) B * D BA * DA
@test lmul!(D, copy(B)) D * B DA * BA
end
end

end # module TestDiagonal
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -892,4 +892,23 @@ end
end
end

@testset "rmul!/lmul! with banded matrices" begin
dl, d, du = rand(3), rand(4), rand(3)
A = Tridiagonal(dl, d, du)
D = Diagonal(d)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A

@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
D = Diagonal(fill(S22, size(T,2)))
@test rmul!(copy(T), D) T * D
D = Diagonal(fill(S33, size(T,1)))
@test lmul!(D, copy(T)) D * T
end
end

end # module TestTridiagonal
3 changes: 3 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
Base.last(r::SOneTo) = length(r)
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")

Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
data::A
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
Expand Down