Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary allocations in resize! by using views. #40

Open
manuelbb-upb opened this issue Feb 23, 2024 · 1 comment
Open

Avoid unnecessary allocations in resize! by using views. #40

manuelbb-upb opened this issue Feb 23, 2024 · 1 comment

Comments

@manuelbb-upb
Copy link

When looking into #39 I noticed that for a QRWYWs the T matrix is re-allocated quite often, even when computing the QR factorization of a smaller matrix then specified at instantiation.

This can be avoided by using views.
This is a proof-of-concept re-implementation based on what is currently in FastLapackInterface.

import LinearAlgebra: BlasInt, BlasFloat, require_one_based_indexing, chkstride1 
import LinearAlgebra.BLAS: @blasfunc
import LinearAlgebra.LAPACK: chklapackerror, liblapack

mutable struct QRWYWs{R<:Number,MT<:StridedMatrix{R}}
    work::Vector{R}
    T::MT
end

function QRWYWs(A::StridedMatrix{T}; kwargs...) where {T <: BlasFloat}
    resize!(QRWYWs(T[], Matrix{T}(undef, 0, 0)), A; kwargs...)
end

function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36, work=true)
    require_one_based_indexing(A)
    chkstride1(A)
    m, n = BlasInt.(size(A))
    @assert n >= 0 ArgumentError("Not a Matrix")
    m1 = min(m, n)
    nb = min(m1, blocksize)
    ws.T = similar(ws.T, nb, m1)
    if work
        resize!(ws.work, nb*n)
    end
    return ws
end

function geqrt!(ws::QRWYWs, A::AbstractMatrix{Float64}; resize=true, blocksize=36)
    m, n = size(A)
    minmn = min(m, n)
    nb = min(minmn, blocksize)
    t1 = size(ws.T, 1)
    if t1 < nb
        if resize
            resize!(ws, A, work = true)
        else
            #throw(WorkspaceSizeError(nb, minmn))
            error("Cannot resize.")
        end
    end
    T =@view(ws.T[1:nb, 1:minmn])
    if nb >0
        lda = max(1, stride(A, 2))
        work = ws.work
        info = Ref{BlasInt}()
        ccall(
            (@blasfunc(dgeqrt_), liblapack), Cvoid,
            (Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
            Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
            Ptr{BlasInt}),
            m, n, nb, A,
            lda, T, max(1, stride(ws.T, 2)), work,
            info
        )
        chklapackerror(info[])
    end
    return A, T
end

At the moment, this is for Float64 only and I have not yet thought about the kwargs (blocksize).
But it seems to save allocations:

import FastLapackInterface as FLA
let n=50;
   A =rand(n, n);
   ws1 =QRWYWs(A);
   ws2 = FLA.QRWYWs(A);
   for j in (0, 5, 10, 25, 50, 60)
     B =rand(n, j)
     @show size(B)
     println("new:")
     @time geqrt!(ws1, B);
     println("old:")
     @time geqrt!(ws2, B);
   end
 end

gives

size(B) = (50, 0)
new:
  0.000001 seconds
old:
  0.000003 seconds (1 allocation: 64 bytes)
size(B) = (50, 5)
new:
  0.000049 seconds
old:
  0.000006 seconds (1 allocation: 256 bytes)
size(B) = (50, 10)
new:
  0.000016 seconds
old:
  0.000015 seconds (1 allocation: 896 bytes)
size(B) = (50, 25)
new:
  0.000048 seconds
old:
  0.000038 seconds (1 allocation: 5.062 KiB)
size(B) = (50, 50)
new:
  0.000098 seconds
old:
  0.000085 seconds (1 allocation: 14.188 KiB)
size(B) = (50, 60)
new:
  0.000112 seconds
old:
  0.000090 seconds (2 allocations: 46.125 KiB)

Is this worth a PR?

@MichelJuillard
Copy link
Member

Yes @manuelbb-upb , please prepare a PR. Best

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants