Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse vector/matrix: add fast implementation of find_next and find_prev (fixed) #23317

Merged
merged 12 commits into from
Jan 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions base/sparse/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ function Base.reinterpret(::Type, A::AbstractSparseArray)
Try reinterpreting the value itself instead.
""")
end

# The following two methods should be overloaded by concrete types to avoid
# allocating the I = find(...)
_sparse_findnextnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : zero(indtype(v)))
_sparse_findprevnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedlast(I, i); !iszero(n) ? I[n] : zero(indtype(v)))

function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
j = _sparse_findnextnz(v, i)
while !iszero(j) && !f(v[j])
j = _sparse_findnextnz(v, j+1)
end
return j
end

function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
j = _sparse_findprevnz(v, i)
while !iszero(j) && !f(v[j])
j = _sparse_findprevnz(v, j-1)
end
return j
end
6 changes: 3 additions & 3 deletions base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import Base.LinAlg: mul!, ldiv!, rdiv!
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,
cotd, coth, count, csc, cscd, csch, adjoint!, diag, diff, done, dot, eig,
exp10, exp2, findn, floor, hash, indmin, inv, issymmetric, istril, istriu,
log10, log2, lu, next, sec, secd, sech, show, sin,
sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
exp10, exp2, findn, findprev, findnext, floor, hash, indmin, inv,
issymmetric, istril, istriu, log10, log2, lu, next, sec, secd, sech, show,
sin, sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan,
tand, tanh, trace, transpose!, tril!, triu!, trunc, vecnorm, abs, abs2,
broadcast, ceil, complex, cond, conj, convert, copy, copyto!, adjoint, diagm,
exp, expm1, factorize, find, findmax, findmin, findnz, float, getindex,
Expand Down
36 changes: 36 additions & 0 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,42 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
return (I, J, V)
end

function _sparse_findnextnz(m::SparseMatrixCSC, i::Integer)
if i > length(m)
return zero(indtype(m))
end
row, col = Tuple(CartesianIndices(m)[i])
lo, hi = m.colptr[col], m.colptr[col+1]
n = searchsortedfirst(m.rowval, row, lo, hi-1, Base.Order.Forward)
if lo <= n <= hi-1
return LinearIndices(m)[m.rowval[n], col]
end
nextcol = findnext(c->(c>hi), m.colptr, col+1)
if iszero(nextcol)
return zero(indtype(m))
end
nextlo = m.colptr[nextcol-1]
return LinearIndices(m)[m.rowval[nextlo], nextcol-1]
end

function _sparse_findprevnz(m::SparseMatrixCSC, i::Integer)
if iszero(i)
return zero(indtype(m))
end
row, col = Tuple(CartesianIndices(m)[i])
lo, hi = m.colptr[col], m.colptr[col+1]
n = searchsortedlast(m.rowval, row, lo, hi-1, Base.Order.Forward)
if lo <= n <= hi-1
return LinearIndices(m)[m.rowval[n], col]
end
prevcol = findprev(c->(c<lo), m.colptr, col-1)
if iszero(prevcol)
return zero(indtype(m))
end
prevhi = m.colptr[prevcol+1]
return LinearIndices(m)[m.rowval[prevhi-1], prevcol]
end

import Base.Random.GLOBAL_RNG
function sprand_IJ(r::AbstractRNG, m::Integer, n::Integer, density::AbstractFloat)
((m < 0) || (n < 0)) && throw(ArgumentError("invalid Array dimensions"))
Expand Down
18 changes: 18 additions & 0 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,24 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti}
return (I, V)
end

function _sparse_findnextnz(v::SparseVector, i::Integer)
n = searchsortedfirst(v.nzind, i)
if n > length(v.nzind)
return zero(indtype(v))
else
return v.nzind[n]
end
end

function _sparse_findprevnz(v::SparseVector, i::Integer)
n = searchsortedlast(v.nzind, i)
if iszero(n)
return zero(indtype(v))
else
return v.nzind[n]
end
end

### Generic functions operating on AbstractSparseVector

### getindex
Expand Down
31 changes: 31 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,37 @@ end
@test count(SparseMatrixCSC(2, 2, Int[1, 2, 3], Int[1, 2], Bool[true, true, true])) == 2
end

@testset "sparse findprev/findnext operations" begin

x = [0,0,0,0,1,0,1,0,1,1,0]
x_sp = sparse(x)

for i=1:length(x)
@test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i)
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i)
end

y = [0 0 0 0 0;
1 0 1 0 0;
1 0 0 0 1;
0 0 1 0 0;
1 0 1 1 0]
y_sp = sparse(y)

for i=1:length(y)
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i)
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i)
end

z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1))
z = collect(z_sp)

for i=1:length(z)
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end
end

# #20711
@testset "vec returns a view" begin
local A = sparse(Matrix(1.0I, 3, 3))
Expand Down