Skip to content

Commit

Permalink
test: add tests for some code branches
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar-Elrefaei committed Mar 9, 2023
1 parent f1a7b60 commit 3db577f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
5 changes: 1 addition & 4 deletions src/KroneckerTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ function kron_mul_elem!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix, b

begin
if p == 1 && q == 1
#FIXME: UNTESTED
# a*b
#OLD _mul!(c, offset_c, a, b, offset_b, 1)
unsafe_mul!(c, a, b, offset1=offset_c, offset3=offset_b)
Expand Down Expand Up @@ -93,12 +92,10 @@ function kron_mul_elem_t!(c::AbstractVector, offset_c::Int64, a::AbstractMatrix,

begin
if p == 1 && q == 1
#FIXME: UNTESTED
# a'*b
#OLD _mul!(c, offset_c, a', b, offset_b, 1)
unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=1)
elseif q == 1
#FIXME: UNTESTED
# (I_p ⊗ a')*b = vec(a'*[b_1 b_2 ... b_p])
#OLD _mul!(c, offset_c, a', b, offset_b, p)
unsafe_mul!(c, a', b; offset1=offset_c, offset3=offset_b, cols3=p)
Expand Down Expand Up @@ -134,6 +131,7 @@ We use vec(a*(b ⊗ b ⊗ ... ⊗ b)) = (b' ⊗ b' ⊗ ... ⊗ b' ⊗ I)vec(a)
"""
function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, order::Int64, work1::AbstractVector, work2::AbstractVector)
#FIXME: UNTESTED, this implementation with work vectors is untested
ma, na = size(a)
mb, nb = size(b)
mc, nc = size(c)
Expand Down Expand Up @@ -248,7 +246,6 @@ computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and w
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
mb,nb = size(b)
if order == 0
#FIXME: UNTESTED
#OLD _mul!(d,1,b,1,mb,nb,c,1,1)
unsafe_mul!(d, b, c; rows2=mb, cols2=nb, cols3=1)
else
Expand Down
56 changes: 56 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ end
kron_mul_elem!(c, a, b, p, q)
@test c kron(kron(I(p), a), I(q)) * b_orig

order = 1
ma = 2
na = 4
q = 1
p = na^(order - 1)
a = rand(ma, na)
b_orig = rand(q * na^order)
c = rand(ma * p * q)
b = copy(b_orig)
kron_mul_elem!(c, a, b, p, q)
@test c kron(kron(I(p), a), I(q)) * b_orig

ma = 2
na = 2
a = randn(ma, na)
Expand Down Expand Up @@ -225,6 +237,50 @@ end
kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2)
@test d kron(kron(a', a'), b) * c

# 3 more calls to cover some branches
order = 2
ma = 2
na = 4
a = randn(ma, na)
mb = 1
nb = 8
b = randn(mb, nb)
c = randn(ma * ma * nb)
d = randn(na * na * mb)
work1 = rand(na * na * mb)
work2 = rand(na * na * mb)
kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2)
kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2)
@test d kron(kron(a', a'), b) * c

order = 1
ma = 2
na = 4
a = randn(ma, na)
mb = 1
nb = 8
b = randn(mb, nb)
c = randn(ma^order * nb)
d = randn(na^order * mb)
work1 = rand(na^order * mb)
work2 = rand(na^order * mb)
kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2)
@test d kron(a', b) * c

order = 0
ma = 2
na = 4
a = randn(ma, na)
mb = 2
nb = 8
b = randn(mb, nb)
c = randn(ma^order * nb)
d = randn(na^order * mb)
work1 = rand(na^order * mb)
work2 = rand(na^order * mb)
kron_at_kron_b_mul_c!(d, a, order, b, c, work1, work2)
@test d b * c

@testset "kron_a_mul_b" begin
order = 2
ma = 2
Expand Down

0 comments on commit 3db577f

Please sign in to comment.