diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 7a51b0dbfb7c4..04d54911d88aa 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -470,8 +470,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul()) -rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul()) +# B .= A * B +function lmul!(A::Bidiagonal, B::AbstractVecOrMat) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in axes(B,2) + for i in axes(ev,1) + B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k] + end + B[end,k] = dv[end] * B[end,k] + end + else + for k in axes(B,2) + for i in reverse(axes(dv,1)[2:end]) + B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k] + end + B[1,k] = dv[1] * B[1,k] + end + end + return B +end +# B .= D * B +function lmul!(D::Diagonal, B::Bidiagonal) + _muldiag_size_check(D, B) + (; dv, ev) = B + isL = B.uplo == 'L' + dv[1] = D.diag[1] * dv[1] + for i in axes(ev,1) + ev[i] = D.diag[i + isL] * ev[i] + dv[i+1] = D.diag[i+1] * dv[i+1] + end + return B +end +# B .= B * A +function rmul!(B::AbstractMatrix, A::Bidiagonal) + _muldiag_size_check(A, B) + (; dv, ev) = A + if A.uplo == 'U' + for k in reverse(axes(dv,1)[2:end]) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1] + end + end + for i in axes(B,1) + B[i,1] *= dv[1] + end + else + for k in axes(ev,1) + for i in axes(B,1) + B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k] + end + end + for i in axes(B,1) + B[i,end] *= dv[end] + end + end + return B +end +# B .= B * D +function rmul!(B::Bidiagonal, D::Diagonal) + _muldiag_size_check(B, D) + (; dv, ev) = B + isU = B.uplo == 'U' + dv[1] *= D.diag[1] + for i in axes(ev,1) + ev[i] *= D.diag[i + isU] + dv[i+1] *= D.diag[i+1] + end + return B +end function check_A_mul_B!_sizes(C, A, B) mA, nA = size(A) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1ff7d371043ee..b3826a2aa7f82 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -327,8 +327,49 @@ function (*)(D::Diagonal, V::AbstractVector) return D.diag .* V end -rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) -lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) +function rmul!(A::AbstractMatrix, D::Diagonal) + _muldiag_size_check(A, D) + for I in CartesianIndices(A) + row, col = Tuple(I) + @inbounds A[row, col] *= D.diag[col] + end + return A +end +# T .= T * D +function rmul!(T::Tridiagonal, D::Diagonal) + _muldiag_size_check(T, D) + (; dl, d, du) = T + d[1] *= D.diag[1] + for i in axes(dl,1) + dl[i] *= D.diag[i] + du[i] *= D.diag[i+1] + d[i+1] *= D.diag[i+1] + end + return T +end + +function lmul!(D::Diagonal, B::AbstractVecOrMat) + _muldiag_size_check(D, B) + for I in CartesianIndices(B) + row = I[1] + @inbounds B[I] = D.diag[row] * B[I] + end + return B +end + +# in-place multiplication with a diagonal +# T .= D * T +function lmul!(D::Diagonal, T::Tridiagonal) + _muldiag_size_check(D, T) + (; dl, d, du) = T + d[1] = D.diag[1] * d[1] + for i in axes(dl,1) + dl[i] = D.diag[i+1] * dl[i] + du[i] = D.diag[i] * du[i] + d[i+1] = D.diag[i+1] * d[i+1] + end + return T +end function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} require_one_based_indexing(out, B) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 2380a93d90a74..2ff3e9b423702 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -933,6 +933,41 @@ end @test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)] end +@testset "rmul!/lmul! with banded matrices" begin + dv, ev = rand(4), rand(3) + for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L)) + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(dv, ev, :U), + Bidiagonal(dv, ev, :L), + Diagonal(dv) + ) + @test_throws ArgumentError rmul!(B, A) + @test_throws ArgumentError lmul!(A, B) + end + D = Diagonal(dv) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + end + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + for uplo in (:L, :U) + B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo) + D = Diagonal(fill(S22, size(B,2))) + @test rmul!(copy(B), D) ≈ B * D + D = Diagonal(fill(S33, size(B,1))) + @test lmul!(D, copy(B)) ≈ D * B + end + + B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U) + D = Diagonal(fill(S32, 4)) + @test lmul!(B, Array(D)) ≈ B * D + B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U) + @test rmul!(Array(D), B) ≈ D * B + end +end + @testset "conversion to Tridiagonal for immutable bands" begin n = 4 dv = FillArrays.Fill(3, n) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 4009f841a355c..1a3b8d4fd0ea7 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1322,4 +1322,17 @@ end @test M == D end +@testset "rmul!/lmul! with banded matrices" begin + @testset "$(nameof(typeof(B)))" for B in ( + Bidiagonal(rand(4), rand(3), :L), + Tridiagonal(rand(3), rand(4), rand(3)) + ) + BA = Array(B) + D = Diagonal(rand(size(B,1))) + DA = Array(D) + @test rmul!(copy(B), D) ≈ B * D ≈ BA * DA + @test lmul!(D, copy(B)) ≈ D * B ≈ DA * BA + end +end + end # module TestDiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index a067c18d7665d..fae708c4c8db4 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -892,4 +892,23 @@ end end end +@testset "rmul!/lmul! with banded matrices" begin + dl, d, du = rand(3), rand(4), rand(3) + A = Tridiagonal(dl, d, du) + D = Diagonal(d) + @test rmul!(copy(A), D) ≈ A * D + @test lmul!(D, copy(A)) ≈ D * A + + @testset "non-commutative" begin + S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2)) + S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3)) + S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2)) + T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3)) + D = Diagonal(fill(S22, size(T,2))) + @test rmul!(copy(T), D) ≈ T * D + D = Diagonal(fill(S33, size(T,1))) + @test lmul!(D, copy(T)) ≈ D * T + end +end + end # module TestTridiagonal diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 2d37cead61a08..bc02fb5cbbd20 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1 Base.last(r::SOneTo) = length(r) Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")") +Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s +Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s + struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N} data::A function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}