Skip to content

Commit

Permalink
Allow Y to be a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 12, 2019
1 parent d745d7e commit 38541dc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ isdiagtype(::Number) = true

matsize(A) = (size(A, 1), size(A, 2))

@inline unsafe_column(A::AbstractMatrix, k) = @inbounds view(A, :, k)
@inline unsafe_column(A::AbstractVecOrMat, k) = @inbounds view(A, :, k)

rmul_or_fill!((Y, β)::Tuple{AbstractVecOrMat, Number}) = rmul_or_fill!(Y, β)
function rmul_or_fill!(Y::AbstractVecOrMat, β::Number)
Expand Down
40 changes: 20 additions & 20 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,30 +245,30 @@ with `m = $(length(Yβ))` and `n = $(length(rhs))`. Note that `m` and `n` must
end
end

ifisa Tuple{Vararg{AbstractMatrix}}
n1, n4 = size(Yβ[1])
ifisa Tuple{Vararg{AbstractVecOrMat}}
n1, n4 = matsize(Yβ[1])
if !all(let dim = (n1, n4)
Y -> size(Y) == dim
Y -> matsize(Y) == dim
end,
Yβ)
throw(ArgumentError("""
Matrices `Y1`, ..., `Yn` passed to `fmul_shared!((Y1, ..., Yn), ...)` do not
have uniform `size`."""))
end
elseifisa Tuple{Vararg{Tuple{AbstractMatrix,Number}}}
n1, n4 = size(Yβ[1][1])
elseifisa Tuple{Vararg{Tuple{AbstractVecOrMat,Number}}}
n1, n4 = matsize(Yβ[1][1])
if !all(let dim = (n1, n4)
((Y, _),) -> size(Y) == dim
((Y, _),) -> matsize(Y) == dim
end,
Yβ)
throw(ArgumentError("""
Matrices `Y1`, ..., `Yn` passed to `fmul_shared!(((Y1, β1), ..., (Yn, βn)), ...)`
do not have uniform `size`."""))
end
elseifisa AbstractMatrix
n1, n4 = size(Yβ)
elseifisa Tuple{AbstractMatrix,Number}
n1, n4 = size(Yβ[1])
elseifisa AbstractVecOrMat
n1, n4 = matsize(Yβ)
elseifisa Tuple{AbstractVecOrMat,Number}
n1, n4 = matsize(Yβ[1])
else
throw(ArgumentError("""
Unsupported type for first argument of `fmul_shared!`:
Expand Down Expand Up @@ -352,11 +352,11 @@ end
return (Diagonal(asdiag(D, size(S, 1))), Base.tail(DSX)...)
end

@inline canonicalize_Yβ(Yβ::Tuple{AbstractMatrix,Number}) =
@inline canonicalize_Yβ(Y::AbstractMatrix) = (Y, false)
@inline canonicalize_Yβ(Yβ::Tuple{AbstractVecOrMat,Number}) =
@inline canonicalize_Yβ(Y::AbstractVecOrMat) = (Y, false)

@inline canonicalize_Yβ(Yβ::Tuple{Vararg{Tuple{AbstractMatrix,Number}}}) =
@inline canonicalize_Yβ(Ys::Tuple{Vararg{AbstractMatrix}}) =
@inline canonicalize_Yβ(Yβ::Tuple{Vararg{Tuple{AbstractVecOrMat,Number}}}) =
@inline canonicalize_Yβ(Ys::Tuple{Vararg{AbstractVecOrMat}}) =
map(canonicalize_Yβ, Ys)

@inline function is_shared_simd3(triplets)
Expand All @@ -378,12 +378,12 @@ end
all(((_, S),) -> isnzshared(t1[2], S), middle)
end

@inline function preprocess_Yβ(Yβ::Tuple{AbstractMatrix,Number})
@inline function preprocess_Yβ(Yβ::Tuple{AbstractVecOrMat,Number})
rmul_or_fill!(Yβ)
return Yβ[1]
end

@inline function preprocess_Yβ(Yβs::Tuple{Vararg{Tuple{AbstractMatrix,Number}}})
@inline function preprocess_Yβ(Yβs::Tuple{Vararg{Tuple{AbstractVecOrMat,Number}}})
rmul_or_fill_many!(Yβs...)
return map(first, Yβs)
end
Expand All @@ -400,7 +400,7 @@ end
) where {N}

Y = preprocess_Yβ(Yβ)
Y :: Union{AbstractMatrix,Tuple{Vararg{AbstractMatrix}}}
Y :: Union{AbstractVecOrMat,Tuple{Vararg{AbstractVecOrMat}}}

lane = VecRange{N}(0)

Expand Down Expand Up @@ -451,15 +451,15 @@ end
return Y
end

@inline function update_Y!(Y::AbstractMatrix, diags, accs, col, k)
@inline function update_Y!(Y::AbstractVecOrMat, diags, accs, col, k)
prods = map(diags, accs) do diag, acc
@inbounds diag[col] * acc
end
@inbounds Y[col, k] += +(prods...)
return
end

@inline function update_Y!(Ys::Tuple{Vararg{AbstractMatrix}},
@inline function update_Y!(Ys::Tuple{Vararg{AbstractVecOrMat}},
diags, accs, col, k)
map(Ys, diags, accs) do Y, diag, acc
@inbounds Y[col, k] += diag[col] * acc
Expand All @@ -485,7 +485,7 @@ compute_vaccs(::Tuple{}, ::Tuple{}, ::Tuple{}, _, _) = ()
) where {N}

Y = preprocess_Yβ(Yβ)
Y :: Union{AbstractMatrix,Tuple{Vararg{AbstractMatrix}}}
Y :: Union{AbstractVecOrMat,Tuple{Vararg{AbstractVecOrMat}}}

lane = VecRange{N}(0)

Expand Down
19 changes: 16 additions & 3 deletions test/test_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fmul_shared_test_params = let params = []
randn!(nonzeros(S2))
randn!(nonzeros(S3))
push!(params, (
label = "default",
label = "Xn :: Matrix",
D1 = Diagonal(randn(m)),
D2 = Diagonal(randn(m)),
D3 = Diagonal(randn(m)),
Expand All @@ -64,6 +64,19 @@ fmul_shared_test_params = let params = []
X3 = randn(m, n),
))

push!(params, (
label = "Xn :: Vector",
D1 = Diagonal(randn(m)),
D2 = Diagonal(randn(m)),
D3 = Diagonal(randn(m)),
S1 = S1,
S2 = S2,
S3 = S3,
X1 = randn(m),
X2 = randn(m),
X3 = randn(m),
))

S1 = sprandn(m, m, 0.3)
S2 = spshared(S1)
S3 = spshared(S1)
Expand All @@ -85,8 +98,8 @@ fmul_shared_test_params = let params = []
params
end

@testset "is_shared_simd" begin
@unpack D1, D2, D3, S1, S2, S3, X1, X2, X3 = fmul_shared_test_params[1]
@testset "is_shared_simd $(p.label)" for p in fmul_shared_test_params[1:2]
@unpack D1, D2, D3, S1, S2, S3, X1, X2, X3 = p

@test SparseXX.is_shared_simd3(((D1, S1', X1), (D2, S2', X2)))
@test SparseXX.is_shared_simd2(((D1, S1'), (D2, S2'), X1))
Expand Down

0 comments on commit 38541dc

Please sign in to comment.