Skip to content

Commit

Permalink
Merge pull request #115 from JuliaSymbolics/s/fix-sparse-set_array
Browse files Browse the repository at this point in the history
Fix set_array for sparse arrays
  • Loading branch information
ChrisRackauckas authored Mar 13, 2021
2 parents 37f2166 + ef52c0e commit b812742
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
19 changes: 4 additions & 15 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,6 @@ function _build_and_inject_function(mod::Module, ex)
RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex)
end

# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
function is_array_matrix(F)
return isa(F, AbstractVector) && all(x->isa(x, AbstractArray), F)
end
function is_array_sparse_matrix(F)
return isa(F, AbstractVector) && all(x->isa(x, AbstractSparseMatrix), F)
end
# Detect heterogeneous element types of "arrays of arrays of matrices/sparce matrices"
function is_array_array_matrix(F)
return isa(F, AbstractVector) && all(x->isa(x, AbstractArray{<:AbstractMatrix}), F)
end
function is_array_array_sparse_matrix(F)
return isa(F, AbstractVector) && all(x->isa(x, AbstractArray{<:AbstractSparseMatrix}), F)
end

toexpr(n::Num, st) = toexpr(value(n), st)

function fill_array_with_zero!(x::AbstractArray)
Expand Down Expand Up @@ -313,6 +298,10 @@ function set_array(s::MultithreadedForm, closed_args, out, outputidxs, rhss, che
SpawnFetch{MultithreadedForm}(first.(arrays), last.(arrays), @inline noop(args...) = nothing)
end

function _set_array(out, outputidxs, rhss::AbstractSparseArray, checkbounds, skipzeros)
_set_array(LiteralExpr(:($out.nzval)), nothing, rhss.nzval, checkbounds, skipzeros)
end

function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
if outputidxs === nothing
outputidxs = collect(eachindex(rhss))
Expand Down
12 changes: 12 additions & 0 deletions test/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,15 @@ f = eval(build_function(sparse([1],[1], [(x+y)/k], 10,10), [x,y,k])[1])
@test size(f([1.,1.,2])) == (10,10)
@test f([1.,1.,2])[1,1] == 1.0
@test sum(f([1.,1.,2])) == 1.0

let # ModelingToolkit.jl#800
@variables x
y = sparse(1:3,1:3,x)

f1,f2 = build_function(y,x)
sf1, sf2 = string(f1), string(f2)
@test !contains(sf1, "CartesianIndex")
@test !contains(sf2, "CartesianIndex")
@test contains(sf1, "SparseMatrixCSC(")
@test contains(sf2, ".nzval")
end

0 comments on commit b812742

Please sign in to comment.