Skip to content

Commit

Permalink
Merge pull request #424 from JuliaArrays/lu_instance
Browse files Browse the repository at this point in the history
Cover more lu instance types
  • Loading branch information
ChrisRackauckas authored Nov 22, 2023
2 parents 43dafbb + c42fd7e commit c40655d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 165 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
julia-version: [1,1.6]
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: JuliaDiff, repo: SparseDiffTools.jl, group: Core}
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
- '1.7'
- '1.8'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.5.1"
version = "7.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -11,11 +11,11 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3"
LinearAlgebra = "1.6"
LinearAlgebra = "1.9"
Requires = "1"
SparseArrays = "1.6"
SuiteSparse = "1.6"
julia = "1.6"
SparseArrays = "1.9"
SuiteSparse = "1.9"
julia = "1.9"

[extensions]
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
Expand Down
226 changes: 70 additions & 156 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,8 @@ julia> ArrayInterface.map_tuple_type(sqrt, Tuple{1,4,16})
```
"""
function map_tuple_type end
if VERSION >= v"1.8"
@inline function map_tuple_type(f, @nospecialize(T::Type))
ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}())
end
else
function map_tuple_type(f::F, ::Type{T}) where {F, T <: Tuple}
if @generated
t = Expr(:tuple)
for i in 1:fieldcount(T)
push!(t.args, :(f($(fieldtype(T, i)))))
end
Expr(:block, Expr(:meta, :inline), t)
else
Tuple(f(fieldtype(T, i)) for i in 1:fieldcount(T))
end
end
@inline function map_tuple_type(f, @nospecialize(T::Type))
ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}())
end

"""
Expand All @@ -78,50 +64,22 @@ julia> ArrayInterface.flatten_tuples((1, (2, (3,))))
```
"""
function flatten_tuples end
if VERSION >= v"1.8"
function flatten_tuples(t::Tuple)
fields = _new_field_positions(t)
ntuple(Val{nfields(fields)}()) do k
i, j = getfield(fields, k)
i = length(t) - i
@inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j)
end
end
_new_field_positions(::Tuple{}) = ()
@nospecialize
function _new_field_positions(x::Tuple)
(_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...)
end
_fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1)))
_fl1(x::Tuple, x1) = ((length(x) - 1, 0),)
@specialize
else
@inline function flatten_tuples(t::Tuple)
if @generated
texpr = Expr(:tuple)
for i in 1:fieldcount(t)
p = fieldtype(t, i)
if p <: Tuple
for j in 1:fieldcount(p)
push!(texpr.args, :(@inbounds(getfield(getfield(t, $i), $j))))
end
else
push!(texpr.args, :(@inbounds(getfield(t, $i))))
end
end
Expr(:block, Expr(:meta, :inline), texpr)
else
_flatten(t)
end
end
_flatten(::Tuple{}) = ()
@inline function _flatten(t::Tuple{Any, Vararg{Any}})
(getfield(t, 1), _flatten(Base.tail(t))...)
end
@inline function _flatten(t::Tuple{Tuple, Vararg{Any}})
(getfield(t, 1)..., _flatten(Base.tail(t))...)
function flatten_tuples(t::Tuple)
fields = _new_field_positions(t)
ntuple(Val{nfields(fields)}()) do k
i, j = getfield(fields, k)
i = length(t) - i
@inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j)
end
end
_new_field_positions(::Tuple{}) = ()
@nospecialize
function _new_field_positions(x::Tuple)
(_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...)
end
_fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1)))
_fl1(x::Tuple, x1) = ((length(x) - 1, 0),)
@specialize

"""
parent_type(::Type{T}) -> Type
Expand Down Expand Up @@ -299,11 +257,7 @@ ismutable(::Type{BigFloat}) = false
ismutable(::Type{BigInt}) = false
function ismutable(::Type{T}) where {T}
if parent_type(T) <: T
@static if VERSION v"1.7.0-DEV.1208"
return Base.ismutabletype(T)
else
return T.mutable
end
return Base.ismutabletype(T)
else
return ismutable(parent_type(T))
end
Expand Down Expand Up @@ -440,32 +394,17 @@ matrix_colors(A::Bidiagonal) = _cycle(1:2, Base.size(A, 2))
matrix_colors(A::Union{Tridiagonal, SymTridiagonal}) = _cycle(1:3, Base.size(A, 2))
_cycle(repetend, len) = repeat(repetend, div(len, length(repetend)) + 1)[1:len]

@static if VERSION > v"1.9-"
"""
bunchkaufman_instance(A, pivot = LinearAlgebra.RowMaximum()) -> bunchkaufman_factorization_instance
"""
bunchkaufman_instance(A, pivot = LinearAlgebra.RowMaximum()) -> bunchkaufman_factorization_instance
Returns an instance of the Bunch-Kaufman factorization object with the correct type
cheaply.
"""
function bunchkaufman_instance(A::Matrix{T}) where T
return bunchkaufman(similar(A, 0, 0), check = false)
end
function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end
else
"""
bunchkaufman_instance(A, pivot = LinearAlgebra.RowMaximum()) -> bunchkaufman_factorization_instance
Returns an instance of the Bunch-Kaufman factorization object with the correct type
cheaply.
"""
function bunchkaufman_instance(A::Matrix{T}) where T
return bunchkaufman(similar(A, 0, 0))
end
function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)))
end
Returns an instance of the Bunch-Kaufman factorization object with the correct type
cheaply.
"""
function bunchkaufman_instance(A::Matrix{T}) where T
return bunchkaufman(similar(A, 0, 0), check = false)
end
function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end

"""
Expand All @@ -482,32 +421,16 @@ Returns the number.
"""
bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false)

@static if VERSION < v"1.8beta"
const DEFAULT_CHOLESKY_PIVOT = Val(false)
else
const DEFAULT_CHOLESKY_PIVOT = LinearAlgebra.NoPivot()
end
const DEFAULT_CHOLESKY_PIVOT = Val(false)

@static if VERSION > v"1.9-"
"""
cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
"""
cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot, check = false)
end
else
"""
cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot)
end
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot, check = false)
end

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
Expand All @@ -521,23 +444,13 @@ Returns the number.
"""
cholesky_instance(a::Number, pivot = DEFAULT_CHOLESKY_PIVOT) = a

@static if VERSION > v"1.9-"
"""
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
cholesky_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = cholesky(a, pivot, check = false)
else
"""
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
cholesky_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = cholesky(a, pivot)
end
"""
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
cholesky_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = cholesky(a, pivot, check = false)

"""
ldlt_instance(A) -> ldlt_factorization_instance
Expand Down Expand Up @@ -586,18 +499,29 @@ function lu_instance(A::Matrix{T}) where {T}
return LU{luT}(similar(A, 0, 0), ipiv, info)
end
function lu_instance(jac_prototype::SparseMatrixCSC)
@static if VERSION < v"1.9.0-DEV.1622"
SuiteSparse.UMFPACK.UmfpackLU(Ptr{Cvoid}(),
Ptr{Cvoid}(),
1,
1,
jac_prototype.colptr[1:1],
jac_prototype.rowval[1:1],
jac_prototype.nzval[1:1],
0)
else
SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1))
end
SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1))
end

function lu_instance(A::Symmetric{T}) where {T}
noUnitT = typeof(zero(T))
luT = LinearAlgebra.lutype(noUnitT)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
info = zero(LinearAlgebra.BlasInt)
return LU{luT}(similar(A, 0, 0), ipiv, info)
end

noalloc_diag(A::Diagonal) = A.diag
noalloc_diag(A::Tridiagonal) = A.d
noalloc_diag(A::SymTridiagonal) = A.dv

function lu_instance(A::Union{Tridiagonal{T},Diagonal{T},SymTridiagonal{T}}) where {T}
noUnitT = typeof(zero(T))
luT = LinearAlgebra.lutype(noUnitT)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
info = zero(LinearAlgebra.BlasInt)
vectype = similar(noalloc_diag(A), 0)
newA = Tridiagonal(vectype, vectype, vectype)
return LU{luT}(newA, ipiv, info)
end

"""
Expand All @@ -607,23 +531,13 @@ Returns the number.
"""
lu_instance(a::Number) = a

@static if VERSION > v"1.9-"
"""
lu_instance(a::Any) -> lu(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
lu_instance(a::Any) = lu(a, check = false)
else
"""
"""
lu_instance(a::Any) -> lu(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
lu_instance(a::Any) = lu(a)
end
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
lu_instance(a::Any) = lu(a, check = false)

"""
qr_instance(A, pivot = NoPivot()) -> qr_factorization_instance
Expand Down
5 changes: 5 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ end
A = sprand(50, 50, 0.5)
@test lu_instance(A) isa typeof(lu(A))
@test lu_instance(1) === 1

@test lu_instance(Symmetric(rand(3,3))) isa typeof(lu(Symmetric(rand(4,4))))
@test lu_instance(Tridiagonal(rand(3),rand(4),rand(3))) isa typeof(lu(Tridiagonal(rand(3),rand(4),rand(3))))
@test lu_instance(SymTridiagonal(rand(4),rand(3))) isa typeof(lu(SymTridiagonal(rand(4),rand(3))))
@test lu_instance(Diagonal(rand(4))) isa typeof(lu(Diagonal(rand(4))))
end

@testset "ismutable" begin
Expand Down

0 comments on commit c40655d

Please sign in to comment.