diff --git a/src/qr.jl b/src/qr.jl index 62ba585..b52c687 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -213,53 +213,52 @@ mutable struct QRWYWs{R<:Number,MT<:StridedMatrix{R}} <: Workspace T::MT end +function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36) + require_one_based_indexing(A) + chkstride1(A) + m, n = BlasInt.(size(A)) + @assert n > 0 ArgumentError("Not a Matrix") + m1 = min(m, n) + nb = min(m1, blocksize) + ws.T = similar(ws.T, nb, m1) + resize!(ws.work, nb*n) + return ws +end + +QRWYWs(A::StridedMatrix{T}; kwargs...) where {T <: LinearAlgebra.BlasFloat} = + resize!(QRWYWs(T[], Matrix{T}(undef, 0, 0)), A; kwargs...) + for (geqrt, elty) in ((:dgeqrt_, :Float64), (:sgeqrt_, :Float32), (:zgeqrt_, :ComplexF64), (:cgeqrt_, :ComplexF32)) - @eval begin - function Base.resize!(ws::QRWYWs, A::StridedMatrix{$elty}; blocksize=36) - require_one_based_indexing(A) - chkstride1(A) - m, n = BlasInt.(size(A)) - @assert n > 0 ArgumentError("Not a Matrix") - m1 = min(m, n) - nb = min(m1, blocksize) - ws.T = similar(ws.T, nb, m1) - resize!(ws.work, nb*n) - return ws - end - QRWYWs(A::StridedMatrix{$elty}; kwargs...) = - resize!(QRWYWs($elty[], Matrix{$elty}(undef, 0, 0)), A; kwargs...) - - function geqrt!(ws::QRWYWs, A::AbstractMatrix{$elty}; resize=true) - require_one_based_indexing(A) - chkstride1(A) - m, n = size(A) - minmn = min(m, n) - nb = size(ws.T, 1) - if nb > minmn - if resize - resize!(ws, A) - nb = size(ws.T, 1) - else - throw(ArgumentError("Allocated workspace block size $nb > $minmn too large.\nUse resize!(ws, A).")) - end + @eval function geqrt!(ws::QRWYWs, A::AbstractMatrix{$elty}; resize=true) + require_one_based_indexing(A) + chkstride1(A) + m, n = size(A) + minmn = min(m, n) + nb = size(ws.T, 1) + if nb > minmn + if resize + resize!(ws, A) + nb = size(ws.T, 1) + else + throw(ArgumentError("Allocated workspace block size $nb > $minmn too large.\nUse resize!(ws, A).")) end - - lda = max(1, stride(A, 2)) - work = ws.work - info = Ref{BlasInt}() - ccall((@blasfunc($geqrt), liblapack), Cvoid, - (Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, - Ptr{BlasInt}), - m, n, nb, A, - lda, ws.T, max(1, stride(ws.T, 2)), ws.work, - info) - chklapackerror(info[]) - return A, ws.T end + + lda = max(1, stride(A, 2)) + work = ws.work + info = Ref{BlasInt}() + ccall((@blasfunc($geqrt), liblapack), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, + Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, + Ptr{BlasInt}), + m, n, nb, A, + lda, ws.T, max(1, stride(ws.T, 2)), ws.work, + info) + chklapackerror(info[]) + return A, ws.T end end @@ -345,7 +344,7 @@ for (geqp3, elty) in ((:dgeqp3_, :Float64), resize!(ws.work, BlasInt(real(ws.work[1]))) return ws end - QRPivotedWs(A::StridedMatrix{$elty}) = + QRPivotedWs(A::StridedMatrix{$elty}) = resize!(QRPivotedWs(Vector{$elty}(undef, 1), $elty[], BlasInt[]), A) function geqp3!(ws::QRPivotedWs{$elty}, A::AbstractMatrix{$elty}; resize=true)