Skip to content

Commit

Permalink
Ensure indexing is consistent between...
Browse files Browse the repository at this point in the history
one-column sparse matrices and sparse vectors, with special attention to stored zeros.
  • Loading branch information
mbauman committed Oct 7, 2015
1 parent 0373f55 commit 495b010
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
12 changes: 9 additions & 3 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ function _logical_index{Tv}(A::SparseMatrixCSC{Tv}, I::AbstractArray{Bool})
SparseVector(n, rowvalB, nzvalB)
end

# TODO: further optimizations are available for I::Range
# TODO: huge optimizations are available for I::Range and ::Colon
getindex(A::SparseMatrixCSC, ::Colon) = A[1:end]
function getindex{Tv}(A::SparseMatrixCSC{Tv}, I::AbstractVector)
szA = size(A)
nA = szA[1]*szA[2]
Expand Down Expand Up @@ -499,8 +500,13 @@ getindex{Tv,Ti}(x::AbstractSparseVector{Tv,Ti}, I::AbstractArray{Bool}) = x[find
S[I, 1]
end

# TODO: do this without reshaping
getindex{Tv,Ti}(x::AbstractSparseVector{Tv,Ti}, I::AbstractArray) = reshape(x[vec(I)], size(I))
function getindex{Tv,Ti}(x::AbstractSparseVector{Tv,Ti}, I::AbstractArray)
# punt to SparseMatrixCSC
S = SparseMatrixCSC(x.n, 1, [1,length(x.nzind)+1], x.nzind, x.nzval)
S[I]
end

getindex(x::AbstractSparseVector, ::Colon) = copy(x)

### show and friends

Expand Down
15 changes: 15 additions & 0 deletions test/sparsedir/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,18 @@ sv[1] = 0

# Ensure that sparsevec with all-zero values returns an array of zeros
@test sparsevec([1,2,3],[0,0,0]) == [0,0,0]

# Compare stored zero semantics between SparseVector and SparseMatrixCSC
let S = SparseMatrixCSC(10,1,[1,6],[1,3,5,6,7],[0,1,2,0,3]), x = SparseVector(10,[1,3,5,6,7],[0,1,2,0,3])
@test nnz(S) == nnz(x) == 5
for I = (:, 1:10, collect(1:10))
@test S[I,1] == S[I] == x[I] == x
@test nnz(S[I,1]) == nnz(S[I]) == nnz(x[I]) == nnz(x)
end
for I = (2:9, 1:2, 9:10, [3,6,1], [10,9,8], [])
@test S[I,1] == S[I] == x[I]
@test nnz(S[I,1]) == nnz(S[I]) == nnz(x[I])
end
@test S[[1 3 5; 2 4 6]] == x[[1 3 5; 2 4 6]]
@test nnz(S[[1 3 5; 2 4 6]]) == nnz(x[[1 3 5; 2 4 6]])
end

0 comments on commit 495b010

Please sign in to comment.