From 262b40aa63bfc0482909dd48267b074fe3ac6d4b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 11 Jul 2024 11:33:11 +0530 Subject: [PATCH] Fix `(l/r)mul!` with `Diagonal`/`Bidiagonal` (#55052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, `rmul!(A::AbstractMatirx, D::Diagonal)` calls `mul!(A, A, D)`, but this isn't a valid call, as `mul!` assumes no aliasing between the destination and the matrices to be multiplied. As a consequence, ```julia julia> B = Bidiagonal(rand(4), rand(3), :L) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.476892 ⋅ ⋅ ⋅ 0.353756 0.139188 ⋅ ⋅ ⋅ 0.685839 0.309336 ⋅ ⋅ ⋅ 0.369038 0.304273 julia> D = Diagonal(rand(size(B,2))); julia> rmul!(B, D) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 julia> B 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ``` This is clearly nonsense, and happens because the internal `_mul!` function assumes that it can safely overwrite the destination with zeros before carrying out the multiplication. This is fixed in this PR by using broadcasting instead. The current implementation is generally equally performant, albeit occasionally with a minor allocation arising from `reshape`ing an `Array`. A similar problem also exists in `l/rmul!` with `Bidiaognal`, but that's a little harder to fix while remaining equally performant. --- stdlib/LinearAlgebra/src/bidiag.jl | 72 ++++++++++++++++++++++++++- stdlib/LinearAlgebra/src/diagonal.jl | 45 ++++++++++++++++- stdlib/LinearAlgebra/test/bidiag.jl | 35 +++++++++++++ stdlib/LinearAlgebra/test/diagonal.jl | 13 +++++ stdlib/LinearAlgebra/test/tridiag.jl | 19 +++++++ test/testhelpers/SizedArrays.jl | 3 ++ 6 files changed, 183 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 7a51b0dbfb7c4..04d54911d88aa 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -470,8 +470,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul()) -rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul()) +# B .= A * B +function lmul!(A::Bidiagonal, B::AbstractVecOrMat) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in axes(B,2) + for i in axes(ev,1) + B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k] + end + B[end,k] = dv[end] * B[end,k] + end + else + for k in axes(B,2) + for i in reverse(axes(dv,1)[2:end]) + B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k] + end + B[1,k] = dv[1] * B[1,k] + end + end + return B +end +# B .= D * B +function lmul!(D::Diagonal, B::Bidiagonal) + _muldiag_size_check(D, B) + (; dv, ev) = B + isL = B.uplo == 'L' + dv[1] = D.diag[1] * dv[1] + for i in axes(ev,1) + ev[i] = D.diag[i + isL] * ev[i] + dv[i+1] = D.diag[i+1] * dv[i+1] + end + return B +end +# B .= B * A +function rmul!(B::AbstractMatrix, A::Bidiagonal) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in reverse(axes(dv,1)[2:end]) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1] + end + end + for i in axes(B,1) + B[i,1] *= dv[1] + end + else + for k in axes(ev,1) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k] + end + end + for i in axes(B,1) + B[i,end] *= dv[end] + end + end + return B +end +# B .= B * D +function rmul!(B::Bidiagonal, D::Diagonal) + _muldiag_size_check(B, D) + (; dv, ev) = B + isU = B.uplo == 'U' + dv[1] *= D.diag[1] + for i in axes(ev,1) + ev[i] *= D.diag[i + isU] + dv[i+1] *= D.diag[i+1] + end + return B +end function check_A_mul_B!_sizes(C, A, B) mA, nA = size(A) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1ff7d371043ee..b3826a2aa7f82 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -327,8 +327,49 @@ function (*)(D::Diagonal, V::AbstractVector) return D.diag .* V end -rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) -lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) +function rmul!(A::AbstractMatrix, D::Diagonal) + _muldiag_size_check(A, D) + for I in CartesianIndices(A) + row, col = Tuple(I) + @inbounds A[row, col] *= D.diag[col] + end + return A +end +# T .= T * D +function rmul!(T::Tridiagonal, D::Diagonal) + _muldiag_size_check(T, D) + (; dl, d, du) = T + d[1] *= D.diag[1] + for i in axes(dl,1) + dl[i] *= D.diag[i] + du[i] *= D.diag[i+1] + d[i+1] *= D.diag[i+1] + end + return T +end + +function lmul!(D::Diagonal, B::AbstractVecOrMat) + _muldiag_size_check(D, B) + for I in CartesianIndices(B) + row = I[1] + @inbounds B[I] = D.diag[row] * B[I] + end + return B +end + +# in-place multiplication with a diagonal +# T .= D * T +function lmul!(D::Diagonal, T::Tridiagonal) + _muldiag_size_check(D, T) + (; dl, d, du) = T + d[1] = D.diag[1] * d[1] + for i in axes(dl,1) + dl[i] = D.diag[i+1] * dl[i] + du[i] = D.diag[i] * du[i] + d[i+1] = D.diag[i+1] * d[i+1] + end + return T +end function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} require_one_based_indexing(out, B) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 2380a93d90a74..2ff3e9b423702 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -933,6 +933,41 @@ end @test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)] end +@testset "rmul!/lmul! with banded matrices" begin + dv, ev = rand(4), rand(3) + for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L)) + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(dv, ev, :U), + Bidiagonal(dv, ev, :L), + Diagonal(dv) + ) + @test_throws ArgumentError rmul!(B, A) + @test_throws ArgumentError lmul!(A, B) + end + D = Diagonal(dv) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + end + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + for uplo in (:L, :U) + B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo) + D = Diagonal(fill(S22, size(B,2))) + @test rmul!(copy(B), D) ≈ B * D + D = Diagonal(fill(S33, size(B,1))) + @test lmul!(D, copy(B)) ≈ D * B + end + + B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U) + D = Diagonal(fill(S32, 4)) + @test lmul!(B, Array(D)) ≈ B * D + B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U) + @test rmul!(Array(D), B) ≈ D * B + end +end + @testset "conversion to Tridiagonal for immutable bands" begin n = 4 dv = FillArrays.Fill(3, n) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 4009f841a355c..1a3b8d4fd0ea7 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1322,4 +1322,17 @@ end @test M == D end +@testset "rmul!/lmul! with banded matrices" begin + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(rand(4), rand(3), :L), + Tridiagonal(rand(3), rand(4), rand(3)) + ) + BA = Array(B) + D = Diagonal(rand(size(B,1))) + DA = Array(D) + @test rmul!(copy(B), D) ≈ B * D ≈ BA * DA + @test lmul!(D, copy(B)) ≈ D * B ≈ DA * BA + end +end + end # module TestDiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index a067c18d7665d..fae708c4c8db4 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -892,4 +892,23 @@ end end end +@testset "rmul!/lmul! with banded matrices" begin + dl, d, du = rand(3), rand(4), rand(3) + A = Tridiagonal(dl, d, du) + D = Diagonal(d) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3)) + D = Diagonal(fill(S22, size(T,2))) + @test rmul!(copy(T), D) ≈ T * D + D = Diagonal(fill(S33, size(T,1))) + @test lmul!(D, copy(T)) ≈ D * T + end +end + end # module TestTridiagonal diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 2d37cead61a08..bc02fb5cbbd20 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1 Base.last(r::SOneTo) = length(r) Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")") +Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s +Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s + struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N} data::A function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}