Skip to content

Commit

Permalink
rename ExtendedMul functions to a unified _mul! interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar-Elrefaei committed Feb 17, 2023
1 parent f53daf4 commit d484fd5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 61 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "sciml"
74 changes: 38 additions & 36 deletions src/ExtendedMul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using LinearAlgebra: BlasInt
using LinearAlgebra.BLAS
import LinearAlgebra.BLAS: gemm!, @blasfunc, libblas

export A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_B!
export _mul!

function gemm!(ta::Char, tb::Char, alpha::Float64,
a::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, ma::Int64, na::Int64,
Expand All @@ -19,33 +19,33 @@ function gemm!(ta::Char, tb::Char, alpha::Float64,
max(ma, 1))
end

# A_mul_B!
function A_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
# a * b, 3 methods (A_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c))
end

function A_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64})
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64})
nb = (ndims(b) > 1) ? size(b, 2) : 1
A_mul_B!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
end

function A_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ma = size(a, 1)
na = (ndims(a) > 1) ? size(a, 2) : 1
A_mul_B!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
end

# At_mul_B!
function At_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
# a' * b, 2 methods (At_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Expand All @@ -57,18 +57,20 @@ function At_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
max(ma, 1))
end

function At_mul_B!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
na = size(a, 1)
ma = (ndims(a) > 1) ? size(a, 2) : 1
At_mul_B!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
na = size(a.parent, 1)
ma = (ndims(a.parent) > 1) ? size(a.parent, 2) : 1
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
end

# A_mul_Bt!
function A_mul_Bt!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
# a * b', 2 methods (A_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
if typeof(b) <: Adjoint{Float64, QuasiUpperTriangular{Float64, Matrix{Float64}}}
end
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Expand All @@ -80,17 +82,17 @@ function A_mul_Bt!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
max(ma, 1))
end

function A_mul_Bt!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64})
nb = size(b, 1)
A_mul_Bt!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64})
nb = size(b.parent, 1)
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
end

# At_mul_Bt!
function At_mul_Bt!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
# a' * b', 1 method (At_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Expand Down
30 changes: 15 additions & 15 deletions src/KroneckerTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ function kron_mul_elem!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix, b
begin
if p == 1 && q == 1
# a*b
A_mul_B!(c, offset_c, a, b, offset_b, 1)
_mul!(c, offset_c, a, b, offset_b, 1)
elseif q == 1
# (I_p ⊗ a)*b = vec(a*[b_1 b_2 ... b_p])
A_mul_B!(c, offset_c, a, b, offset_b, p)
_mul!(c, offset_c, a, b, offset_b, p)
elseif p == 1
# (a ⊗ I_q)*b = (b'*(a' ⊗ I_q))' = vec(reshape(b,q,m)*a')
A_mul_Bt!(c, offset_c, b, offset_b, q, n, a)
_mul!(c, offset_c, b, offset_b, q, n, a')
else
# (I_p ⊗ a ⊗ I_q)*b = vec([(a ⊗ I_q)*b_1 (a ⊗ I_q)*b_2 ... (a ⊗ I_q)*b_p])
mq = m*q
nq = n*q
for i=1:p
A_mul_Bt!(c, offset_c, b, offset_b, q, n, a)
_mul!(c, offset_c, b, offset_b, q, n, a')
offset_b += nq
offset_c += mq
end
Expand All @@ -89,19 +89,19 @@ function kron_mul_elem_t!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix,
begin
if p == 1 && q == 1
# a'*b
At_mul_B!(c, offset_c, a, b, offset_b, 1)
_mul!(c, offset_c, a', b, offset_b, 1)
elseif q == 1
# (I_p ⊗ a')*b = vec(a'*[b_1 b_2 ... b_p])
At_mul_B!(c, offset_c, a, b, offset_b, p)
_mul!(c, offset_c, a', b, offset_b, p)
elseif p == 1
# (a' ⊗ I_q)*b = (b'*(a ⊗ I_q))' = vec(reshape(b,q,m)*a)
A_mul_B!(c, offset_c, b, offset_b, q, m, a)
_mul!(c, offset_c, b, offset_b, q, m, a)
else
# (I_p ⊗ a' ⊗ I_q)*b = vec([(a' ⊗ I_q)*b_1 (a' ⊗ I_q)*b_2 ... (a' ⊗ I_q)*b_p])
mq = m*q
nq = n*q
for i=1:p
A_mul_B!(c, offset_c, b, offset_b, q, m, a)
_mul!(c, offset_c, b, offset_b, q, m, a)
offset_b += mq
offset_c += nq
end
Expand Down Expand Up @@ -207,7 +207,7 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using data from c at offset_c
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64)
mb,nb = size(b)
if order == 0
A_mul_B!(d,offset_d,b,1,mb,nb,c,offset_c,1)
_mul!(d,offset_d,b,1,mb,nb,c,offset_c,1)
else
ma, na = size(a)
# length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))"))
Expand All @@ -234,7 +234,7 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and w
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
mb,nb = size(b)
if order == 0
A_mul_B!(d,1,b,1,mb,nb,c,1,1)
_mul!(d,1,b,1,mb,nb,c,1,1)
else
ma, na = size(a)
# length(work) == naorder*mb || throw(DimensionMismatch("The dimension of vector , $(length(c)) doesn't correspond to order, ($order) and the dimension of the matrices a, $(size(a)), and b, $(size(b))"))
Expand Down Expand Up @@ -313,11 +313,11 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
mb, nb = size(b)
mc, nc = size(c)
if mc <= nc
At_mul_B!(work1, 1, a, 1, na, ma, b, 1, nb)
_mul!(work1, 1, a', 1, na, ma, b, 1, nb)
kron_at_mul_b!(vec(d), c, order, work1, na, vec(d), work2)
else
kron_at_mul_b!(work1, c, order, b, na, work1, work2)
At_mul_B!(vec(d), 1, a, 1, na, ma, work1, 1, nc^order)
_mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
end
end

Expand All @@ -326,11 +326,11 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
mb, nb = size(b)
mc, nc = size(c)
if mc <= nc
A_mul_B!(work1, 1, a, 1, ma, na, b, 1, nb)
_mul!(work1, 1, a, 1, ma, na, b, 1, nb)
kron_a_mul_b!(vec(d), c, order, work1, ma, work1, work2)
else
kron_a_mul_b!(work1, b, order, c, mb, work1, work2)
At_mul_B!(vec(d), 1, a, 1, na, ma, work1, 1, nc^order)
_mul!(vec(d), 1, a', 1, na, ma, work1, 1, nc^order)
end
end

Expand Down Expand Up @@ -389,7 +389,7 @@ function a_mul_b_kron_c_d!(e::AbstractMatrix, a::AbstractMatrix, b::AbstractMatr
na == mb || throw(DimensionMismatch("The number of columns of a, $(size(a,2)), doesn't match the number of rows of b, $(size(b,1))"))
nb == mc*md^(order-1) || throw(DimensionMismatch("The number of columns of b, $(size(b,2)), doesn't match the number of rows of c, $(size(c,1)), and d, $(size(d,1)) for order, $order"))
(ma == me && nc*nd^(order-1) == ne) || throw(DimensionMismatch("Dimension mismatch for e: $(size(e)) while ($ma, $(nc*nd^(order-1))) was expected"))
A_mul_B!(work1, 1, a, 1, ma, mb, vec(b), 1, nb)
_mul!(work1, 1, a, 1, ma, mb, vec(b), 1, nb)
p = mc*md^(order - 2)
q = ma
for i = 0:order - 2
Expand Down
39 changes: 29 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using LinearAlgebra
using Random
using Test
import KroneckerTools: kron_mul_elem!, kron_mul_elem_t!, a_mul_kron_b!,
a_mul_b_kron_c!, a_mul_b_kron_ct!, at_mul_b_kron_c!,
a_mul_b_kron_c_d!, kron_a_mul_b!, kron_at_mul_b!,
kron_at_kron_b_mul_c!, QuasiTriangular.QuasiUpperTriangular,
A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt!
import KroneckerTools: kron_mul_elem!, kron_mul_elem_t!, a_mul_kron_b!, a_mul_b_kron_c!,
a_mul_b_kron_ct!, at_mul_b_kron_c!, a_mul_b_kron_c_d!, kron_a_mul_b!,
kron_at_mul_b!, kron_at_kron_b_mul_c!,
QuasiTriangular.QuasiUpperTriangular,
_mul!

Random.seed!(123)
#a = rand(2,3)
Expand Down Expand Up @@ -45,34 +45,53 @@ for m in [1, 3]
end
=#

@testset "KroneckerTools" verbose = true begin
@testset "KroneckerTools" verbose=true begin
@testset "ExtendedMul" begin
# a * b, 3 methods
ma = 2
na = 4
a = randn(ma, na)
mb = 4
nb = 3
b = randn(mb, nb)
c = randn(ma, nb)
A_mul_B!(c, 1, a, 1, ma, na, b, 1, nb)
_mul!(c, 1, a, 1, ma, na, b, 1, nb)
@test c a * b
c = randn(ma, nb)
_mul!(c, 1, a, 1, ma, na, b)
@test c a * b
c = randn(ma, nb)
_mul!(c, 1, a, b, 1, nb)
@test c a * b

# a' * b, 2 methods
mb = 2
nb = 3
b = randn(mb, nb)
c = randn(na, nb)
At_mul_B!(c, 1, a, 1, na, ma, b, 1, nb)
_mul!(c, 1, a', 1, na, ma, b, 1, nb)
@test c transpose(a) * b
c = randn(na, nb)
_mul!(c, 1, a', b, 1, nb)
@test c a' * b

# a * b', 2 methods
nb = 4
mb = 3
b = randn(mb, nb)
c = randn(ma, mb)
A_mul_Bt!(c, 1, a, 1, ma, na, b, 1, mb)
_mul!(c, 1, a, 1, ma, na, b', 1, mb)
@test c a * transpose(b)
c = randn(ma, mb)
_mul!(c, 1, a, 1, ma, na, b')
@test c a * transpose(b)

# a' * b', 1 method
mb = 4
nb = 2
b = randn(mb, nb)
c = randn(na, mb)
At_mul_Bt!(c, 1, a, 1, na, ma, b, 1, mb)
_mul!(c, 1, a', 1, na, ma, b', 1, mb)
@test c transpose(a) * transpose(b)
end

Expand Down

0 comments on commit d484fd5

Please sign in to comment.