Skip to content

Commit

Permalink
lse implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
louisponet committed Sep 14, 2022
1 parent 7cb57ae commit 08f34a8
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 9 deletions.
6 changes: 6 additions & 0 deletions docs/src/LAPACK.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ LinearAlgebra.LAPACK.hetrf_rook!(::BunchKaufmanWs, ::AbstractChar, ::AbstractMat
```@docs
LinearAlgebra.LAPACK.pstrf!(::CholeskyPivotedWs, ::AbstractChar, ::AbstractMatrix, ::Real)
```

## LSE
```@docs
LinearAlgebra.LAPACK.gglse!(::LSEWs, ::AbstractMatrix, ::AbstractVector, ::AbstractMatrix, ::AbstractVector)
```

3 changes: 0 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
The goal of `FastLapackInterface` is to eliminate any temporary allocations when using certain [`LAPACK functions`](@ref LAPACK) compared to Base julia. This is achieved by providing some [`Workspaces`](@ref WorkSpaces) that can then be used during the computation of [`LAPACK functions`](@ref LAPACK).
Eliminating most of the allocations not only improves the computation time of the functions, but dramatically improves `GC` impact when performing multithreaded workloads.

!!! note
For now the target functionality is limited to [`QR`](@ref QR-id), [`Schur`](@ref Schur-id), [`LU`](@ref LU-id), [`Eigen`](@ref Eigen-id), [`Bunch-Kaufman`](@ref BunchKaufman-id), and [`CholeskyPivoted`](@ref Cholesky-id) related decompositions.

```@meta
DocTestSetup = quote
using LinearAlgebra, FastLapackInterface
Expand Down
5 changes: 5 additions & 0 deletions docs/src/workspaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ BunchKaufmanWs
```@docs
CholeskyPivotedWs
```

## [LSE](@id LSE-id)
```@docs
LSEWs
```
2 changes: 2 additions & 0 deletions src/FastLapackInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ include("bunch_kaufman.jl")
export BunchKaufmanWs
include("cholesky.jl")
export CholeskyPivotedWs
include("lse.jl")
export LSEWs

# Uniform interface
include("workspace.jl")
Expand Down
122 changes: 122 additions & 0 deletions src/lse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import LinearAlgebra.LAPACK: gglse!

"""
LSEWs
Workspace for the least squares solving function [`LAPACK.geqrf!`](@ref).
# Examples
```jldoctest
julia> A = [1.2 2.3 6.2
6.2 3.3 8.8
9.1 2.1 5.5]
3×3 Matrix{Float64}:
1.2 2.3 6.2
6.2 3.3 8.8
9.1 2.1 5.5
julia> B = [2.7 3.1 7.7
4.1 8.1 1.8]
2×3 Matrix{Float64}:
2.7 3.1 7.7
4.1 8.1 1.8
julia> c = [0.2, 7.2, 2.9]
3-element Vector{Float64}:
0.2
7.2
2.9
julia> d = [3.9, 2.1]
2-element Vector{Float64}:
3.9
2.1
julia> ws = LSEWs(A, B)
LSEWs{Float64}
work: 101-element Vector{Float64}
X: 3-element Vector{Float64}
julia> LAPACK.gglse!(ws, A, c, B, d)
([0.19723156207005318, 0.0683561362406917, 0.40981438442398854], 13.750943845251626)
```
"""
struct LSEWs{T} <: Workspace
work::Vector{T}
X::Vector{T}
end

LSEWs(A::AbstractMatrix) = LSEWs(A, A)
LSEWs(A::AbstractMatrix, B::AbstractMatrix) = resize!(LSEWs(Vector{eltype(A)}(undef, 1), Vector{eltype(A)}(undef, size(A,2))), A, B)
for (gglse, elty) in ((:dgglse_, :Float64),
(:sgglse_, :Float32),
(:zgglse_, :ComplexF64),
(:cgglse_, :ComplexF32))
@eval begin
function Base.resize!(ws::LSEWs, A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty}; work=true, blocksize=32)
require_one_based_indexing(A)
chkstride1(A)
m, n = size(A)
p = size(B,1)
resize!(ws.X, n)
if work
resize!(ws.work, p + min(m, n) + max(m,n)*blocksize)
end
return ws

end

function gglse!(ws::LSEWs{$elty}, A::AbstractMatrix{$elty}, c::AbstractVector{$elty},
B::AbstractMatrix{$elty}, d::AbstractVector{$elty}; resize=true, blocksize=32)
require_one_based_indexing(A, c, B, d)
chkstride1(A, c, B, d)
m, n = size(A)
p = size(B, 1)
if size(B, 2) != n
throw(DimensionMismatch("B has second dimension $(size(B,2)), needs $n"))
end
if length(c) != m
throw(DimensionMismatch("c has length $(length(c)), needs $m"))
end
if length(d) != p
throw(DimensionMismatch("d has length $(length(d)), needs $p"))
end
if n > m + p
throw(DimensionMismatch("Rows of A + rows of B needs to be larger than columns of A and B."))
end
nws = length(ws.X)
if nws != n
if resize
resize!(ws.X, n)
worksize = p + min(m, n) + max(m,n)*blocksize
if length(ws.work) < worksize
resize!(ws.work, worksize)
end
else
throw(WorkspaceSizeError(nws, n))
end
end

info = Ref{BlasInt}()
ccall((@blasfunc($gglse), liblapack), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt},
Ptr{BlasInt}),
m, n, p, A, max(1,stride(A,2)), B, max(1,stride(B,2)), c, d, ws.X,
ws.work, length(ws.work), info)
chklapackerror(info[])
ws.X, dot(view(c, n - p + 1:m), view(c, n - p + 1:m))
end
end
end

"""
gglse!(ws, A, c, B, d) -> (ws.X,res)
Solves the equation `A * x = c` where `x` is subject to the equality
constraint `B * x = d`. Uses the formula `||c - A*x||^2 = 0` to solve.
Uses preallocated [`LSEWs`](@ref) to store `X` and work buffers.
Returns `ws.X` and the residual sum-of-squares.
"""
gglse!(ws::LSEWs, A::AbstractMatrix, c::AbstractVector, B::AbstractMatrix, d::AbstractVector)
13 changes: 7 additions & 6 deletions src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ QRWs{Float64}
julia> t = QR(LAPACK.geqrf!(ws, A)...)
QR{Float64, Matrix{Float64}, Vector{Float64}}
Q factor:
2×2 QRPackedQ{Float64, Matrix{Float64}, Vector{Float64}}:
2×2 LinearAlgebra.QRPackedQ{Float64, Matrix{Float64}, Vector{Float64}}:
-0.190022 -0.98178
-0.98178 0.190022
R factor:
Expand Down Expand Up @@ -127,10 +127,10 @@ QRWYWs{Float64, Matrix{Float64}}
work: 4-element Vector{Float64}
T: 2×2 Matrix{Float64}
julia> t = QRCompactWY(LAPACK.geqrt!(ws, A)...)
QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}
julia> t = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws, A)...)
LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}
Q factor:
2×2 QRCompactWYQ{Float64, Matrix{Float64}, Matrix{Float64}}:
2×2 LinearAlgebra.QRCompactWYQ{Float64, Matrix{Float64}, Matrix{Float64}}:
-0.190022 -0.98178
-0.98178 0.190022
R factor:
Expand Down Expand Up @@ -229,15 +229,16 @@ julia> A = [1.2 2.3
6.2 3.3
julia> ws = QRPivotedWs(A)
QRPivotedWs{Float64}
QRPivotedWs{Float64, Float64}
work: 100-element Vector{Float64}
rwork: 0-element Vector{Float64}
τ: 2-element Vector{Float64}
jpvt: 2-element Vector{Int64}
julia> t = QRPivoted(LAPACK.geqp3!(ws, A)...)
QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}
Q factor:
2×2 QRPackedQ{Float64, Matrix{Float64}, Vector{Float64}}:
2×2 LinearAlgebra.QRPackedQ{Float64, Matrix{Float64}, Vector{Float64}}:
-0.190022 -0.98178
-0.98178 0.190022
R factor:
Expand Down
4 changes: 4 additions & 0 deletions src/workspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ Workspace(::typeof(LAPACK.hetrf_rook!), A::AbstractMatrix) = BunchKaufmanWs(A)

Workspace(::typeof(LAPACK.pstrf!), A::AbstractMatrix) = CholeskyPivotedWs(A)

Workspace(::typeof(LAPACK.gglse!), A::AbstractMatrix) = LSEWs(A)

"""
decompose!(ws, args...)
Expand Down Expand Up @@ -101,6 +103,8 @@ function decompose!(ws::CholeskyPivotedWs, A::Union{Hermitian, Symmetric}, tol=1
return LAPACK.pstrf!(ws, A.uplo, A.data, tol; kwargs...)
end

decompose!(ws::LSEWs, args...; kwargs...) = LAPACK.gglse!(ws, args...; kwargs...)

"""
factorize!(ws, args...)
Expand Down
41 changes: 41 additions & 0 deletions test/lse_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Test
using FastLapackInterface
using LinearAlgebra.LAPACK

@testset "LSEWs" begin
n = 8
m = 4
p = 6
for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "$T" begin
A0 = rand(T, m, n)
B0 = rand(T, p, n)
c0 = rand(T, m)
d0 = rand(T, p)
ws = LSEWs(A0, B0)
@testset "gglse!" begin
A = copy(A0)
B = copy(B0)
c = copy(c0)
d = copy(d0)
X1, err1 = LAPACK.gglse!(copy(A0), copy(c0), copy(B0), copy(d0))
X2, err2 = LAPACK.gglse!(ws, copy(A0), copy(c0), copy(B0), copy(d0))

@test isapprox(X1, X2)
@test isapprox(err1, err2)
# using Workspace, factorize!
ws = Workspace(LAPACK.gglse!, copy(A0))
X2, err2 = factorize!(ws, copy(A0), copy(c0), copy(B0), copy(d0))

@test isapprox(X1, X2)
for div in (-1, 1)
@test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, rand(T, m, n+div), copy(c0), rand(T, p+div, n+div), rand(T, p+div); resize=false)
factorize!(ws, rand(T, m, n+div), copy(c0), rand(T, p+div, n+div), rand(T, p+div))
@test size(ws.X , 1) == n+div
@test size(ws.work , 1) >= p+div + min(m, n+div) + max(m,n+div)*32
end
end
end
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using FastLapackInterface
using LinearAlgebra
using Test

include("lse_test.jl")
include("lu_test.jl")
include("schur_test.jl")
include("qr_test.jl")
Expand Down

0 comments on commit 08f34a8

Please sign in to comment.