Skip to content

Commit

Permalink
improve defaults and add simplelu
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 22, 2022
1 parent 3d12bd8 commit f336a16
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 18 deletions.
5 changes: 4 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021 Jonathan <[email protected]> and contributors
Copyright (c) 2021 SciML and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand All @@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

SimpleLU.jl is derived from https://github.com/JuliaGNI/SimpleSolvers.jl under
an MIT license.
4 changes: 3 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false

include("common.jl")
include("factorization.jl")
include("simplelu.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("default.jl")
Expand All @@ -45,7 +46,8 @@ const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, RFLUFactorization, UMFPACKFactorization, KLUFactorization
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
UMFPACKFactorization, KLUFactorization
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
Expand Down
48 changes: 32 additions & 16 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ function defaultalg(A,b)
# whether MKL or OpenBLAS is being used
if (A === nothing && !isgpu(b)) || A isa Matrix
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b) && (length(b) <= 100 ||
(isopenblas() && length(b) <= 500)
)
alg = RFLUFactorization()
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
else
alg = LUFactorization()
end
else
alg = LUFactorization()
end
Expand Down Expand Up @@ -58,12 +62,18 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = RFLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
end
else
alg = LUFactorization()
SciMLBase.solve(cache, alg, args...; kwargs...)
Expand Down Expand Up @@ -110,12 +120,18 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
# it makes sense according to the benchmarks, which is dependent on
# whether MKL or OpenBLAS is being used
if A isa Matrix
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
(isopenblas() && size(A,1) <= 500)
)
alg = RFLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
ArrayInterface.can_setindex(b)
if length(b) <= 10
alg = GenericLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
alg = RFLUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
end
else
alg = LUFactorization()
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
Expand Down
132 changes: 132 additions & 0 deletions src/simplelu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
## From https://github.com/JuliaGNI/SimpleSolvers.jl/blob/master/src/linear/lu_solver.jl

mutable struct LUSolver{T}
n::Int
A::Matrix{T}
b::Vector{T}
x::Vector{T}
pivots::Vector{Int}
perms::Vector{Int}
info::Int

LUSolver{T}(n) where {T} = new(n, zeros(T, n, n), zeros(T, n), zeros(T, n), zeros(Int, n), zeros(Int, n), 0)
end

function LUSolver(A::Matrix{T}) where {T}
n = LinearAlgebra.checksquare(A)
lu = LUSolver{eltype(A)}(n)
lu.A .= A
lu
end

function LUSolver(A::Matrix{T}, b::Vector{T}) where {T}
n = LinearAlgebra.checksquare(A)
@assert n == length(b)
lu = LUSolver{eltype(A)}(n)
lu.A .= A
lu.b .= b
lu
end

function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T}
A = lu.A

begin
@inbounds for i in eachindex(lu.perms)
lu.perms[i] = i
end

@inbounds for k = 1:lu.n
# find index max
kp = k
if pivot
amax = real(zero(T))
for i = k:lu.n
absi = abs(A[i,k])
if absi > amax
kp = i
amax = absi
end
end
end
lu.pivots[k] = kp
lu.perms[k], lu.perms[kp] = lu.perms[kp], lu.perms[k]

if A[kp,k] != 0
if k != kp
# Interchange
for i = 1:lu.n
tmp = A[k,i]
A[k,i] = A[kp,i]
A[kp,i] = tmp
end
end
# Scale first column
Akkinv = inv(A[k,k])
for i = k+1:lu.n
A[i,k] *= Akkinv
end
elseif lu.info == 0
lu.info = k
end
# Update the rest
for j = k+1:lu.n
for i = k+1:lu.n
A[i,j] -= A[i,k]*A[k,j]
end
end
end

lu.info
end
end

function simplelu_solve!(lu::LUSolver{T}) where {T}
local s::T

@inbounds for i = 1:lu.n
lu.x[i] = lu.b[lu.perms[i]]
end

@inbounds for i = 2:lu.n
s = 0
for j = 1:i-1
s += lu.A[i,j] * lu.x[j]
end
lu.x[i] -= s
end

lu.x[lu.n] /= lu.A[lu.n,lu.n]
@inbounds for i = lu.n-1:-1:1
s = 0
for j = i+1:lu.n
s += lu.A[i,j] * lu.x[j]
end
lu.x[i] -= s
lu.x[i] /= lu.A[i,i]
end

lu.b .= lu.x

lu.x
end

### Wrapper

struct SimpleLUFactorization <: AbstractFactorization
pivot::Bool
SimpleLUFactorization(pivot=true) = new(pivot)
end

function SciMLBase.solve(cache::LinearCache, alg::SimpleLUFactorization; kwargs...)
if cache.isfresh
cache.cacheval.A = cache.A
simplelu_factorize!(cache.cacheval, alg.pivot)
end
cache.cacheval.b = cache.b
cache.cacheval.x = cache.u
y = simplelu_solve!(cache.cacheval)
SciMLBase.build_linear_solution(alg,y,nothing,cache)
end

init_cacheval(alg::SimpleLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = LUSolver(convert(AbstractMatrix,A))

0 comments on commit f336a16

Please sign in to comment.