Skip to content

Commit

Permalink
Make FastLapackInterface.jl an extension as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 6, 2025
1 parent 3bff586 commit 3beafc8
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 111 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Expand All @@ -35,6 +34,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -52,6 +52,7 @@ LinearSolveCUDAExt = "CUDA"
LinearSolveCUDSSExt = "CUDSS"
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down Expand Up @@ -126,6 +127,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -150,4 +152,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface"]
16 changes: 11 additions & 5 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ use `Krylov_GMRES()`.

### RecursiveFactorization.jl

!!! note

Using this solver requires adding the package RecursiveFactorization.jl, i.e. `using RecursiveFactorization`

```@docs
RFLUFactorization
```
Expand Down Expand Up @@ -123,7 +127,13 @@ FastLapackInterface.jl is a package that allows for a lower-level interface to t
calls to allow for preallocating workspaces to decrease the overhead of the wrappers.
LinearSolve.jl provides a wrapper to these routines in a way where an initialized solver
has a non-allocating LU factorization. In theory, this post-initialized solve should always
be faster than the Base.LinearAlgebra version.
be faster than the Base.LinearAlgebra version. In practice, with the way we wrap the solvers,
we do not see a performance benefit and in fact benchmarks tend to show this inhibits
performance.

!!! note

Using this solver requires adding the package FastLapackInterface.jl, i.e. `using FastLapackInterface`

```@docs
FastLUFactorization
Expand Down Expand Up @@ -157,10 +167,6 @@ KrylovJL

### MKL.jl

!!! note

Using this solver requires adding the package MKL_jll.jl, i.e. `using MKL_jll`

```@docs
MKLLUFactorization
```
Expand Down
82 changes: 82 additions & 0 deletions ext/LinearSolveFastLapackInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module LinearSolveFastLapackInterfaceExt

using LinearSolve, LinearAlgebra
using FastLapackInterface

struct WorkspaceAndFactors{W, F}
workspace::W
factors::F
end

function LinearSolve.init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = LUWs(A)
return WorkspaceAndFactors(ws, LinearSolve.ArrayInterface.lu_instance(convert(AbstractMatrix, A)))
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::FastLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = LinearSolve.@get_cacheval(cache, :FastLUFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace,
A)...)
cache.cacheval = ws_and_fact
cache.isfresh = false
end
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRWYWs(A; blocksize = alg.blocksize)
return WorkspaceAndFactors(ws,
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end
function LinearSolve.init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRpWs(A)
return WorkspaceAndFactors(ws,
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end

function LinearSolve.init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::FastQRFactorization{P};
kwargs...) where {P}
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = LinearSolve.@get_cacheval(cache, :FastQRFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
if P === NoPivot
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(
ws_and_fact.workspace,
A)...)
else
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(
ws_and_fact.workspace,
A)...)
end
cache.cacheval = ws_and_fact
cache.isfresh = false
end
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end


end
1 change: 0 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using SciMLOperators
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using Setfield
using UnPack
using FastLapackInterface
using DocStringExtensions
using EnumX
using Markdown
Expand Down
25 changes: 25 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,31 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror =
RFLUFactorization(pivot, thread; throwerror)
end

# There's no options like pivot here.
# But I'm not sure it makes sense as a GenericFactorization
# since it just uses `LAPACK.getrf!`.
"""
`FastLUFactorization()`
The FastLapackInterface.jl version of the LU factorization. Notably,
this version does not allow for choice of pivoting method.
"""
struct FastLUFactorization <: AbstractDenseFactorization end

"""
`FastQRFactorization()`
The FastLapackInterface.jl version of the QR factorization.
"""
struct FastQRFactorization{P} <: AbstractDenseFactorization
pivot::P
blocksize::Int
end

# is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
# but QRFactorization uses 16.
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)

"""
```julia
MKLPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,
Expand Down
102 changes: 0 additions & 102 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1026,108 +1026,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::DiagonalFactorization;
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

## FastLAPACKFactorizations

struct WorkspaceAndFactors{W, F}
workspace::W
factors::F
end

# There's no options like pivot here.
# But I'm not sure it makes sense as a GenericFactorization
# since it just uses `LAPACK.getrf!`.
"""
`FastLUFactorization()`
The FastLapackInterface.jl version of the LU factorization. Notably,
this version does not allow for choice of pivoting method.
"""
struct FastLUFactorization <: AbstractDenseFactorization end

function init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = LUWs(A)
return WorkspaceAndFactors(ws, ArrayInterface.lu_instance(convert(AbstractMatrix, A)))
end

function SciMLBase.solve!(cache::LinearCache, alg::FastLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = @get_cacheval(cache, :FastLUFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace,
A)...)
cache.cacheval = ws_and_fact
cache.isfresh = false
end
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

"""
`FastQRFactorization()`
The FastLapackInterface.jl version of the QR factorization.
"""
struct FastQRFactorization{P} <: AbstractDenseFactorization
pivot::P
blocksize::Int
end

# is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
# but QRFactorization uses 16.
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)

function init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRWYWs(A; blocksize = alg.blocksize)
return WorkspaceAndFactors(ws,
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end
function init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRpWs(A)
return WorkspaceAndFactors(ws,
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end

function init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
end

function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
kwargs...) where {P}
A = cache.A
A = convert(AbstractMatrix, A)
ws_and_fact = @get_cacheval(cache, :FastQRFactorization)
if cache.isfresh
# we will fail here if A is a different *size* than in a previous version of the same cache.
# it may instead be desirable to resize the workspace.
if P === NoPivot
@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(
ws_and_fact.workspace,
A)...)
else
@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(
ws_and_fact.workspace,
A)...)
end
cache.cacheval = ws_and_fact
cache.isfresh = false
end
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## SparspakFactorization is here since it's MIT licensed, not GPL

"""
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators, RecursiveFactorization, Sparspak
using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
using Test
import Random
Expand Down

0 comments on commit 3beafc8

Please sign in to comment.