diff --git a/src/ExtendedMul.jl b/src/ExtendedMul.jl index bd09031..ee519f1 100644 --- a/src/ExtendedMul.jl +++ b/src/ExtendedMul.jl @@ -1,105 +1,163 @@ using LinearAlgebra: BlasInt using LinearAlgebra.BLAS -import LinearAlgebra.BLAS: gemm!, @blasfunc, libblas - -export _mul! - -function gemm!(ta::Char, tb::Char, alpha::Float64, - a::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, ma::Int64, na::Int64, - b::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, nb::Int64, - beta::Float64, c::Union{Ref{Float64}, AbstractVecOrMat{Float64}}) - ccall((@blasfunc(dgemm_), libblas), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, - Ref{BlasInt}), - ta, tb, ma, nb, - na, alpha, a, max(ma, 1), - b, max(na, 1), beta, c, - max(ma, 1)) -end +import LinearAlgebra.BLAS: @blasfunc, libblas + +export unsafe_mul! + +function ext_gemm!(ta::Char, tb::Char, ma::Int, nb::Int, na::Int, alpha, + a::StridedVecOrMat{Float64}, b::StridedVecOrMat{Float64}, beta, + c::StridedVecOrMat{Float64}, offset_a, offset_b, offset_c) + lda = ndims(a) == 2 ? strides(a)[2] : ma + ldb = ndims(b) == 2 ? strides(b)[2] : na + ldc = ndims(c) == 2 ? strides(c)[2] : ma -# a * b, 3 methods (A_mul_B!) -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64) - gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b), - nb, 0.0, Ref(c, offset_c)) + ccall((@blasfunc("dgemm_"), libblas), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, + Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, + Ptr{Float64}, Ref{BlasInt}), ta, tb, ma, nb, na, alpha, Ref(a, offset_a), lda, + Ref(b, offset_b), ldb, beta, Ref(c, offset_c), ldc) + return c end -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::AbstractVecOrMat{Float64}) - nb = (ndims(b) > 1) ? size(b, 2) : 1 - _mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb) +function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::StridedVecOrMat; + offset1::Int = 1, offset2::Int = 1, offset3::Int = 1, + rows2::Int = size(a, 1), cols2::Int = size(a, 2), + cols3::Int = size(b, 2)) + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + ext_gemm!('N', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1) end -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::AbstractVecOrMat{Float64}, - b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64) - ma = size(a, 1) - na = (ndims(a) > 1) ? size(a, 2) : 1 - _mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb) +function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, + bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1, + offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1), + cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2)) + b = bAdj.parent + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + ext_gemm!('N', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1) end -# a' * b, 2 methods (At_mul_B!) -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64) - ccall((@blasfunc(dgemm_), libblas), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, - Ref{BlasInt}), - 'T', 'N', ma, nb, - na, 1.0, Ref(a, offset_a), max(na, 1), - Ref(b, offset_b), max(na, 1), 0.0, Ref(c, offset_c), - max(ma, 1)) +function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedArray}, + b::StridedVecOrMat; offset1::Int = 1, offset2::Int = 1, + offset3::Int = 1, rows2::Int = size(aAdj, 1), + cols2::Int = size(aAdj, 2), cols3::Int = size(b, 2)) + a = aAdj.parent + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + ext_gemm!('T', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1) end -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::Adjoint{Float64}, - b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64) - na = size(a.parent, 1) - ma = (ndims(a.parent) > 1) ? size(a.parent, 2) : 1 - _mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb) +function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedVecOrMat}, + bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1, + offset2::Int = 1, offset3::Int = 1, rows2::Int = size(aAdj, 1), + cols2::Int = size(aAdj, 2), cols3::Int = size(bAdj, 2)) + a = aAdj.parent + b = bAdj.parent + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + ext_gemm!('T', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1) end -# a * b', 2 methods (A_mul_Bt!) -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::Adjoint{Float64}, offset_b::Int64, nb::Int64) - if typeof(b) <: Adjoint{Float64, QuasiUpperTriangular{Float64, Matrix{Float64}}} +function blas_check(c, a, b, offset_c, offset_a, offset_b, ma, na, nb) + @boundscheck begin + # Make sure offset is a sane value + @assert (length(a) >= offset_a >= 1) + @assert (length(b) >= offset_b >= 1) + @assert (length(c) >= offset_c >= 1) + # Assert that there is enough data in each variable + _mb = na + @assert ma * na<=length(a) - offset_a + 1 "You're asking from A more than it has" + @assert _mb * nb<=length(b) - offset_b + 1 "You're asking from B more than it has" + @assert ma * nb<=length(c) - offset_c + 1 "You're assigning into C more than it can take" + + #TODO: if A is a vector, ma is supplied and na is not. Do not default to na=1 end - ccall((@blasfunc(dgemm_), libblas), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, - Ref{BlasInt}), - 'N', 'T', ma, nb, - na, 1.0, Ref(a, offset_a), max(ma, 1), - Ref(b, offset_b), max(nb, 1), 0.0, Ref(c, offset_c), - max(ma, 1)) end -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::Adjoint{Float64}) - nb = size(b.parent, 1) - _mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb) +## QuasiUpperTriangular + +# B = alpha*A*B such that A is triangular +function ext_trmm!(side::Char, uplo::Char, ta::Char, diag::Char, mb::Int, nb::Int, alpha, + a::StridedMatrix{Float64}, b::StridedVecOrMat{Float64}, offset_b) + # intentionally assume the triangular data is not in vector form + #TODO: Reconsider ma/na vs mb/nb. trmm docs usage might conflict a little with ours + + # following trmm docs + lda = uppercase(side) == 'L' ? mb : nb + ldb = ndims(b) == 2 ? strides(b)[2] : mb + + ccall((@blasfunc("dtrmm_"), libblas), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, + Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}), side, + uplo, ta, diag, mb, nb, alpha, a, lda, Ref(b, offset_b), ldb) + return b end -# a' * b', 1 method (At_mul_Bt!) -function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64, - a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64, - b::Adjoint{Float64}, offset_b::Int64, nb::Int64) - ccall((@blasfunc(dgemm_), libblas), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, - Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, - Ref{BlasInt}), - 'T', 'T', ma, nb, - na, 1.0, Ref(a, offset_a), max(na, 1), - Ref(b, offset_b), max(ma, 1), 0.0, Ref(c, offset_c), - max(ma, 1)) +# A is QuasiUpper +function unsafe_mul!(c::StridedVecOrMat, a::QuasiUpperTriangular, b::StridedVecOrMat; + offset1::Int = 1, offset2::Int = 1, offset3::Int = 1, + rows2::Int = size(a, 1), cols2::Int = size(a, 2), + cols3::Int = size(b, 2)) + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + + rows3 = cols2 + copyto!(c, offset1, b, offset3, rows3 * cols3) + alpha = 1.0 + ext_trmm!('L', 'U', 'N', 'N', rows3, cols3, alpha, a.data, c, offset1) + + @inbounds for i in 2:rows2 + x = a[i, i - 1] + indb = offset3 + indc = offset1 + 1 + @simd for j in 1:cols3 + c[indc] += x * b[indb] + indb += rows2 + indc += rows2 + end + end end + +# B is QuasiUpper +function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::QuasiUpperTriangular; + offset1::Int = 1, offset2::Int = 1, offset3::Int = 1, + rows2::Int = size(a, 1), cols2::Int = size(a, 2), + cols3::Int = size(b, 2)) + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + + copyto!(c, offset1, a, offset2, rows2 * cols2) + alpha = 1.0 + ext_trmm!('R', 'U', 'N', 'N', rows2, cols2, alpha, b.data, c, offset1) + + inda = offset2 + rows2 + indc = offset1 + @inbounds for i in 2:cols2 + x = b[i, i - 1] + @simd for j in 1:rows2 + c[indc] += x * a[inda] + inda += 1 + indc += 1 + end + end +end + +# B is an Adjoint of QuasiUpper +function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, + bAdj::Adjoint{Float64, <:QuasiUpperTriangular}; offset1::Int = 1, + offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1), + cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2)) + b = bAdj.parent + blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3) + + copyto!(c, offset1, a, offset2, rows2 * cols2) + alpha = 1.0 + ext_trmm!('R', 'U', 'T', 'N', rows2, cols2, alpha, b.data, c, offset1) + + inda = offset2 + indc = offset1 + rows2 + @inbounds for j in 2:cols2 + x = alpha * b[j, j - 1] + @simd for i in 1:rows2 + c[indc] += x * a[inda] + inda += 1 + indc += 1 + end + end + c +end \ No newline at end of file diff --git a/src/KroneckerTools.jl b/src/KroneckerTools.jl index 767fedc..4d2883e 100644 --- a/src/KroneckerTools.jl +++ b/src/KroneckerTools.jl @@ -3,11 +3,11 @@ module KroneckerTools using LinearAlgebra using LinearAlgebra.BLAS using QuasiTriangular -import QuasiTriangular: A_mul_B!, At_mul_B!, A_mul_Bt! import Base.convert include("ExtendedMul.jl") -export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct! +export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, + at_mul_b_kron_c!, a_mul_b_kron_ct!, kron_a_mul_b!, kron_at_mul_b! """ Content: @@ -51,19 +51,23 @@ function kron_mul_elem!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix, b begin if p == 1 && q == 1 # a*b - _mul!(c, offset_c, a, b, offset_b, 1) + #OLD _mul!(c, offset_c, a, b, offset_b, 1) + unsafe_mul!(c, a, b) elseif q == 1 # (I_p ⊗ a)*b = vec(a*[b_1 b_2 ... b_p]) - _mul!(c, offset_c, a, b, offset_b, p) + #OLD _mul!(c, offset_c, a, b, offset_b, p) + unsafe_mul!(c, a, b; offset1=offset_c, offset3=offset_b, cols3=p) elseif p == 1 # (a ⊗ I_q)*b = (b'*(a' ⊗ I_q))' = vec(reshape(b,q,m)*a') - _mul!(c, offset_c, b, offset_b, q, n, a') + #OLD _mul!(c, offset_c, b, offset_b, q, n, a') + unsafe_mul!(c, b, a'; offset1=offset_c, offset2=offset_b, rows2=q, cols2=n) else # (I_p ⊗ a ⊗ I_q)*b = vec([(a ⊗ I_q)*b_1 (a ⊗ I_q)*b_2 ... (a ⊗ I_q)*b_p]) mq = m*q nq = n*q for i=1:p - _mul!(c, offset_c, b, offset_b, q, n, a') + #OLD _mul!(c, offset_c, b, offset_b, q, n, a') + unsafe_mul!(c, b, a'; offset1=offset_c, offset2=offset_b, rows2=q, cols2=n) offset_b += nq offset_c += mq end @@ -89,19 +93,24 @@ function kron_mul_elem_t!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix, begin if p == 1 && q == 1 # a'*b - _mul!(c, offset_c, a', b, offset_b, 1) + #OLD _mul!(c, offset_c, a', b, offset_b, 1) + unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=1) elseif q == 1 # (I_p ⊗ a')*b = vec(a'*[b_1 b_2 ... b_p]) - _mul!(c, offset_c, a', b, offset_b, p) + #OLD _mul!(c, offset_c, a', b, offset_b, p) + unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=p) + elseif p == 1 # (a' ⊗ I_q)*b = (b'*(a ⊗ I_q))' = vec(reshape(b,q,m)*a) - _mul!(c, offset_c, b, offset_b, q, m, a) + #OLD _mul!(c, offset_c, b, offset_b, q, m, a) + unsafe_mul!(c, b, a; offset1=offset_c, offset2=offset_b, rows2=q, cols2=m) else # (I_p ⊗ a' ⊗ I_q)*b = vec([(a' ⊗ I_q)*b_1 (a' ⊗ I_q)*b_2 ... (a' ⊗ I_q)*b_p]) mq = m*q nq = n*q for i=1:p - _mul!(c, offset_c, b, offset_b, q, m, a) + #OLD _mul!(c, offset_c, b, offset_b, q, m, a) + unsafe_mul!(c, b, a; offset1=offset_c, offset2=offset_b, rows2=q, cols2=m) offset_b += mq offset_c += nq end @@ -207,7 +216,8 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using data from c at offset_c function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64) mb,nb = size(b) if order == 0 - _mul!(d,offset_d,b,1,mb,nb,c,offset_c,1) + #OLD _mul!(d, offset_d, b, 1, mb, nb, c, offset_c, 1) + unsafe_mul!(d, b, c; offset1=offset_d, offset3=offset_c, rows2=mb, cols2=nb, cols3=1) else ma, na = size(a) # length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))")) @@ -226,7 +236,7 @@ function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMa copyto!(d, offset_d, work2, offset_w, p*na*q) end end - +# (length(c) / nb) * mb < length(d) """ function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector) computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2 @@ -234,7 +244,8 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and w function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector) mb,nb = size(b) if order == 0 - _mul!(d,1,b,1,mb,nb,c,1,1) + #OLD _mul!(d,1,b,1,mb,nb,c,1,1) + unsafe_mul!(d, b, c) else ma, na = size(a) # length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))")) @@ -254,6 +265,9 @@ function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int6 end end + + + """ function kron_at_kron_b_mul_c!(a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work::AbstractVector) updates c at offset_c with (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using c and work as work vectors @@ -289,7 +303,6 @@ function kron_at_mul_b!(c::AbstractVector, a::AbstractMatrix, order::Int64, b::A end copyto!(c,1,work2,1,s) end - function kron_a_mul_b!(c::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractVector, q::Int64, work1::AbstractVector, work2::AbstractVector) ma,na = size(a) # length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))")) @@ -313,11 +326,13 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri mb, nb = size(b) mc, nc = size(c) if mc <= nc - _mul!(work1, 1, a', 1, na, ma, b, 1, nb) + #OLD _mul!(work1, 1, a', 1, na, ma, b, 1, nb) + unsafe_mul!(work1, a', b; rows2=na, cols2=ma, cols3=nb) kron_at_mul_b!(vec(d), c, order, work1, na, vec(d), work2) else kron_at_mul_b!(work1, c, order, b, na, work1, work2) - _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order) + #OLD _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order) + unsafe_mul!(vec(d), a', work1; rows2=na, cols2=ma, cols3=nc^order) end end @@ -326,11 +341,14 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri mb, nb = size(b) mc, nc = size(c) if mc <= nc - _mul!(work1, 1, a, 1, ma, na, b, 1, nb) + #OLD _mul!(work1, 1, a, 1, ma, na, b, 1, nb) + unsafe_mul!(work1, a, b) kron_a_mul_b!(vec(d), c, order, work1, ma, work1, work2) else kron_a_mul_b!(work1, b, order, c, mb, work1, work2) - _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order) + #OLD _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order) + unsafe_mul!(vec(d), a', work1; rows2=a, cols2=ma, cols3=nc^order) + end end @@ -389,7 +407,8 @@ function a_mul_b_kron_c_d!(e::AbstractMatrix, a::AbstractMatrix, b::AbstractMatr na == mb || throw(DimensionMismatch("The number of columns of a, $(size(a,2)), doesn't match the number of rows of b, $(size(b,1))")) nb == mc*md^(order-1) || throw(DimensionMismatch("The number of columns of b, $(size(b,2)), doesn't match the number of rows of c, $(size(c,1)), and d, $(size(d,1)) for order, $order")) (ma == me && nc*nd^(order-1) == ne) || throw(DimensionMismatch("Dimension mismatch for e: $(size(e)) while ($ma, $(nc*nd^(order-1))) was expected")) - _mul!(work1, 1, a, 1, ma, mb, vec(b), 1, nb) + # OLD _mul!(work1, 1, a, 1, ma, mb, b, 1, nb) + unsafe_mul!(work1, a, b; rows2=ma, cols2=mb, cols3=nb) p = mc*md^(order - 2) q = ma for i = 0:order - 2 diff --git a/test/runtests.jl b/test/runtests.jl index 436523d..8edf80d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,8 @@ using Test import KroneckerTools: kron_mul_elem!, kron_mul_elem_t!, a_mul_kron_b!, a_mul_b_kron_c!, a_mul_b_kron_ct!, at_mul_b_kron_c!, a_mul_b_kron_c_d!, kron_a_mul_b!, kron_at_mul_b!, kron_at_kron_b_mul_c!, - QuasiTriangular.QuasiUpperTriangular, - _mul! + QuasiTriangular.QuasiUpperTriangular, + ext_gemm!, ext_trmm!, unsafe_mul! Random.seed!(123) #a = rand(2,3) @@ -45,57 +45,102 @@ for m in [1, 3] end =# -@testset "KroneckerTools" verbose=true begin - @testset "ExtendedMul" begin - # a * b, 3 methods - ma = 2 - na = 4 - a = randn(ma, na) - mb = 4 - nb = 3 - b = randn(mb, nb) - c = randn(ma, nb) - _mul!(c, 1, a, 1, ma, na, b, 1, nb) - @test c ≈ a * b - c = randn(ma, nb) - _mul!(c, 1, a, 1, ma, na, b) - @test c ≈ a * b - c = randn(ma, nb) - _mul!(c, 1, a, b, 1, nb) - @test c ≈ a * b +@testset verbose=true "KroneckerTools" begin + @testset verbose=true "ExtendedMul" begin + @testset "Few Sanity Checks" begin + #TODO: this is an infinitesimal sample of the allowed possibilities + # undersized c output + n = 5 + a = randn(n + 1, n) + b = randn(n, n + 1) + c = randn(n, n) + @test_throws AssertionError unsafe_mul!(c, a, b) - # a' * b, 2 methods - mb = 2 - nb = 3 - b = randn(mb, nb) - c = randn(na, nb) - _mul!(c, 1, a', 1, na, ma, b, 1, nb) - @test c ≈ transpose(a) * b - c = randn(na, nb) - _mul!(c, 1, a', b, 1, nb) - @test c ≈ a' * b + # undersized b + a = randn(n * 2, n) + b = randn(n - 1, n * 3) + c = randn(n * 2, n * 3) + @test_throws AssertionError unsafe_mul!(c, a, b) + end - # a * b', 2 methods - nb = 4 - mb = 3 - b = randn(mb, nb) - c = randn(ma, mb) - _mul!(c, 1, a, 1, ma, na, b', 1, mb) - @test c ≈ a * transpose(b) - c = randn(ma, mb) - _mul!(c, 1, a, 1, ma, na, b') - @test c ≈ a * transpose(b) + # Sometimes res variables are used because B is mutated + @testset "Extended trmm" begin + # Simplest case + a = triu(rand(5, 5)) + b = (rand(5, 5)) + res1 = a * b + res2 = ext_trmm!('L', 'U', 'N', 'N', 5, 5, 1, a, b, 1) + @test res1 ≈ res2 + end - # a' * b', 1 method - mb = 4 - nb = 2 - b = randn(mb, nb) - c = randn(na, mb) - _mul!(c, 1, a', 1, na, ma, b', 1, mb) - @test c ≈ transpose(a) * transpose(b) + @testset "unsafe_mul! Matrix" begin + ## trying a different combination of default arguments each time wdw + # a * b + ma = 2 + na = 4 + a = randn(ma, na) + mb = 4 + nb = 3 + b = randn(mb, nb) + c = randn(ma, nb) + unsafe_mul!(c, a, b; rows2 = ma, cols2 = na, cols3 = nb) + @test c ≈ a * b + + # a' * b + mb = 2 + nb = 3 + b = randn(mb, nb) + c = randn(na, nb) + unsafe_mul!(c, a', b; offset1 = 1, offset2 = 1, cols2 = ma, cols3 = nb) + @test c ≈ transpose(a) * b + c = randn(na, nb) + unsafe_mul!(c, a', b) + @test c ≈ a' * b + + # a * b' + nb = 4 + mb = 3 + b = randn(mb, nb) + c = randn(ma, mb) + unsafe_mul!(c, a, b'; offset1 = 1, offset3 = 1, rows2 = ma) + @test c ≈ a * transpose(b) + c = randn(ma, mb) + unsafe_mul!(c, a, b') + @test c ≈ a * transpose(b) + + # a' * b' + mb = 4 + nb = 2 + b = randn(mb, nb) + c = randn(na, mb) + unsafe_mul!(c, a', b') + @test c ≈ transpose(a) * transpose(b) + end + + @testset "unsafe_mul! QuasiUpper" begin + # a * b, a is Quasi + n = 5 + a = QuasiUpperTriangular(triu(randn(n, n))) + b = randn(n, n * 2) + c = randn(n, n * 2) + unsafe_mul!(c, a, b) + @test c ≈ a * b + + # a * b', b is Quasi + n = 4 + a = randn(n, n) + b = QuasiUpperTriangular(triu(randn(n, n))) + c = randn(n, n) + unsafe_mul!(c, a, b') + @test c ≈ a * transpose(b) + # write output to a vector + c = randn(n * n) + unsafe_mul!(c, a, b') + @test c ≈ vec(a * transpose(b)) + end end - @testset "Main kronecker operations" begin + @testset "Kron operations" begin order = 2 ma = 2 na = 4 @@ -121,7 +166,7 @@ end w2 = Vector{Float64}(undef, ma * mc^order) a_mul_b_kron_c!(d, a, b, c, order, w1, w2) cc = c - for i = 2:order + for i in 2:order cc = kron(cc, c) end @test d ≈ a * b_orig * cc @@ -129,7 +174,7 @@ end b = copy(b_orig) a_mul_kron_b!(d, b, c, order) cc = c - for i = 2:order + for i in 2:order cc = kron(cc, c) end @test d ≈ b_orig * cc @@ -178,89 +223,87 @@ end work2 = rand(na * na * mb) kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2) kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2) - @test d ≈ kron(kron(a', a'), b) * c - end - @testset "kron_a_mul_b" begin - order = 2 - ma = 2 - na = 4 - q = 2 - a = rand(ma, na) - b = rand(q * na^order) - c = rand(q * ma^order) - work1 = rand(q * max(ma, na)^order) - work2 = similar(work1) - q = 2 - kron_a_mul_b!(c, a, order, b, q, work1, work2) - kron_a_mul_b!(c, a, order, b, q, work1, work2) - @test c ≈ kron(kron(a, a), Matrix{Float64}(I(q))) * b + @testset "kron_a_mul_b" begin + order = 2 + ma = 2 + na = 4 + q = 2 + a = rand(ma, na) + b = rand(q * na^order) + c = rand(q * ma^order) + work1 = rand(q * max(ma, na)^order) + work2 = similar(work1) + q = 2 + kron_a_mul_b!(c, a, order, b, q, work1, work2) + kron_a_mul_b!(c, a, order, b, q, work1, work2) + @test c ≈ kron(kron(a, a), Matrix{Float64}(I(q))) * b - # test2 - b = rand(q * ma^order) - c = rand(q * na^order) - kron_at_mul_b!(c, a, order, b, q, work1, work2) - kron_at_mul_b!(c, a, order, b, q, work1, work2) - @test c ≈ kron(kron(a', a'), Matrix{Float64}(I(q))) * b - end + # test2 + b = rand(q * ma^order) + c = rand(q * na^order) + kron_at_mul_b!(c, a, order, b, q, work1, work2) + kron_at_mul_b!(c, a, order, b, q, work1, work2) + @test c ≈ kron(kron(a', a'), Matrix{Float64}(I(q))) * b + end - @testset "a_mul_b_kron_c" begin - order = 2 - mc = 2 - nc = 3 - c = rand(mc, nc) - ma = 2 - na = 4 - a = rand(ma, na) - nb = nc^order - b = rand(na, nb) - d = rand(ma, mc^order) - work1 = rand(ma * max(mc, nc)^order) - work2 = similar(work1) - a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) - a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) - @test d ≈ a * b * kron(c', c') + @testset "a_mul_b_kron_c" begin + order = 2 + mc = 2 + nc = 3 + c = rand(mc, nc) + ma = 2 + na = 4 + a = rand(ma, na) + nb = nc^order + b = rand(na, nb) + d = rand(ma, mc^order) + work1 = rand(ma * max(mc, nc)^order) + work2 = similar(work1) + a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) + a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) + @test d ≈ a * b * kron(c', c') - # test4 - nb = mc^order - b = rand(ma, nb) - d = rand(na, nc^order) - work1 = rand(na * max(mc, nc)^order) - work2 = similar(work1) - at_mul_b_kron_c!(d, a, b, c, order, work1, work2) - at_mul_b_kron_c!(d, a, b, c, order, work1, work2) - @test d ≈ a' * b * kron(c, c) + # test4 + nb = mc^order + b = rand(ma, nb) + d = rand(na, nc^order) + work1 = rand(na * max(mc, nc)^order) + work2 = similar(work1) + at_mul_b_kron_c!(d, a, b, c, order, work1, work2) + at_mul_b_kron_c!(d, a, b, c, order, work1, work2) + @test d ≈ a' * b * kron(c, c) - # Doubtful usefulness, probably matching different combination of sizes - order = 2 - mc = 2 - nc = 3 - c = rand(mc, nc) - ma = 4 - na = 4 - a = rand(ma, na) - nb = nc^order - b = rand(na, nb) - d = rand(ma, mc^order) - work1 = rand(ma * max(mc, nc)^order) - work2 = similar(work1) - a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) - a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) - @test d ≈ a * b * kron(c', c') + order = 2 + mc = 2 + nc = 3 + c = rand(mc, nc) + ma = 4 + na = 4 + a = rand(ma, na) + nb = nc^order + b = rand(na, nb) + d = rand(ma, mc^order) + work1 = rand(ma * max(mc, nc)^order) + work2 = similar(work1) + a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) + a_mul_b_kron_ct!(d, a, b, c, order, work1, work2) + @test d ≈ a * b * kron(c', c') - #test4a - nb = mc^order - b = rand(ma, nb) - d = rand(na, nc^order) - work1 = rand(na * max(mc, nc)^order) - work2 = similar(work1) - at_mul_b_kron_c!(d, a, b, c, order, work1, work2) - at_mul_b_kron_c!(d, a, b, c, order, work1, work2) - @test d ≈ a' * b * kron(c, c) + #test4a + nb = mc^order + b = rand(ma, nb) + d = rand(na, nc^order) + work1 = rand(na * max(mc, nc)^order) + work2 = similar(work1) + at_mul_b_kron_c!(d, a, b, c, order, work1, work2) + at_mul_b_kron_c!(d, a, b, c, order, work1, work2) + @test d ≈ a' * b * kron(c, c) + end end - @testset "QuasiUpperTriangular" begin + @testset "Kron with QuasiUpper" begin order = 2 ma = 4 na = 4 @@ -281,7 +324,7 @@ end kron_at_mul_b!(c, a, order, b, q, work1, work2) @test c ≈ kron(kron(a', a'), Matrix{Float64}(I(q))) * b - # Testing the same thing with views + # Views of QuasiUpper order = 2 ma = 4 na = 4 @@ -305,7 +348,7 @@ end @test c ≈ kron(kron(a', a'), Matrix{Float64}(I(q))) * b end - @testset "Views" begin + @testset "Kron with views" begin order = 2 mc = 2 nc = 3