Skip to content

Commit

Permalink
make setindex! not remove zeros from sparsity pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Mar 20, 2016
1 parent 05bdbb0 commit adc7508
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 113 deletions.
142 changes: 35 additions & 107 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2240,19 +2240,7 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer)
v = convert(T, v)
r1 = Int(A.colptr[i1])
r2 = Int(A.colptr[i1+1]-1)
if v == 0 #either do nothing or delete entry if it exists
if r1 <= r2
r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward)
if (r1 <= r2) && (A.rowval[r1] == i0)
deleteat!(A.rowval, r1)
deleteat!(A.nzval, r1)
@simd for j = (i1+1):(A.n+1)
@inbounds A.colptr[j] -= 1
end
end
end
return A
end

i = (r1 > r2) ? r1 : searchsortedfirst(A.rowval, i0, r1, r2, Forward)

if (i <= r2) && (A.rowval[i] == i0)
Expand All @@ -2279,8 +2267,7 @@ setindex!(A::SparseMatrixCSC, x, ::Colon, ::Colon) = setindex!(A, x, 1:size(A, 1
setindex!(A::SparseMatrixCSC, x, ::Colon, j::Union{Integer, AbstractVector}) = setindex!(A, x, 1:size(A, 1), j)
setindex!(A::SparseMatrixCSC, x, i::Union{Integer, AbstractVector}, ::Colon) = setindex!(A, x, i, 1:size(A, 2))

setindex!{Tv,T<:Integer}(A::SparseMatrixCSC{Tv}, x::Number, I::AbstractVector{T}, J::AbstractVector{T}) =
(0 == x) ? spdelete!(A, I, J) : spset!(A, convert(Tv,x), I, J)
setindex!{Tv,T<:Integer}(A::SparseMatrixCSC{Tv}, x::Number, I::AbstractVector{T}, J::AbstractVector{T}) = spset!(A, convert(Tv,x), I, J)

function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector{Ti}, J::AbstractVector{Ti})
!issorted(I) && (@inbounds I = I[sortperm(I)])
Expand Down Expand Up @@ -2390,63 +2377,6 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
return A
end

function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}, J::AbstractVector{Ti})
m, n = size(A)
nnzA = nnz(A)
(nnzA == 0) && (return A)

!issorted(I) && (@inbounds I = I[sortperm(I)])
!issorted(J) && (@inbounds J = J[sortperm(J)])

((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))

colptr = colptrA = A.colptr
rowval = rowvalA = A.rowval
nzval = nzvalA = A.nzval
rowidx = 1
ndel = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(ndel > 0) && (colptrA[col] = colptr[col] - ndel)
if isempty(rrange) || !(col in J)
nincl = length(rrange)
if(ndel > 0) && !isempty(rrange)
copy!(rowvalA, rowidx, rowval, rrange[1], nincl)
copy!(nzvalA, rowidx, nzval, rrange[1], nincl)
end
rowidx += nincl
else
for ridx in rrange
if rowval[ridx] in I
if ndel == 0
colptrA = copy(colptr)
rowvalA = copy(rowval)
nzvalA = copy(nzval)
end
ndel += 1
else
if ndel > 0
rowvalA[rowidx] = rowval[ridx]
nzvalA[rowidx] = nzval[ridx]
end
rowidx += 1
end
end
end
end

if ndel > 0
colptrA[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.colptr = colptrA
A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end

setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::AbstractVector{T}, J::AbstractVector{T}) =
setindex!(A, convert(SparseMatrixCSC{Tv,Ti}, S), I, J)

Expand Down Expand Up @@ -2596,7 +2526,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
nadd = 0
bidx = xidx = 1
r1 = r2 = 0

Expand All @@ -2612,7 +2542,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
if (nadd > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
Expand All @@ -2621,25 +2551,25 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)
# 0: update, 1: add new
if r1 <= r2 && rowvalA[r1] == row
mode = 0
else
mode = 1
end

if (mode > 1) && (nadd == 0) && (ndel == 0)
if (mode == 1) && (nadd == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalB = Array(Ti, length(rowvalA)+n); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+n); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
if mode == 0
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
r1 += 1
elseif mode == 2
r1 += 1
ndel += 1
elseif mode == 3
elseif mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
Expand All @@ -2649,7 +2579,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
end # if I[row, col]
end # for row in 1:A.m

if ((nadd != 0) || (ndel != 0))
if (nadd != 0)
l = r2-r1+1
if l > 0
copy!(rowvalB, bidx, rowvalA, r1, l)
Expand All @@ -2659,8 +2589,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
colptrB[col+1] = bidx

if (xidx > n) && (length(colptrB) > (col+1))
diff = nadd - ndel
colptrB[(col+2):end] = colptrA[(col+2):end] .+ diff
colptrB[(col+2):end] = colptrA[(col+2):end] .+ nadd
r1 = colptrA[col+1]
r2 = colptrA[end]-1
l = r2-r1+1
Expand All @@ -2676,7 +2605,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
(xidx > n) && break
end # for col in 1:A.n

if (nadd != 0) || (ndel != 0)
if (nadd != 0)
n = length(nzvalB)
if n > (bidx-1)
deleteat!(nzvalB, bidx:n)
Expand All @@ -2694,7 +2623,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval; szA = size(A)
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
nadd = 0
bidx = aidx = 1

S = issorted(I) ? (1:n) : sortperm(I)
Expand All @@ -2715,8 +2644,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
r2 = Int(colptrA[col+1] - 1)

# copy from last position till current column
if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ (nadd - ndel)
if (nadd > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ nadd
copylen = r1 - aidx
if copylen > 0
copy!(rowvalB, bidx, rowvalA, aidx, copylen)
Expand All @@ -2733,7 +2662,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
if (nadd > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
Expand All @@ -2743,27 +2672,26 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)
# 0: update, 1: add new
if r1 <= r2 && rowvalA[r1] == row
mode = 0
else
mode = 1
end

if (mode > 1) && (nadd == 0) && (ndel == 0)
if (mode == 1) && (nadd == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalB = Array(Ti, length(rowvalA)+n); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+n); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
if mode == 0
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
aidx += 1
r1 += 1
elseif mode == 2
r1 += 1
aidx += 1
ndel += 1
elseif mode == 3
elseif mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
Expand All @@ -2772,8 +2700,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end

# copy the rest
@inbounds if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd - ndel)
@inbounds if (nadd > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd)
r1 = colptrA[end]-1
copylen = r1 - aidx + 1
if copylen > 0
Expand Down
28 changes: 22 additions & 6 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,24 @@ for (aa116, ss116) in [(a116, s116), (ad116, sd116)]
end
end

# workaround issue #7197: comment out let-block
#let S = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])
S1290 = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])

let S1290 = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])
S1290[1,1] = 1
S1290[5] = 2
S1290[end] = 3
@test S1290[end] == (S1290[1] + S1290[2,2])
@test 6 == sum(diag(S1290))
@test full(S1290)[[3,1],1] == full(S1290[[3,1],1])
# end
end


# setindex tests
let a = spzeros(5, 5)
a[3,2] = 0.0
@test countnz(a) == 0
@test nnz(a) == 1
end

let a = spzeros(Int, 10, 10)
@test countnz(a) == 0
a[1,:] = 1
Expand All @@ -547,6 +552,9 @@ let a = spzeros(Int, 10, 10)
@test a[1,:] == sparse([1:10;])
a[:,2] = 1:10
@test a[:,2] == sparse([1:10;])
a[:,2] = 0
@test countnz(a) == 9
@test nnz(a) == 19
end

let A = spzeros(Int, 10, 20)
Expand All @@ -559,8 +567,11 @@ let A = spzeros(Int, 10, 20)
A[6:10,11:20] = 20
@test countnz(A) == 100
@test A[6:10,11:20] == 20 * ones(Int, 5, 10)
# Storing zeros in structural nonzeros doesn't modify sparsity pattern
A[6:10,11:20] = 0
@test nnz(A) == 100
A[4:8,8:16] = 15
@test countnz(A) == 121
@test nnz(A) == 121
@test A[4:8,8:16] == 15 * ones(Int, 5, 9)
end

Expand All @@ -587,6 +598,8 @@ let A = speye(Int, 5), I=1:10, X=reshape([trues(10); falses(15)],5,5)
@test A[I] == A[X] == [1,0,0,0,0,0,1,0,0,0]
A[I] = [1:10;]
@test A[I] == A[X] == collect(1:10)
A[I] = zeros(Int, 10)
@test A[I] == A[X] == zeros(Int, 10)
end

let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprandbool(50, 30, 0.2)
Expand All @@ -603,10 +616,12 @@ let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprandbool(50, 30, 0
@test (sum(S) + sumFI) == sumS1

S[FI] = 10
nnz_S1 = nnz(S)
@test sum(S) == sumS2 + 10*sum(FI)
S[FI] = 0
nnz_S2 = nnz(S)
@test sum(S) == sumS2

@test nnz_S1 == nnz_S2
S[FI] = [1:sum(FI);]
@test sum(S) == sumS2 + sum(1:sum(FI))
end
Expand Down Expand Up @@ -1291,6 +1306,7 @@ let
x = UpperTriangular(A)*ones(n)
@test UpperTriangular(A)\x ones(n)
A[2,2] = 0
Base.SparseArrays.dropzeros!(A)
@test_throws LinAlg.SingularException LowerTriangular(A)\ones(n)
@test_throws LinAlg.SingularException UpperTriangular(A)\ones(n)
end
Expand Down

0 comments on commit adc7508

Please sign in to comment.