-
-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
199 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ | |
Manifest.toml | ||
|
||
*.swp | ||
.vscode | ||
wip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module LinearSolveBlockDiagonalsExt | ||
|
||
using LinearSolve, BlockDiagonals | ||
|
||
function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, u, Pl, Pr, | ||
maxiters::Int, abstol, reltol, verbose, assumptions; zeroinit = true) | ||
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2." | ||
# We need to perform this check even when `zeroinit == true`, since the type of the | ||
# cache is dependent on whether we are able to use the specialized dispatch. | ||
bsizes = blocksizes(A) | ||
usize = first(first(bsizes)) | ||
uniform_blocks = true | ||
for bsize in bsizes | ||
if bsize[1] != usize || bsize[2] != usize | ||
uniform_blocks = false | ||
break | ||
end | ||
end | ||
# Can't help but perform dynamic dispatch here | ||
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters, | ||
abstol, reltol, verbose, assumptions; zeroinit) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module LinearSolveNNlibExt | ||
|
||
using LinearSolve, NNlib | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
""" | ||
SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) | ||
A simple GMRES implementation for square non-Hermitian linear systems. | ||
This implementation handles Block Diagonal Matrices with Uniformly Sized Square Blocks with | ||
specialized dispatches. | ||
## Arguments | ||
* `restart::Int = 20`: the number of iterations before restarting. Must be a strictly | ||
positive integer. | ||
* `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a | ||
uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing | ||
this option will lead to incorrect results. | ||
* If this is set `≤ 0` and during runtime we get a Block Diagonal Matrix, then we will | ||
check if the specialized dispatch can be used. | ||
!!! warning | ||
Most users should be using the `KrylovJL_GMRES` solver instead of this implementation. | ||
""" | ||
struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod | ||
restart::Int | ||
blocksize::Int | ||
|
||
function SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) | ||
@assert restart≥1 "restart must be greater than or equal to 1" | ||
return new{blocksize > 0}(restart, blocksize) | ||
end | ||
end | ||
|
||
struct SimpleGMRESCache{UBD, T, QType, HType, xType, rType, βe₁Type, AType, bType, βType} | ||
M::Int | ||
N::Int | ||
maxiters::Int | ||
blocksize::Int | ||
ϵ::T | ||
Q::QType | ||
H::HType | ||
x::xType | ||
r::rType | ||
βe₁::βe₁Type | ||
A::AType | ||
b::bType | ||
β::βType | ||
abstol::T | ||
|
||
function SimpleGMRESCache{UBD}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, β, | ||
abstol) where {UBD} | ||
return new{UBD, typeof(ϵ), typeof(Q), typeof(H), typeof(x), typeof(r), typeof(βe₁), | ||
typeof(A), typeof(b), typeof(β)}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, | ||
A, b, β, abstol) | ||
end | ||
end | ||
|
||
_no_preconditioner(::Nothing) = true | ||
_no_preconditioner(::IdentityOperator) = true | ||
_no_preconditioner(::UniformScaling) = true | ||
_no_preconditioner(_) = false | ||
|
||
function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...) | ||
return _init_cacheval(Val(false), alg, args...; kwargs...) | ||
end | ||
|
||
# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks | ||
# and use the specialized dispatch | ||
function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, | ||
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true) | ||
if zeroinit | ||
return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), | ||
similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0), | ||
A, b, zero(eltype(u)), abstol) | ||
end | ||
|
||
@assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." | ||
N = LinearAlgebra.checksquare(A) | ||
T = eltype(u) | ||
M = min(maxiters, alg.restart) | ||
ϵ = eps(T) | ||
|
||
# Initialize the Cache | ||
## Use `b` since `A` might be an operator | ||
Q = similar(b, length(b), M + 1) | ||
H = similar(b, M + 1, M) | ||
fill!(H, zero(T)) | ||
|
||
mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u | ||
axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b | ||
β = norm(@view(Q[:, 1]), 2) | ||
Q[:, 1] ./= β | ||
|
||
x = u | ||
r = similar(b) | ||
βe₁ = similar(b, M + 1) | ||
fill!(βe₁, 0) | ||
βe₁[1:1] .= β # Avoid the scalar indexing error | ||
|
||
return SimpleGMRESCache{false}(M, N, maxiters, alg.blocksize, ϵ, Q, H, x, r, βe₁, A, b, | ||
β, abstol) | ||
end | ||
|
||
default_alias_A(::SimpleGMRES, ::Any, ::Any) = false | ||
default_alias_b(::SimpleGMRES, ::Any, ::Any) = false | ||
|
||
function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) | ||
if cache.isfresh | ||
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, | ||
cache.maxiters, cache.abstol, cache.reltol, cache.verbose, | ||
cache.assumptions; zeroinit = false) | ||
cache.cacheval = solver | ||
cache.isfresh = false | ||
end | ||
return SciMLBase.solve!(cache.cacheval) | ||
end | ||
|
||
function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} | ||
@unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache | ||
norm2 = Base.Fix2(norm, 2) | ||
res_norm = β | ||
|
||
# FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES | ||
# version | ||
for _ in 1:(maxiters ÷ M + 1) | ||
for j in 1:M | ||
Qⱼ₊₁ = @view(Q[:, j + 1]) | ||
mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j) | ||
for i in 1:j | ||
H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁) | ||
axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁) | ||
end | ||
H[j + 1, j] = norm2(Qⱼ₊₁) | ||
H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j]) | ||
|
||
# FIXME: Figure out a way to avoid the allocation | ||
# Using views doesn't work very well with LinearSolve | ||
y = @view(H[1:(j + 1), 1:j]) \ @view(βe₁[1:(j + 1)]) | ||
|
||
# Update the solution | ||
mul!(x, @view(Q[:, 1:j]), y) | ||
mul!(r, A, x, T(-1), T(0)) | ||
axpy!(T(1), b, r) | ||
res_norm = norm2(r) | ||
|
||
if res_norm < abstol | ||
return SciMLBase.build_linear_solution(nothing, x, r, nothing; | ||
retcode = ReturnCode.Success) | ||
end | ||
end | ||
|
||
# Restart | ||
Q[:, 1] = r ./ res_norm | ||
fill!(H, zero(T)) | ||
end | ||
|
||
return SciMLBase.build_linear_solution(nothing, x, r, nothing; | ||
retcode = ReturnCode.MaxIters) | ||
end |