From cdd20bd2d4b632b837014f7bef90369f9a41328c Mon Sep 17 00:00:00 2001 From: louisponet Date: Sun, 21 Aug 2022 12:07:07 +0200 Subject: [PATCH 1/2] Bumped version, fixed resizes, added more tests --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- src/FastLapackInterface.jl | 2 +- src/bunch_kaufman.jl | 25 ++-- src/cholesky.jl | 13 +- src/eigen.jl | 246 ++++++++++++++++--------------------- src/exceptions.jl | 8 +- src/lu.jl | 6 +- src/qr.jl | 75 +++++------ src/schur.jl | 83 +++++++------ test/bunch_kaufman_test.jl | 7 +- test/eigen_test.jl | 39 +++--- test/lu_test.jl | 8 +- test/qr_test.jl | 26 ++-- test/schur_test.jl | 12 ++ 15 files changed, 283 insertions(+), 271 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 028b663..e93a89e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: matrix: version: - '1.6' - - '1.8.0-rc1' + - '1.8.0' - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 5b10a8e..f6389f1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FastLapackInterface" uuid = "29a986be-02c6-4525-aec4-84b980013641" authors = ["Louis Ponet, Michel Juillard"] -version = "1.2.4" +version = "1.2.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FastLapackInterface.jl b/src/FastLapackInterface.jl index f4d9343..baceac7 100644 --- a/src/FastLapackInterface.jl +++ b/src/FastLapackInterface.jl @@ -16,7 +16,7 @@ else end abstract type Workspace end - +include("exceptions.jl") include("lu.jl") export LUWs include("qr.jl") diff --git a/src/bunch_kaufman.jl b/src/bunch_kaufman.jl index 6edfa82..a5a2a64 100644 --- a/src/bunch_kaufman.jl +++ b/src/bunch_kaufman.jl @@ -82,20 +82,22 @@ for (sytrfs, elty) in ((:csytrf_,:csytrf_rook_, :chetrf_,:chetrf_rook_),:ComplexF32)) @eval begin - function Base.resize!(ws::BunchKaufmanWs, A::AbstractMatrix{$elty}) + function Base.resize!(ws::BunchKaufmanWs, A::AbstractMatrix{$elty}; work = true) chkstride1(A) n = checksquare(A) if n == 0 return ws end resize!(ws.ipiv, n) - info = Ref{BlasInt}() - ccall((@blasfunc($(sytrfs[1])), liblapack), Cvoid, - (Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong), - 'U', n, A, stride(A,2), ws.ipiv, ws.work, -1, info, 1) - chkargsok(info[]) - resize!(ws.work, BlasInt(real(ws.work[1]))) + if work + info = Ref{BlasInt}() + ccall((@blasfunc($(sytrfs[1])), liblapack), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + 'U', n, A, stride(A,2), ws.ipiv, ws.work, -1, info, 1) + chkargsok(info[]) + resize!(ws.work, BlasInt(real(ws.work[1]))) + end return ws end function BunchKaufmanWs(A::AbstractMatrix{$elty}) @@ -110,11 +112,12 @@ for (sytrfs, elty) in return A, ws.ipiv, zero(BlasInt) end chkuplo(uplo) - if n > length(ws.ipiv) + nws = length(ws.ipiv) + if n != nws if resize - resize!(ws, A) + resize!(ws, A, work=nws= vu throw(ArgumentError("lower boundary, $vl, must be less than upper boundary, $vu")) end - m = Ref{BlasInt}() + + chkstride1(A) + n = checksquare(A) + nws = length(ws.w) + if nws != n + if resize + resize!(ws, A, vecs = size(ws.Z, 1) > 1 || jobz == 'V', work = n > nws) + else + throw(WorkspaceSizeError(nws, n)) + end + end + # If WS was created without support for vectors if jobz == 'N' ldz = 1 elseif jobz == 'V' ldz = n - if size(ws.Z, 1) < ldz + nws = size(ws.Z, 1) + if nws != ldz if resize - resize!(ws, A, vecs=true) - else - throw(ArgumentError("Workspace does not support eigenvectors.\nUse resize!(ws, A, vecs=true).")) - end - elseif size(ws.Z, 1) > ldz - if resize - # Only resize Z because w we resize below ws.Z = similar(ws.Z, ldz, ldz) else - throw(ArgumentError("Workspace too large.")) + throw(ArgumentError("Workspace does not support eigenvectors.\nUse resize!(ws, A, vecs=true).")) end end end - if length(ws.w) < n - if resize - resize!(ws, A, vecs = size(ws.Z, 1) > 1) - else - throw(ArgumentError("Workspace too small.\nUse resize!(ws, A).")) - end - elseif length(ws.w) > n - if resize - resize!(ws.w, n) - else - throw(ArgumentError("Workspace too large.")) - end - end + m = Ref{BlasInt}() info = Ref{BlasInt}() if eltype(A) <: Complex ccall((@blasfunc($syevr), liblapack), Cvoid, @@ -528,7 +523,7 @@ for (ggev, elty, relty) in (:zggev_, :ComplexF64, :Float64), (:cggev_, :ComplexF32, :Float32)) @eval begin - function Base.resize!(ws::GeneralizedEigenWs, A::AbstractMatrix{$elty}; lvecs=false,rvecs=false) + function Base.resize!(ws::GeneralizedEigenWs, A::AbstractMatrix{$elty}; lvecs=false,rvecs=false,work=true) require_one_based_indexing(A) chkstride1(A) n = checksquare(A) @@ -536,7 +531,7 @@ for (ggev, elty, relty) in ldb = lda resize!(ws.αr, n) cmplx = eltype(A) <: Complex - if cmplx + if cmplx && work resize!(ws.αi, 8n) else resize!(ws.αi, n) @@ -549,36 +544,36 @@ for (ggev, elty, relty) in ldvr = rvecs ? n : 1 ws.vl = zeros($elty, lvecs ? n : 0, n) ws.vr = zeros($elty, rvecs ? n : 0, n) - - info = Ref{BlasInt}() - - if cmplx - ccall((@blasfunc($ggev), liblapack), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$relty}, - Ptr{BlasInt}, Clong, Clong), - jobvl, jobvr, n, A, - lda, A, ldb, ws.αr, - ws.β, ws.vl, ldvl, ws.vr, - ldvr, ws.work, -1, ws.αi, - info, 1, 1) - else - ccall((@blasfunc($ggev), liblapack), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, - Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ptr{BlasInt}, Clong, Clong), - jobvl, jobvr, n, A, - lda, A, ldb, ws.αr, - ws.αi, ws.β, ws.vl, ldvl, - ws.vr, ldvr, ws.work, -1, - info, 1, 1) + if work + info = Ref{BlasInt}() + if cmplx + ccall((@blasfunc($ggev), liblapack), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$relty}, + Ptr{BlasInt}, Clong, Clong), + jobvl, jobvr, n, A, + lda, A, ldb, ws.αr, + ws.β, ws.vl, ldvl, ws.vr, + ldvr, ws.work, -1, ws.αi, + info, 1, 1) + else + ccall((@blasfunc($ggev), liblapack), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, + Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, + Ptr{BlasInt}, Clong, Clong), + jobvl, jobvr, n, A, + lda, A, ldb, ws.αr, + ws.αi, ws.β, ws.vl, ldvl, + ws.vr, ldvr, ws.work, -1, + info, 1, 1) + end + chklapackerror(info[]) + resize!(ws.work, BlasInt(ws.work[1])) end - chklapackerror(info[]) - resize!(ws.work, BlasInt(ws.work[1])) return ws end function GeneralizedEigenWs(A::AbstractMatrix{$elty}; kwargs...) @@ -596,56 +591,27 @@ for (ggev, elty, relty) in lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) - lb = length(ws.β) - if lb < n + nws = length(ws.β) + if nws != n if resize - resize!(ws, A, lvecs = size(ws.vl, 1) > 0, rvecs = size(ws.vr, 1) > 0 ) + resize!(ws, A, lvecs = size(ws.vl, 1) > 0 || jobvl == 'V', rvecs = size(ws.vr, 1) > 0 || jobvl == 'V', work=n > nws) else - throw(ArgumentError("Workspace too small.\nUse resize!(ws, A).")) - end - elseif lb > n - if resize - resize!(ws.β, n) - resize!(ws.αr, n) - if eltype(A) <: AbstractFloat - # Otherwise it's just a work buffer - resize!(ws.αi, n) - end - else - throw(ArgumentError("Workspace too large.")) + throw(WorkspaceSizeError(nws, n)) end end ldvl = size(ws.vl, 1) ldvr = size(ws.vr, 1) - - if jobvl == 'V' - if ldvl < n - if resize - resize!(ws, A, lvecs = true, rvecs = ldvr != 0) - else - throw(ArgumentError("Workspace was created without support for left eigenvectors or too small,\n use resize!(ws, A, lvecs=true).")) - end - elseif ldvl > n - if resize - ws.vl = similar(ws.vl, n, n) - else - throw(ArgumentError("ws.vl is too large, needs to be of size $n x $n.")) - end - end - end - if jobvr == 'V' - if ldvr < n - if resize - resize!(ws, A, rvecs = true, lvecs = ldvl != 0) - else - throw(ArgumentError("Workspace was created without support for right eigenvectors or too small,\n use resize!(ws, A, rvecs=true).")) - end - elseif ldvr > n - if resize - ws.vr = similar(ws.vr, n, n) - else - throw(ArgumentError("ws.vr is too large, needs to be of size $n x $n.")) + + for (job, v, str1, str2) in ((jobvl, :vl, "left", "lvecs"),(jobvr, :vr, "right", "rvecs")) + ldv = size(getfield(ws, v), 1) + if job == 'V' + if ldv != n + if resize + setfield!(ws, v, similar(A, n, n)) + else + throw(ArgumentError("Workspace was created without support for $str1 eigenvectors or too small,\n use resize!(ws, A, $str2=true).")) + end end end end diff --git a/src/exceptions.jl b/src/exceptions.jl index fe584b3..6d2086d 100644 --- a/src/exceptions.jl +++ b/src/exceptions.jl @@ -1,7 +1,13 @@ struct SingularException <: Exception end struct DggesException <: Exception - error_nbr::Int64 + error_nbr::Int end Base.showerror(io::IO, e::DggesException) = print(io, "dgges error ", e.error_nbr) + +struct WorkspaceSizeError <: Exception + nws::Int + n::Int +end +Base.showerror(io::IO, e::WorkspaceSizeError) = print(io, "Workspace has the wrong size: expected $(e.n), got $(e.nws).\nUse resize!(ws, A).") diff --git a/src/lu.jl b/src/lu.jl index 6f86a31..ee5301a 100644 --- a/src/lu.jl +++ b/src/lu.jl @@ -45,11 +45,13 @@ for (getrf, elty) in ((:dgetrf_, :Float64), (:cgetrf_, :ComplexF32)) @eval begin function getrf!(ws::LUWs, A::AbstractMatrix{$elty}; resize=true) - if min(size(A)...) <= length(ws.ipiv) + nws = length(ws.ipiv) + n = min(size(A)...) + if n != nws if resize resize!(ws, A) else - throw(ArgumentError("Allocated Workspace is too small.")) + throw(WorkspaceSizeError(nws, n)) end end require_one_based_indexing(A) diff --git a/src/qr.jl b/src/qr.jl index fd3ff78..3f8e00e 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -49,17 +49,19 @@ for (geqrf, elty) in ((:dgeqrf_, :Float64), (:zgeqrf_, :ComplexF64), (:cgeqrf_, :ComplexF32)) @eval begin - function Base.resize!(ws::QRWs, A::StridedMatrix{$elty}) + function Base.resize!(ws::QRWs, A::StridedMatrix{$elty}; work=true) m, n = size(A) lda = max(1, stride(A, 2)) resize!(ws.τ, min(m, n)) - info = Ref{BlasInt}() - ccall((@blasfunc($geqrf), liblapack), Cvoid, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}), - m, n, A, lda, ws.τ, ws.work, -1, info) - chklapackerror(info[]) - resize!(ws.work, BlasInt(real(ws.work[1]))) + if work + info = Ref{BlasInt}() + ccall((@blasfunc($geqrf), liblapack), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ws.τ, ws.work, -1, info) + chklapackerror(info[]) + resize!(ws.work, BlasInt(real(ws.work[1]))) + end return ws end QRWs(A::StridedMatrix{$elty}) = @@ -69,11 +71,13 @@ for (geqrf, elty) in ((:dgeqrf_, :Float64), require_one_based_indexing(A) chkstride1(A) m, n = size(A) - if length(ws) < min(m, n) + nws = length(ws) + minn = min(m, n) + if nws != minn if resize - resize!(ws, A) + resize!(ws, A; work = minn > nws) else - throw(ArgumentError("Workspace is too small, use resize!(ws, A).")) + throw(WorkspaceSizeError(nws, minn)) end end lda = max(1, stride(A, 2)) @@ -145,7 +149,7 @@ mutable struct QRWYWs{R<:Number,MT<:StridedMatrix{R}} <: Workspace T::MT end -function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36) +function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36, work=true) require_one_based_indexing(A) chkstride1(A) m, n = BlasInt.(size(A)) @@ -153,7 +157,9 @@ function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36) m1 = min(m, n) nb = min(m1, blocksize) ws.T = similar(ws.T, nb, m1) - resize!(ws.work, nb*n) + if work + resize!(ws.work, nb*n) + end return ws end @@ -170,14 +176,14 @@ for (geqrt, elty) in ((:dgeqrt_, :Float64), m, n = size(A) minmn = min(m, n) nb = size(ws.T, 1) - if nb > minmn + if nb != minmn if resize - resize!(ws, A) - nb = size(ws.T, 1) + resize!(ws, A, work = nb < minmn) else - throw(ArgumentError("Allocated workspace block size $nb > $minmn too large.\nUse resize!(ws, A).")) + throw(WorkspaceSizeError(nb, minmn)) end end + nb = size(ws.T, 1) lda = max(1, stride(A, 2)) work = ws.work @@ -260,20 +266,22 @@ for (geqp3, elty) in ((:dgeqp3_, :Float64), (:zgeqp3_, :ComplexF64), (:cgeqp3_, :ComplexF32)) @eval begin - function Base.resize!(ws::QRPivotedWs, A::StridedMatrix{$elty}) + function Base.resize!(ws::QRPivotedWs, A::StridedMatrix{$elty}; work=true) require_one_based_indexing(A) chkstride1(A) m, n = size(A) RldA = max(1, stride(A, 2)) resize!(ws.jpvt, n) resize!(ws.τ, min(m, n)) - info = Ref{BlasInt}() - ccall((@blasfunc($geqp3), liblapack), Cvoid, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), - m, n, A, RldA, ws.jpvt, ws.τ, ws.work, -1, info) - chklapackerror(info[]) - resize!(ws.work, BlasInt(real(ws.work[1]))) + if work + info = Ref{BlasInt}() + ccall((@blasfunc($geqp3), liblapack), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), + m, n, A, RldA, ws.jpvt, ws.τ, ws.work, -1, info) + chklapackerror(info[]) + resize!(ws.work, BlasInt(real(ws.work[1]))) + end return ws end QRPivotedWs(A::StridedMatrix{$elty}) = @@ -281,18 +289,13 @@ for (geqp3, elty) in ((:dgeqp3_, :Float64), function geqp3!(ws::QRPivotedWs{$elty}, A::AbstractMatrix{$elty}; resize=true) m, n = size(A) - if length(ws.τ) != min(m, n) - if resize - resize!(ws, A) - else - throw(ArgumentError("Workspace is too small, use resize!(ws, A).")) - end - end - if length(ws.jpvt) != n + nws = length(ws.jpvt) + minn = min(m, n) + if nws != n || minn != length(ws.τ) if resize - resize!(ws, A) + resize!(ws, A; work = n > nws) else - throw(ArgumentError("Workspace is too small, use resize!(ws, A).")) + throw(WorkspaceSizeError(nws, minn)) end end lda = stride(A, 2) @@ -335,7 +338,7 @@ for (ormqr, orgqr, elty) in ((:dormqr_, :dorgqr_, :Float64), @eval function ormqr!(ws::Union{QRWs{$elty}, QRPivotedWs{$elty}}, side::AbstractChar, trans::AbstractChar, A::AbstractMatrix{$elty}, - C::AbstractVecOrMat{$elty}; resize=true) + C::AbstractVecOrMat{$elty}) require_one_based_indexing(A, C) chktrans(trans) chkside(side) diff --git a/src/schur.jl b/src/schur.jl index d45a6f2..a6d148e 100644 --- a/src/schur.jl +++ b/src/schur.jl @@ -91,7 +91,7 @@ Base.length(ws::SchurWs) = length(ws.wr) for (gees, elty) in ((:dgees_, :Float64), (:sgees_, :Float32)) @eval begin - function Base.resize!(ws::SchurWs, A::AbstractMatrix{$elty}) + function Base.resize!(ws::SchurWs, A::AbstractMatrix{$elty}; work=true) require_one_based_indexing(A) chkstride1(A) n = checksquare(A) @@ -100,20 +100,22 @@ for (gees, elty) in ((:dgees_, :Float64), ws.vs = zeros($elty, n, n) resize!(ws.bwork, n) resize!(ws.eigen_values, n) - info = Ref{BlasInt}() - ccall((@blasfunc($gees), liblapack), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, Ref{BlasInt}, - Ptr{$elty}, Ref{BlasInt}, Ptr{Cvoid}, Ptr{$elty}, - Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{Cvoid}, Ptr{BlasInt}, Clong, Clong), - 'V', 'N', C_NULL, n, - A, max(1, stride(A, 2)), C_NULL, ws.wr, - ws.wi, ws.vs, max(size(ws.vs, 1), 1), ws.work, - -1, C_NULL, info, 1, 1) + if work + info = Ref{BlasInt}() + ccall((@blasfunc($gees), liblapack), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, Ref{BlasInt}, + Ptr{$elty}, Ref{BlasInt}, Ptr{Cvoid}, Ptr{$elty}, + Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{Cvoid}, Ptr{BlasInt}, Clong, Clong), + 'V', 'N', C_NULL, n, + A, max(1, stride(A, 2)), C_NULL, ws.wr, + ws.wi, ws.vs, max(size(ws.vs, 1), 1), ws.work, + -1, C_NULL, info, 1, 1) - chklapackerror(info[]) + chklapackerror(info[]) - resize!(ws.work, BlasInt(real(ws.work[1]))) + resize!(ws.work, BlasInt(real(ws.work[1]))) + end return ws end @@ -123,15 +125,16 @@ for (gees, elty) in ((:dgees_, :Float64), function gees!(ws::SchurWs{$elty}, jobvs::AbstractChar, A::AbstractMatrix{$elty}; select::Union{Nothing,Function} = nothing, - resize=false) + resize=true) require_one_based_indexing(A) chkstride1(A) n = checksquare(A) - if n > length(ws) + nws = length(ws) + if n != nws if resize - resize!(ws, A) + resize!(ws, A; work = n > nws) else - throw(ArgumentError("Allocated workspace has length $(length(ws)), but needs length $n.")) + throw(WorkspaceSizeError(nws, n)) end end info = Ref{BlasInt}() @@ -269,7 +272,7 @@ Base.length(ws::GeneralizedSchurWs) = length(ws.αr) for (gges, elty) in ((:dgges_, :Float64), (:sgges_, :Float32)) @eval begin - function Base.resize!(ws::GeneralizedSchurWs, A::AbstractMatrix{$elty}) + function Base.resize!(ws::GeneralizedSchurWs, A::AbstractMatrix{$elty}; work=true) chkstride1(A) n = checksquare(A) resize!(ws.αr, n) @@ -279,23 +282,25 @@ for (gges, elty) in ((:dgges_, :Float64), resize!(ws.eigen_values, n) ws.vsl = zeros($elty, n, n) ws.vsr = zeros($elty, n, n) - info = Ref{BlasInt}() - ccall((@blasfunc($gges), liblapack), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{Cvoid}, Ptr{$elty}, Ptr{$elty}, - Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{Cvoid}, - Ref{BlasInt}, Clong, Clong, Clong), - 'V', 'V', 'N', C_NULL, - n, A, max(1, stride(A, 2)), A, - max(1, stride(A, 2)), C_NULL, ws.αr, ws.αi, - ws.β, ws.vsl, n, ws.vsr, - n, ws.work, -1, C_NULL, - info, 1, 1, 1) + if work + info = Ref{BlasInt}() + ccall((@blasfunc($gges), liblapack), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{Cvoid}, Ptr{$elty}, Ptr{$elty}, + Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{Cvoid}, + Ref{BlasInt}, Clong, Clong, Clong), + 'V', 'V', 'N', C_NULL, + n, A, max(1, stride(A, 2)), A, + max(1, stride(A, 2)), C_NULL, ws.αr, ws.αi, + ws.β, ws.vsl, n, ws.vsr, + n, ws.work, -1, C_NULL, + info, 1, 1, 1) - chklapackerror(info[]) - resize!(ws.work, BlasInt(real(ws.work[1]))) + chklapackerror(info[]) + resize!(ws.work, BlasInt(real(ws.work[1]))) + end return ws end GeneralizedSchurWs(A::AbstractMatrix{$elty}) = @@ -311,11 +316,12 @@ for (gges, elty) in ((:dgges_, :Float64), if n != m throw(DimensionMismatch("dimensions of A, ($n,$n), and B, ($m,$m), must match")) end - if n > length(ws) + nws = length(ws) + if n != nws if resize - resize!(ws, A) + resize!(ws, A; work = n > nws) else - throw(ArgumentError("Allocated workspace has length $(length(ws)), but needs length $n.")) + throw(WorkspaceSizeError(nws, n)) end end @@ -358,8 +364,7 @@ for (gges, elty) in ((:dgges_, :Float64), @inbounds for i in axes(A, 1) ws.eigen_values[i] = complex(ws.αr[i], ws.αi[i]) end - return A, B, ws.eigen_values, ws.β, view(ws.vsl, 1:(jobvsl == 'V' ? n : 0), :), - view(ws.vsr, 1:(jobvsr == 'V' ? n : 0), :) + return A, B, ws.eigen_values, ws.β, ws.vsl, ws.vsr end end end diff --git a/test/bunch_kaufman_test.jl b/test/bunch_kaufman_test.jl index 22ebdfd..f83e9e2 100644 --- a/test/bunch_kaufman_test.jl +++ b/test/bunch_kaufman_test.jl @@ -23,8 +23,11 @@ using LinearAlgebra.LAPACK A1, ipiv1, inf1 = decompose!(ws, 'U', copy(A)) @test A1 == A2 @test ipiv1 == ipiv2 - decompose!(ws, Symmetric(rand(T, n+1, n+1))) - @test length(ws.ipiv) == n+1 + for div in (-1,1) + @test_throws FastLapackInterface.WorkspaceSizeError decompose!(ws, Symmetric(rand(T, n+div, n+div)); resize=false) + decompose!(ws, Symmetric(rand(T, n+div, n+div))) + @test length(ws.ipiv) == n+div + end end @testset "$(T) sytrf_rook!" begin A = rand(T, n, n) diff --git a/test/eigen_test.jl b/test/eigen_test.jl index d76bd95..f934f0d 100644 --- a/test/eigen_test.jl +++ b/test/eigen_test.jl @@ -50,12 +50,11 @@ using LinearAlgebra.LAPACK @test_throws ArgumentError factorize!(ws, 'P', 'V', 'V', 'E', copy(A0); resize=false) factorize!(ws, 'P', 'V', 'V', 'E', copy(A0)) @test size(ws.iwork, 1) != 0 - @test_throws ArgumentError factorize!(ws, 'P', 'N', 'N', 'N', rand(n+1, n+1); resize=false) - A2, WR2, WI2, VL2, VR2, ilo2, ihi2, scale2, abnrm2, rconde2, rcondv2 = - factorize!(ws, 'P', 'N', 'N', 'N', rand(n+1, n+1)) - - @test length(WR2) == n+1 - + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, 'P', 'N', 'N', 'N',rand(n+div, n+div); resize=false) + factorize!(ws, 'P', 'N', 'N', 'N',rand(n+div, n+div)) + @test length(ws.W) == n+div + end show(devnull, "text/plain", ws) end @@ -121,18 +120,17 @@ end show(devnull, "text/plain", ws) ws = Workspace(LAPACK.syevr!, copy(A0); vecs = false) - @test_throws ArgumentError factorize!(ws, 'N', 'A', 'U', randn(n+1, n+1), 0.0, 0.0, 0, 0, 1e-6; resize=false) - @test_throws ArgumentError factorize!(ws, 'V', 'A', 'U', copy(A0), 0.0, 0.0, 0, 0, 1e-6; resize=false) - w2, Z2 = factorize!(ws, 'V', 'A', 'U', randn(n+1, n+1), 0.0, 0.0, 0, 0, 1e-6) - @test length(ws.w) == n+1 - @test size(ws.Z, 1) == n+1 + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, 'N', 'A', 'U', randn(n+div, n+div), 0.0, 0.0, 0, 0, 1e-6; resize=false) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, 'V', 'A', 'U', randn(n+div, n+div), 0.0, 0.0, 0, 0, 1e-6; resize=false) + w2, Z2 = factorize!(ws, 'V', 'A', 'U', randn(n+div, n+div), 0.0, 0.0, 0, 0, 1e-6) + @test length(ws.w) == n+div + @test size(ws.Z, 1) == n+div + end w2, Z2 = factorize!(ws, 'V', 'A', 'U', copy(A0), 0.0, 0.0, 0, 0, 1e-16) @test length(w2) == n @test sum(abs.(Matrix(Eigen(w2, Z2)) .- A0)) < 1e-12 - - @test_throws ArgumentError factorize!(ws, 'V', 'I', 'U', randn(n+1, n+1), 0.0, 0.0, 10, 5, 1e-6) - @test_throws ArgumentError factorize!(ws, 'V', 'V', 'U', randn(n+1, n+1), 2.0, 1.0, 0, 0, 1e-6) end @testset "Complex, square" begin @@ -184,12 +182,13 @@ end @test_throws ArgumentError LAPACK.ggev!(ws, 'V', 'V', copy(A0), copy(B0); resize=false) LAPACK.ggev!(ws, 'V', 'V', copy(A0), copy(B0)) @test size(ws.vl, 1) == n - @test_throws ArgumentError LAPACK.ggev!(ws, 'V', 'V', randn(n+1,n+1), randn(n+1, n+1), resize=false) - LAPACK.ggev!(ws, 'V', 'V', randn(n+1,n+1), randn(n+1, n+1)) - @test size(ws.vl, 1) == n+1 - @test size(ws.vr, 1) == n+1 - αr1, αi1, β1, vl1, vr1 = LAPACK.ggev!(ws, 'V', 'V', copy(A0), copy(B0)) - @test length(αr1) == n + + for div in (-1,1) + @test_throws FastLapackInterface.WorkspaceSizeError LAPACK.ggev!(ws, 'V', 'V', randn(n+div,n+div), randn(n+div, n+div), resize=false) + LAPACK.ggev!(ws, 'V', 'V', randn(n+div,n+div), randn(n+div, n+div)) + @test size(ws.vl, 1) == n+div + @test size(ws.vr, 1) == n+div + end end @testset "Complex, square" begin diff --git a/test/lu_test.jl b/test/lu_test.jl index 0dbc4d7..84fccd8 100644 --- a/test/lu_test.jl +++ b/test/lu_test.jl @@ -28,8 +28,12 @@ m = 2 @test UpperTriangular(reshape(res.U, n, n)) ≈ F.U show(devnull, "text/plain", linws) - resize!(linws, rand(elty, 5,5)) - @test length(linws.ipiv) == 5 + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(linws, rand(elty, n+div, n+div); resize=false) + factorize!(linws, rand(elty, n+div, n+div)) + @test length(linws.ipiv) == n+div + end + # res = LU(LAPACK.getrf!(collect(copy(A)'), linws)...) # @test UpperTriangular(reshape(res.U, n, n)) ≈ F.U diff --git a/test/qr_test.jl b/test/qr_test.jl index 2e7761c..c528354 100644 --- a/test/qr_test.jl +++ b/test/qr_test.jl @@ -18,10 +18,11 @@ using LinearAlgebra.LAPACK qr2 = QR(factorize!(ws, copy(A0))...) @test isapprox(Matrix(qr1), Matrix(qr2)) - - @test_throws ArgumentError factorize!(ws, rand(n+1, n+1); resize=false) - factorize!(ws, rand(n+1, n+1)) - @test size(ws.τ , 1) == n+1 + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, rand(n+div, n+div); resize=false) + factorize!(ws, rand(n+div, n+div)) + @test size(ws.τ , 1) == n+div + end end @testset "ormqr!" begin @@ -67,10 +68,13 @@ end qr1 = LinearAlgebra.QRCompactWY(factorize!(ws, A)...) @test isapprox(Matrix(qr1), Matrix(qr2)) show(devnull, "text/plain", ws) - @test_throws ArgumentError factorize!(ws, rand(n-1, n-1); resize=false) - factorize!(ws, rand(n-1, n-1)) - @test size(ws.T, 1) == n-1 + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, rand(n+div, n+div); resize=false) + factorize!(ws, rand(n+div, n+div)) + @test size(ws.T , 1) == n+div + end end + end end @@ -93,9 +97,11 @@ end q1 = QRPivoted(factorize!(ws, copy(A))...) @test isapprox(Matrix(q1), Matrix(q2)) - @test_throws ArgumentError factorize!(ws, rand(n+1, n+1); resize=false) - factorize!(ws, rand(n+1, n+1)) - @test size(ws.τ , 1) == n+1 + for div in (-1, 1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, rand(n+div, n+div); resize=false) + factorize!(ws, rand(n+div, n+div)) + @test size(ws.τ, 1) == n+div + end show(devnull, "text/plain", ws) end @testset "orgqr!" begin diff --git a/test/schur_test.jl b/test/schur_test.jl index c2b8da8..e7c3519 100644 --- a/test/schur_test.jl +++ b/test/schur_test.jl @@ -20,6 +20,12 @@ using LinearAlgebra.LAPACK @test isapprox(vs1, vs2) @test isapprox(wr1, wr2) show(devnull, "text/plain", ws) + for div in (-1,1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, 'V', randn(n+div, n+div); resize=false) + factorize!(ws, 'V', randn(n+div, n+div)) + @test length(ws.wr) == n+div + end + end @testset "Real, square, select" begin @@ -66,6 +72,12 @@ end @test isapprox(vsl1, vsl2) @test isapprox(vsr1, vsr2) show(devnull, "text/plain", ws) + + for div in (-1,1) + @test_throws FastLapackInterface.WorkspaceSizeError factorize!(ws, 'V', 'V', randn(n+div, n+div), randn(n+div, n+div); resize=false) + factorize!(ws, 'V', 'V', randn(n+div, n+div), randn(n+div, n+div)) + @test length(ws.αr) == n+div + end end #TODO: This should be tested with something realistic From dab895717f6cd670d0e497a58a32025244040310 Mon Sep 17 00:00:00 2001 From: louisponet Date: Sun, 21 Aug 2022 12:09:12 +0200 Subject: [PATCH 2/2] do not return view --- src/qr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qr.jl b/src/qr.jl index 3f8e00e..0c48b00 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -400,7 +400,7 @@ for (ormqr, orgqr, elty) in ((:dormqr_, :dorgqr_, :Float64), info) chklapackerror(info[]) if n < size(A,2) - return view(A, :, 1:n) + return A[:, 1:n] else return A end