Skip to content

Commit

Permalink
reimplement ExtendedMul interface to use kwargs
Browse files Browse the repository at this point in the history
- Major reorganization
- Add proper QuasiUpperTriangular support in ExtendedMul directy
  • Loading branch information
Omar-Elrefaei committed Mar 8, 2023
1 parent d484fd5 commit b0bcb3f
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 230 deletions.
230 changes: 144 additions & 86 deletions src/ExtendedMul.jl
Original file line number Diff line number Diff line change
@@ -1,105 +1,163 @@
using LinearAlgebra: BlasInt
using LinearAlgebra.BLAS
import LinearAlgebra.BLAS: gemm!, @blasfunc, libblas

export _mul!

function gemm!(ta::Char, tb::Char, alpha::Float64,
a::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, ma::Int64, na::Int64,
b::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64}, AbstractVecOrMat{Float64}})
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
ta, tb, ma, nb,
na, alpha, a, max(ma, 1),
b, max(na, 1), beta, c,
max(ma, 1))
end
import LinearAlgebra.BLAS: @blasfunc, libblas

export unsafe_mul!

function ext_gemm!(ta::Char, tb::Char, ma::Int, nb::Int, na::Int, alpha,
a::StridedVecOrMat{Float64}, b::StridedVecOrMat{Float64}, beta,
c::StridedVecOrMat{Float64}, offset_a, offset_b, offset_c)
lda = ndims(a) == 2 ? strides(a)[2] : ma
ldb = ndims(b) == 2 ? strides(b)[2] : na
ldc = ndims(c) == 2 ? strides(c)[2] : ma

# a * b, 3 methods (A_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c))
ccall((@blasfunc("dgemm_"), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{BlasInt}), ta, tb, ma, nb, na, alpha, Ref(a, offset_a), lda,
Ref(b, offset_b), ldb, beta, Ref(c, offset_c), ldc)
return c
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64})
nb = (ndims(b) > 1) ? size(b, 2) : 1
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::StridedVecOrMat;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('N', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ma = size(a, 1)
na = (ndims(a) > 1) ? size(a, 2) : 1
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat,
bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1),
cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2))
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('N', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

# a' * b, 2 methods (At_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(na, 1),
Ref(b, offset_b), max(na, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedArray},
b::StridedVecOrMat; offset1::Int = 1, offset2::Int = 1,
offset3::Int = 1, rows2::Int = size(aAdj, 1),
cols2::Int = size(aAdj, 2), cols3::Int = size(b, 2))
a = aAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('T', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
na = size(a.parent, 1)
ma = (ndims(a.parent) > 1) ? size(a.parent, 2) : 1
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedVecOrMat},
bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(aAdj, 1),
cols2::Int = size(aAdj, 2), cols3::Int = size(bAdj, 2))
a = aAdj.parent
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('T', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

# a * b', 2 methods (A_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
if typeof(b) <: Adjoint{Float64, QuasiUpperTriangular{Float64, Matrix{Float64}}}
function blas_check(c, a, b, offset_c, offset_a, offset_b, ma, na, nb)
@boundscheck begin
# Make sure offset is a sane value
@assert (length(a) >= offset_a >= 1)
@assert (length(b) >= offset_b >= 1)
@assert (length(c) >= offset_c >= 1)
# Assert that there is enough data in each variable
_mb = na
@assert ma * na<=length(a) - offset_a + 1 "You're asking from A more than it has"
@assert _mb * nb<=length(b) - offset_b + 1 "You're asking from B more than it has"
@assert ma * nb<=length(c) - offset_c + 1 "You're assigning into C more than it can take"

#TODO: if A is a vector, ma is supplied and na is not. Do not default to na=1
end
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'N', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(ma, 1),
Ref(b, offset_b), max(nb, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64})
nb = size(b.parent, 1)
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
## QuasiUpperTriangular

# B = alpha*A*B such that A is triangular
function ext_trmm!(side::Char, uplo::Char, ta::Char, diag::Char, mb::Int, nb::Int, alpha,
a::StridedMatrix{Float64}, b::StridedVecOrMat{Float64}, offset_b)
# intentionally assume the triangular data is not in vector form
#TODO: Reconsider ma/na vs mb/nb. trmm docs usage might conflict a little with ours

# following trmm docs
lda = uppercase(side) == 'L' ? mb : nb
ldb = ndims(b) == 2 ? strides(b)[2] : mb

ccall((@blasfunc("dtrmm_"), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}), side,
uplo, ta, diag, mb, nb, alpha, a, lda, Ref(b, offset_b), ldb)
return b
end

# a' * b', 1 method (At_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(na, 1),
Ref(b, offset_b), max(ma, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
# A is QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::QuasiUpperTriangular, b::StridedVecOrMat;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

rows3 = cols2
copyto!(c, offset1, b, offset3, rows3 * cols3)
alpha = 1.0
ext_trmm!('L', 'U', 'N', 'N', rows3, cols3, alpha, a.data, c, offset1)

@inbounds for i in 2:rows2
x = a[i, i - 1]
indb = offset3
indc = offset1 + 1
@simd for j in 1:cols3
c[indc] += x * b[indb]
indb += rows2
indc += rows2
end
end
end

# B is QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::QuasiUpperTriangular;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

copyto!(c, offset1, a, offset2, rows2 * cols2)
alpha = 1.0
ext_trmm!('R', 'U', 'N', 'N', rows2, cols2, alpha, b.data, c, offset1)

inda = offset2 + rows2
indc = offset1
@inbounds for i in 2:cols2
x = b[i, i - 1]
@simd for j in 1:rows2
c[indc] += x * a[inda]
inda += 1
indc += 1
end
end
end

# B is an Adjoint of QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat,
bAdj::Adjoint{Float64, <:QuasiUpperTriangular}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1),
cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2))
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

copyto!(c, offset1, a, offset2, rows2 * cols2)
alpha = 1.0
ext_trmm!('R', 'U', 'T', 'N', rows2, cols2, alpha, b.data, c, offset1)

inda = offset2
indc = offset1 + rows2
@inbounds for j in 2:cols2
x = alpha * b[j, j - 1]
@simd for i in 1:rows2
c[indc] += x * a[inda]
inda += 1
indc += 1
end
end
c
end
51 changes: 33 additions & 18 deletions src/KroneckerTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module KroneckerTools
using LinearAlgebra
using LinearAlgebra.BLAS
using QuasiTriangular
import QuasiTriangular: A_mul_B!, At_mul_B!, A_mul_Bt!
import Base.convert

include("ExtendedMul.jl")
export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct!
export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!,
at_mul_b_kron_c!, a_mul_b_kron_ct!, kron_a_mul_b!, kron_at_mul_b!

"""
Content:
Expand Down Expand Up @@ -51,19 +51,23 @@ function kron_mul_elem!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix, b
begin
if p == 1 && q == 1
# a*b
_mul!(c, offset_c, a, b, offset_b, 1)
#OLD _mul!(c, offset_c, a, b, offset_b, 1)
unsafe_mul!(c, a, b)
elseif q == 1
# (I_p ⊗ a)*b = vec(a*[b_1 b_2 ... b_p])
_mul!(c, offset_c, a, b, offset_b, p)
#OLD _mul!(c, offset_c, a, b, offset_b, p)
unsafe_mul!(c, a, b; offset1=offset_c, offset3=offset_b, cols3=p)
elseif p == 1
# (a ⊗ I_q)*b = (b'*(a' ⊗ I_q))' = vec(reshape(b,q,m)*a')
_mul!(c, offset_c, b, offset_b, q, n, a')
#OLD _mul!(c, offset_c, b, offset_b, q, n, a')
unsafe_mul!(c, b, a'; offset1=offset_c, offset2=offset_b, rows2=q, cols2=n)
else
# (I_p ⊗ a ⊗ I_q)*b = vec([(a ⊗ I_q)*b_1 (a ⊗ I_q)*b_2 ... (a ⊗ I_q)*b_p])
mq = m*q
nq = n*q
for i=1:p
_mul!(c, offset_c, b, offset_b, q, n, a')
#OLD _mul!(c, offset_c, b, offset_b, q, n, a')
unsafe_mul!(c, b, a'; offset1=offset_c, offset2=offset_b, rows2=q, cols2=n)
offset_b += nq
offset_c += mq
end
Expand All @@ -89,19 +93,24 @@ function kron_mul_elem_t!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix,
begin
if p == 1 && q == 1
# a'*b
_mul!(c, offset_c, a', b, offset_b, 1)
#OLD _mul!(c, offset_c, a', b, offset_b, 1)
unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=1)
elseif q == 1
# (I_p ⊗ a')*b = vec(a'*[b_1 b_2 ... b_p])
_mul!(c, offset_c, a', b, offset_b, p)
#OLD _mul!(c, offset_c, a', b, offset_b, p)
unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=p)

elseif p == 1
# (a' ⊗ I_q)*b = (b'*(a ⊗ I_q))' = vec(reshape(b,q,m)*a)
_mul!(c, offset_c, b, offset_b, q, m, a)
#OLD _mul!(c, offset_c, b, offset_b, q, m, a)
unsafe_mul!(c, b, a; offset1=offset_c, offset2=offset_b, rows2=q, cols2=m)
else
# (I_p ⊗ a' ⊗ I_q)*b = vec([(a' ⊗ I_q)*b_1 (a' ⊗ I_q)*b_2 ... (a' ⊗ I_q)*b_p])
mq = m*q
nq = n*q
for i=1:p
_mul!(c, offset_c, b, offset_b, q, m, a)
#OLD _mul!(c, offset_c, b, offset_b, q, m, a)
unsafe_mul!(c, b, a; offset1=offset_c, offset2=offset_b, rows2=q, cols2=m)
offset_b += mq
offset_c += nq
end
Expand Down Expand Up @@ -207,7 +216,8 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using data from c at offset_c
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64)
mb,nb = size(b)
if order == 0
_mul!(d,offset_d,b,1,mb,nb,c,offset_c,1)
#OLD _mul!(d, offset_d, b, 1, mb, nb, c, offset_c, 1)
unsafe_mul!(d, b, c; offset1=offset_d, offset3=offset_c, cols3=1)
else
ma, na = size(a)
# length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))"))
Expand All @@ -234,7 +244,8 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and w
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
mb,nb = size(b)
if order == 0
_mul!(d,1,b,1,mb,nb,c,1,1)
#OLD _mul!(d,1,b,1,mb,nb,c,1,1)
unsafe_mul!(d, b, c)
else
ma, na = size(a)
# length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))"))
Expand Down Expand Up @@ -289,7 +300,6 @@ function kron_at_mul_b!(c::AbstractVector, a::AbstractMatrix, order::Int64, b::A
end
copyto!(c,1,work2,1,s)
end

function kron_a_mul_b!(c::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractVector, q::Int64, work1::AbstractVector, work2::AbstractVector)
ma,na = size(a)
# length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))"))
Expand All @@ -313,11 +323,13 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
mb, nb = size(b)
mc, nc = size(c)
if mc <= nc
_mul!(work1, 1, a', 1, na, ma, b, 1, nb)
#OLD _mul!(work1, 1, a', 1, na, ma, b, 1, nb)
unsafe_mul!(work1, a', b; rows2=na, cols2=ma, cols3=nb)
kron_at_mul_b!(vec(d), c, order, work1, na, vec(d), work2)
else
kron_at_mul_b!(work1, c, order, b, na, work1, work2)
_mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
#OLD _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
unsafe_mul!(vec(d), a', work1; rows2=na, cols2=ma, cols3=nc^order)
end
end

Expand All @@ -326,11 +338,13 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
mb, nb = size(b)
mc, nc = size(c)
if mc <= nc
_mul!(work1, 1, a, 1, ma, na, b, 1, nb)
#OLD _mul!(work1, 1, a, 1, ma, na, b, 1, nb)
unsafe_mul!(work1, a, b)
kron_a_mul_b!(vec(d), c, order, work1, ma, work1, work2)
else
kron_a_mul_b!(work1, b, order, c, mb, work1, work2)
_mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
#OLD _mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
unsafe_mul!(vec(d), a', work1; rows2=a, cols2=ma, cols3=nc^order)
end
end

Expand Down Expand Up @@ -389,7 +403,8 @@ function a_mul_b_kron_c_d!(e::AbstractMatrix, a::AbstractMatrix, b::AbstractMatr
na == mb || throw(DimensionMismatch("The number of columns of a, $(size(a,2)), doesn't match the number of rows of b, $(size(b,1))"))
nb == mc*md^(order-1) || throw(DimensionMismatch("The number of columns of b, $(size(b,2)), doesn't match the number of rows of c, $(size(c,1)), and d, $(size(d,1)) for order, $order"))
(ma == me && nc*nd^(order-1) == ne) || throw(DimensionMismatch("Dimension mismatch for e: $(size(e)) while ($ma, $(nc*nd^(order-1))) was expected"))
_mul!(work1, 1, a, 1, ma, mb, vec(b), 1, nb)
# OLD _mul!(work1, 1, a, 1, ma, mb, b, 1, nb)
unsafe_mul!(work1, a, b; rows2=ma, cols2=mb, cols3=nb)
p = mc*md^(order - 2)
q = ma
for i = 0:order - 2
Expand Down
Loading

0 comments on commit b0bcb3f

Please sign in to comment.