Skip to content

Commit

Permalink
reimplement ExtendedMul interface to use kwargs
Browse files Browse the repository at this point in the history
- Major reorganization
- Add proper QuasiUpperTriangular support in ExtendedMul directy
  • Loading branch information
Omar-Elrefaei committed Mar 8, 2023
1 parent d484fd5 commit 8462e82
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 231 deletions.
230 changes: 144 additions & 86 deletions src/ExtendedMul.jl
Original file line number Diff line number Diff line change
@@ -1,105 +1,163 @@
using LinearAlgebra: BlasInt
using LinearAlgebra.BLAS
import LinearAlgebra.BLAS: gemm!, @blasfunc, libblas

export _mul!

function gemm!(ta::Char, tb::Char, alpha::Float64,
a::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, ma::Int64, na::Int64,
b::Union{Ref{Float64}, AbstractVecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64}, AbstractVecOrMat{Float64}})
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
ta, tb, ma, nb,
na, alpha, a, max(ma, 1),
b, max(na, 1), beta, c,
max(ma, 1))
end
import LinearAlgebra.BLAS: @blasfunc, libblas

export unsafe_mul!

function ext_gemm!(ta::Char, tb::Char, ma::Int, nb::Int, na::Int, alpha,
a::StridedVecOrMat{Float64}, b::StridedVecOrMat{Float64}, beta,
c::StridedVecOrMat{Float64}, offset_a, offset_b, offset_c)
lda = ndims(a) == 2 ? strides(a)[2] : ma
ldb = ndims(b) == 2 ? strides(b)[2] : na
ldc = ndims(c) == 2 ? strides(c)[2] : ma

# a * b, 3 methods (A_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c))
ccall((@blasfunc("dgemm_"), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{BlasInt}), ta, tb, ma, nb, na, alpha, Ref(a, offset_a), lda,
Ref(b, offset_b), ldb, beta, Ref(c, offset_c), ldc)
return c
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64})
nb = (ndims(b) > 1) ? size(b, 2) : 1
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::StridedVecOrMat;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('N', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ma = size(a, 1)
na = (ndims(a) > 1) ? size(a, 2) : 1
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat,
bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1),
cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2))
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('N', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

# a' * b, 2 methods (At_mul_B!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(na, 1),
Ref(b, offset_b), max(na, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedArray},
b::StridedVecOrMat; offset1::Int = 1, offset2::Int = 1,
offset3::Int = 1, rows2::Int = size(aAdj, 1),
cols2::Int = size(aAdj, 2), cols3::Int = size(b, 2))
a = aAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('T', 'N', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64},
b::AbstractVecOrMat{Float64}, offset_b::Int64, nb::Int64)
na = size(a.parent, 1)
ma = (ndims(a.parent) > 1) ? size(a.parent, 2) : 1
_mul!(c, offset_c, a, 1, ma, na, b, offset_b, nb)
function unsafe_mul!(c::StridedVecOrMat, aAdj::Adjoint{Float64, <:StridedVecOrMat},
bAdj::Adjoint{Float64, <:StridedVecOrMat}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(aAdj, 1),
cols2::Int = size(aAdj, 2), cols3::Int = size(bAdj, 2))
a = aAdj.parent
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)
ext_gemm!('T', 'T', rows2, cols3, cols2, 1, a, b, 0, c, offset2, offset3, offset1)
end

# a * b', 2 methods (A_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
if typeof(b) <: Adjoint{Float64, QuasiUpperTriangular{Float64, Matrix{Float64}}}
function blas_check(c, a, b, offset_c, offset_a, offset_b, ma, na, nb)
@boundscheck begin
# Make sure offset is a sane value
@assert (length(a) >= offset_a >= 1)
@assert (length(b) >= offset_b >= 1)
@assert (length(c) >= offset_c >= 1)
# Assert that there is enough data in each variable
_mb = na
@assert ma * na<=length(a) - offset_a + 1 "You're asking from A more than it has"
@assert _mb * nb<=length(b) - offset_b + 1 "You're asking from B more than it has"
@assert ma * nb<=length(c) - offset_c + 1 "You're assigning into C more than it can take"

#TODO: if A is a vector, ma is supplied and na is not. Do not default to na=1
end
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'N', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(ma, 1),
Ref(b, offset_b), max(nb, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
end

function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::AbstractVecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64})
nb = size(b.parent, 1)
_mul!(c, offset_c, a, offset_a, ma, na, b, 1, nb)
## QuasiUpperTriangular

# B = alpha*A*B such that A is triangular
function ext_trmm!(side::Char, uplo::Char, ta::Char, diag::Char, mb::Int, nb::Int, alpha,
a::StridedMatrix{Float64}, b::StridedVecOrMat{Float64}, offset_b)
# intentionally assume the triangular data is not in vector form
#TODO: Reconsider ma/na vs mb/nb. trmm docs usage might conflict a little with ours

# following trmm docs
lda = uppercase(side) == 'L' ? mb : nb
ldb = ndims(b) == 2 ? strides(b)[2] : mb

ccall((@blasfunc("dtrmm_"), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}), side,
uplo, ta, diag, mb, nb, alpha, a, lda, Ref(b, offset_b), ldb)
return b
end

# a' * b', 1 method (At_mul_Bt!)
function _mul!(c::AbstractVecOrMat{Float64}, offset_c::Int64,
a::Adjoint{Float64}, offset_a::Int64, ma::Int64, na::Int64,
b::Adjoint{Float64}, offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(na, 1),
Ref(b, offset_b), max(ma, 1), 0.0, Ref(c, offset_c),
max(ma, 1))
# A is QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::QuasiUpperTriangular, b::StridedVecOrMat;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

rows3 = cols2
copyto!(c, offset1, b, offset3, rows3 * cols3)
alpha = 1.0
ext_trmm!('L', 'U', 'N', 'N', rows3, cols3, alpha, a.data, c, offset1)

@inbounds for i in 2:rows2
x = a[i, i - 1]
indb = offset3
indc = offset1 + 1
@simd for j in 1:cols3
c[indc] += x * b[indb]
indb += rows2
indc += rows2
end
end
end

# B is QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat, b::QuasiUpperTriangular;
offset1::Int = 1, offset2::Int = 1, offset3::Int = 1,
rows2::Int = size(a, 1), cols2::Int = size(a, 2),
cols3::Int = size(b, 2))
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

copyto!(c, offset1, a, offset2, rows2 * cols2)
alpha = 1.0
ext_trmm!('R', 'U', 'N', 'N', rows2, cols2, alpha, b.data, c, offset1)

inda = offset2 + rows2
indc = offset1
@inbounds for i in 2:cols2
x = b[i, i - 1]
@simd for j in 1:rows2
c[indc] += x * a[inda]
inda += 1
indc += 1
end
end
end

# B is an Adjoint of QuasiUpper
function unsafe_mul!(c::StridedVecOrMat, a::StridedVecOrMat,
bAdj::Adjoint{Float64, <:QuasiUpperTriangular}; offset1::Int = 1,
offset2::Int = 1, offset3::Int = 1, rows2::Int = size(a, 1),
cols2::Int = size(a, 2), cols3::Int = size(bAdj, 2))
b = bAdj.parent
blas_check(c, a, b, offset1, offset2, offset3, rows2, cols2, cols3)

copyto!(c, offset1, a, offset2, rows2 * cols2)
alpha = 1.0
ext_trmm!('R', 'U', 'T', 'N', rows2, cols2, alpha, b.data, c, offset1)

inda = offset2
indc = offset1 + rows2
@inbounds for j in 2:cols2
x = alpha * b[j, j - 1]
@simd for i in 1:rows2
c[indc] += x * a[inda]
inda += 1
indc += 1
end
end
c
end
Loading

0 comments on commit 8462e82

Please sign in to comment.