Skip to content

Commit

Permalink
Simplify getindex using the new eachindex
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauman committed Apr 2, 2015
1 parent 307d21f commit d6802ad
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 82 deletions.
4 changes: 4 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ typealias AbstractVector{T} AbstractArray{T,1}
typealias AbstractMatrix{T} AbstractArray{T,2}
typealias AbstractVecOrMat{T} Union(AbstractVector{T}, AbstractMatrix{T})

## Definitions from operators.jl ##
index_lengths_dim(A, dim, i::AbstractVector{Bool}, I...) = tuple(sum(i), index_lengths_dim(A, dim+1, I...)...)
index_shape_dim(A, dim, i::AbstractVector{Bool}, I...) = tuple(sum(i), index_shape_dim(A, dim+1, I...)...)

## Basic functions ##

vect() = Array(Any, 0)
Expand Down
147 changes: 66 additions & 81 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,26 @@ stagedfunction _unsafe_getindex(l::LinearIndexing, A::AbstractArray, I::Union(Re
@nexprs $N d->(I_d = to_index(I[d]))
dest = similar(A, @ncall $N index_shape A I)
@ncall $N checksize dest I
@ncall $N _unsafe_getindex! linearindexing(dest) dest l A I
@ncall $N _unsafe_getindex! dest l A I
end
end

# logical indexing optimization - don't use find (within to_index)
# Todo: use magic to speed up LinearSlow src
# This is inherently a linear operation in the source, but we could potentially
# use fast dividing integers to speed it up.
function _unsafe_getindex(::LinearIndexing, src::AbstractArray, I::AbstractArray{Bool})
# Both index_shape and checksize compute sum(I); manually hoist it out
N = sum(I)
dest = similar(src, (N,))
size(dest) == (N,) || throw(DimensionMismatch())
c = 1
for i = 1:length(I)
if unsafe_getindex(I, i)
unsafe_setindex!(dest, src[i], c)
c += 1
D = eachindex(dest)
Ds = start(D)
s = 0
for b in I
s+=1
if b
d, Ds = next(D, Ds)
unsafe_setindex!(dest, src[s], d)
end
end
dest
Expand Down Expand Up @@ -231,67 +235,45 @@ stagedfunction _unsafe_getindex{T,AN,N}(l::LinearIndexing, A::AbstractArray{T,AN
end
end

# Indexing with an array of indices is inherently linear in the source
# With a fast destination, use linear indexing for dest and iterate over indices
function _unsafe_getindex!(::LinearFast, dest::AbstractArray, ::LinearIndexing, src::AbstractArray, I::AbstractArray)
d = 0
# Indexing with an array of indices is inherently linear in the source, but
# might be able to be optimized with fast dividing integers
function _unsafe_getindex!(dest::AbstractArray, ::LinearIndexing, src::AbstractArray, I::AbstractArray)
D = eachindex(dest)
Ds = start(D)
for idx in I
d += 1
unsafe_setindex!(dest, unsafe_getindex(src, idx), d)
end
dest
end
# With a slow destination, cartesian indexing for dest and iterate over indices
# It'd be nice to use `zip` here, but it's slower than manually writing the loop
function _unsafe_getindex!(::LinearSlow, dest::AbstractArray, ::LinearIndexing, src::AbstractArray, I::AbstractArray)
itr = eachindex(dest)
s = start(itr)
for idx in I
(d, s) = next(itr, s)
d, Ds = next(D, Ds)
unsafe_setindex!(dest, unsafe_getindex(src, idx), d)
end
dest
end

# Both fast
stagedfunction _unsafe_getindex!(::LinearFast, dest::AbstractArray, ::LinearFast, src::AbstractArray, I::Union(Real, AbstractVector, Colon)...)
# Fast source - compute the linear index
stagedfunction _unsafe_getindex!(dest::AbstractArray, ::LinearFast, src::AbstractArray, I::Union(Real, AbstractVector, Colon)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
stride_1 = 1
@nexprs $N d->(stride_{d+1} = stride_d*size(src, d))
$(symbol(:offset_, N)) = 1
k = 1
@nloops $N i dest d->(@inbounds offset_{d-1} = offset_d + (unsafe_getindex(I[d], i_d)-1)*stride_d) begin
unsafe_setindex!(dest, unsafe_getindex(src, offset_0), k)
k += 1
end
dest
end
end
# Fast destination, slow source
stagedfunction _unsafe_getindex!(::LinearFast, dest::AbstractArray, ::LinearSlow, src::AbstractArray, I::Union(Real, AbstractVector, Colon)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
k = 1
@nloops $N i dest d->(@inbounds j_d = unsafe_getindex(I[d], i_d)) begin
v = @ncall $N unsafe_getindex src j
unsafe_setindex!(dest, v, k)
k += 1
D = eachindex(dest)
Ds = start(D)
@nloops $N i dest d->(offset_{d-1} = offset_d + (unsafe_getindex(I[d], i_d)-1)*stride_d) begin
d, Ds = next(D, Ds)
unsafe_setindex!(dest, unsafe_getindex(src, offset_0), d)
end
dest
end
end
# A slow destination. It's unlikely a fast array would give a slow similar array
# so it's not worth specializing that case.
stagedfunction _unsafe_getindex!(::LinearSlow, dest::AbstractArray, ::LinearIndexing, src::AbstractArray, I::Union(Real, AbstractVector, Colon)...)
# Slow source - index with the indices provided.
# TODO: this may not be the full dimensionality; that case could be optimized
stagedfunction _unsafe_getindex!(dest::AbstractArray, ::LinearSlow, src::AbstractArray, I::Union(Real, AbstractVector, Colon)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nloops $N i dest d->(@inbounds j_d = unsafe_getindex(I[d], i_d)) begin
D = eachindex(dest)
Ds = start(D)
@nloops $N i dest d->(j_d = unsafe_getindex(I[d], i_d)) begin
d, Ds = next(D, Ds)
v = @ncall $N unsafe_getindex src j
@ncall $N unsafe_setindex! dest v i
unsafe_setindex!(dest, v, d)
end
dest
end
Expand All @@ -303,15 +285,13 @@ checksize(A::AbstractArray, I::AbstractArray{Bool}) = length(A) == sum(I) || thr
stagedfunction checksize(A::AbstractArray, I...)
N = length(I)
quote
@nexprs $N d->(size(A, d) == index_length(A, d, I[d]) || throw(DimensionMismatch("index $d selects $(length(I[d])) elements, but size(A, $d) = $(size(A,d))")))
@nexprs $N d->(_checksize(A, d, I[d]) || throw(DimensionMismatch("index $d selects $(length(I[d])) elements, but size(A, $d) = $(size(A,d))")))
end
end
index_length(A::AbstractArray, dim, I) = length(I)
index_length(A::AbstractArray, dim, I::AbstractVector{Bool}) = sum(I)
index_length(A::AbstractArray, dim, ::Colon) = size(A, dim)
index_length(A::AbstractArray, dim, ::Real) = 1

@inline unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
_checksize(A::AbstractArray, dim, I) = size(A, dim) == length(I)
_checksize(A::AbstractArray, dim, I::AbstractVector{Bool}) = size(A, dim) == sum(I)
_checksize(A::AbstractArray, dim, ::Colon) = true
_checksize(A::AbstractArray, dim, ::Real) = size(A, dim) == 1

@inline unsafe_setindex!{T}(v::Array{T}, x::T, ind::Int) = (@inbounds v[ind] = x; v)
@inline unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Int) = (v[ind] = x; v)
Expand Down Expand Up @@ -591,6 +571,9 @@ end

## getindex

@inline unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
@inline unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v)

# contiguous multidimensional indexing: if the first dimension is a range,
# we can get some performance from using copy_chunks!

Expand All @@ -607,6 +590,7 @@ function getindex(B::BitArray, ::Colon)
return X
end

# Optimization where the inner dimension is contiguous
stagedfunction unsafe_getindex(B::BitArray, I0::Union(Colon,UnitRange{Int}), I::Union(Int,UnitRange{Int},Colon)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
Expand Down Expand Up @@ -640,30 +624,31 @@ stagedfunction unsafe_getindex(B::BitArray, I0::Union(Colon,UnitRange{Int}), I::
end
end

# general multidimensional non-scalar indexing

stagedfunction unsafe_getindex(B::BitArray, I::Union(Int,AbstractVector{Int},Colon)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
shape = @ncall $N index_shape B I
X = BitArray(shape)
Xc = X.chunks

stride_1 = 1
@nexprs $N d->(stride_{d+1} = stride_d * size(B, d))
@nexprs 1 d->(offset_{$N} = 1)
ind = 1
@nloops($N, i, X, d->(@inbounds j_d = unsafe_getindex(I[d], i_d);
offset_{d-1} = offset_d + (j_d-1)*stride_d), # PRE
begin
unsafe_bitsetindex!(Xc, B[offset_0], ind)
ind += 1
end)
return X
end
end
# # TODO? in the general multidimensional non-scalar case, can we do slightly
# # better by manually hoisting the offset calculations? If this is re-enabled,
# # dispatch must be altered to prevent scalar indexing from returning an array
# stagedfunction unsafe_getindex(B::BitArray, I::Union(Int,AbstractVector{Int},Colon)...)
# N = length(I)
# Isplat = Expr[:(I[$d]) for d = 1:N]
# quote
# @nexprs $N d->(I_d = I[d])
# shape = @ncall $N index_shape B I
# X = BitArray(shape)
# Xc = X.chunks
#
# stride_1 = 1
# @nexprs $N d->(stride_{d+1} = stride_d * size(B, d))
# @nexprs 1 d->(offset_{$N} = 1)
# ind = 1
# @nloops($N, i, X, d->(@inbounds j_d = unsafe_getindex(I[d], i_d);
# offset_{d-1} = offset_d + (j_d-1)*stride_d), # PRE
# begin
# unsafe_bitsetindex!(Xc, B[offset_0], ind)
# ind += 1
# end)
# return X
# end
# end


## setindex!
Expand Down
2 changes: 1 addition & 1 deletion base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ function promote_shape(a::Dims, b::Dims)
return a
end

# The lengths of the given indices, lowering : to the appropriate size
# Recursively compute the lengths of a list of indices
index_lengths(A::AbstractArray, I...) = index_lengths_dim(A, 1, I...)
index_lengths_dim(A, dim) = ()
index_lengths_dim(A, dim, ::Colon) = dim == 1 ? (length(A),) : (trailingsize(A, dim),)
Expand Down

0 comments on commit d6802ad

Please sign in to comment.