Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Redesign underlying storage #39

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 95 additions & 125 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,158 +1,128 @@
"""
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M}
OneHotArray(indices, L)
OneHotVector{T,L}
OneHotVector(index, L)

A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`)
stored as a compact `N == M-1`-dimensional array of indices.

Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
Parameter `I` is the type of the underlying storage, and `T` its eltype.
A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`).
"""
struct OneHotArray{T<:Integer, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
indices::I
nlabels::Int
struct OneHotVector{T,L} # <: AbstractVector{Bool} # this does everything but is too limiting
index::Integer
function OneHotVector{T,L}(index) where {T,L}
@assert 1 <= index <= L "OneHotVector index $(index) out of range [1,$(L)]"
return new(index)
end
end
OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L)
OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L)
OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

"""
OneHotVector{T} = OneHotArray{T, 0, 1, T}
OneHotVector(indices, L)

A one-hot vector with `L` labels (i.e. `length(A) == L` and `count(A) == 1`) typically constructed by [`onehot`](@ref).
Stored efficiently as a single index of type `T`, usually `UInt32`.
"""
const OneHotVector{T} = OneHotArray{T, 0, 1, T}
OneHotVector(idx, L) = OneHotArray(idx, L)

"""
OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I}
OneHotMatrix(indices, L)

A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref).
Stored efficiently as a vector of indices with type `I` and eltype `T`.
"""
const OneHotMatrix{T, I} = OneHotArray{T, 1, 2, I}
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, N, var"N+1", I} =
Union{OneHotArray{T, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, <:Any, <:Any, I}}}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) = (size(x, 1) == parent(x).nlabels)

_check_nlabels(L, xs::OneHotLike...) = all(size.(xs, 1) .== L)

_nlabels(x::OneHotArray) = size(x, 1)
function _nlabels(x::OneHotLike, xs::OneHotLike...)
L = size(x, 1)
_check_nlabels(L, xs...) ||
throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))

return L
OneHotVector(t::Type, index, nlabels) = OneHotVector{t,nlabels}(index)
OneHotVector(index, nlabels) = OneHotVector{Float32,nlabels}(index)
Base.size(x::OneHotVector{T,L}) where {T,L} = (L,)
function Base.getindex(x::OneHotVector{T,L}, i::Integer) where {T,L}
@boundscheck 1 <= i <= L
i == x.index || throw(BoundsError(x, i))
end
Base.show(io::IO, x::OneHotVector{T,L}) where {T,L} = Base.show(io, setindex!(zeros(T, L), convert(T, 1), x.index))
Base.argmax(x::OneHotVector; dims = Colon()) = x.index

struct OneHotArray{T,N,M,L,A} <: AbstractArray{T,N}
onehotvectors::AbstractArray{OneHotVector{T,L},M}
function OneHotArray(onehotaxis, onehotvectors::AbstractArray{OneHotVector{T,L},M}) where {T,L,M}
N = M+1
@assert onehotaxis isa Integer "onehot axis must be integer"
@assert 1 <= onehotaxis <= N "onehot axis out of range [1,$N]"
new{T,N,M,L,onehotaxis}(onehotvectors)
end
end
onehotaxis(x::OneHotArray{T,N,M,L,A}) where {T,N,M,L,A} = A
function size_selector(i, shape, onehotaxis, L)
if i < onehotaxis
shape[i]
elseif i > onehotaxis
shape[i-1]
else
L
end
end

Base.size(x::OneHotArray) = (x.nlabels, size(x.indices)...)

function Base.getindex(x::OneHotArray{<:Any, N}, i::Int, I::Vararg{Int, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
Base.size(x::OneHotArray{T,N,M,L,A}) where {T,N,M,L,A} = ntuple(i -> size_selector(i, size(x.onehotvectors), onehotaxis(x), L), N)
function Base.getindex(x::OneHotArray{T,N,M,L}, i::Integer) where {T,N,M,L}
flat_onehotvectors = x.onehotvectors[:]
h_ind, L_ind = fldmod(i, L)
@boundscheck 1 <= h_ind <= length(flat_onehotvectors) || throw(BoundsError(x, i))
@boundscheck 1 <= L_ind <= L || throw(BoundsError(x, i))
return flat_onehotvectors[h_ind].index == L_ind
end
# the method above is faster on the CPU but will scalar index on the GPU
# so we define the method below to pass the extra indices directly to GPU array
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray},
i::Int,
I::Vararg{Any, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
function Base.getindex(x::OneHotArray{T,N,M,L}, i::Vararg{Integer,N}) where {T,N,M,L}
@boundscheck all(1 .<= i .<= size(x)) || throw(BoundsError(x, (i...)))
index_pre = i[1:onehotaxis(x)-1]
index_post = i[onehotaxis(x)+1:end]
intern_ind = x.onehotvectors[index_pre..., index_post...].index
return convert(T, intern_ind == i[onehotaxis(x)])
end
function Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, I::Vararg{Any, N}) where N
return OneHotArray(x.indices[I...], x.nlabels)

function Base.show(io::IO, x::OneHotArray{T,N,M,L}) where {T,N,M,L}
z = zeros(Int32, size(x))
# loop efficiently over only the ones
for ext_ind in eachindex(IndexCartesian(), x.onehotvectors)
index_pre = Tuple(ext_ind)[1:onehotaxis(x)]
index_post = Tuple(ext_ind)[onehotaxis(x)+1:end]
intern_ind = x.onehotvectors[ext_ind].index
ind = CartesianIndex(CartesianIndex(index_pre..., intern_ind, index_post...))
setindex!(z, convert(Int32, 1), ind)
end
Base.show(io, z)
end
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x

function Base.showarg(io::IO, x::OneHotArray, toplevel)
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
Base.showarg(io, x.indices, false)
print(io, ')')
toplevel && print(io, " with eltype Bool")
function Base.showarg(io::IO, x::OneHotArray{T,N,M,L,A}, toplevel) where {T,N,M,L,A}
print(io, "$(size(x)) OneHotArray")
toplevel && print(io, " with one hot axis $A and eltype $T")
return nothing
end

# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
function Base.replace_in_print_matrix(x::OneHotArray, i::Integer, j::Integer, s::AbstractString)
x[i,j] > 0 ? s : Base.replace_with_centered_mark(s)
end

# copy CuArray versions back before trying to print them:
for fun in (:show, :print_array) # print_array is used by 3-arg show
@eval begin
Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} =
Base.$fun(io::IO, X::OneHotArray{T,N,M,L, <:AbstractGPUArray}) where {T,N,M,L} =
Base.$fun(io, adapt(Array, X))
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} =
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotArray{T,N,M,L,<:AbstractGPUArray}}) where {T,N,M,L} =
Base.$fun(io, adapt(Array, X))
end
end

_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}
# Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels) # TODO: edit this

_notall_onehot(x::OneHotArray, xs::OneHotArray...) = false
_notall_onehot(x::OneHotLike, xs::OneHotLike...) = any(x -> !_isonehot(x), (x, xs...))
# function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray} # TODO: edit this
# # We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
# S = Base.BroadcastStyle(T)
# # S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# # isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
# (typeof(S).name.wrapper){var"N+1"}()
# end

function Base.cat(x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...; dims::Int) where N
if isone(dims) || _notall_onehot(x, xs...)
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
else
L = _nlabels(x, xs...)
# Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
function Base.cat(xs::OneHotArray{T,N,M,L,A}...; dims) where {T,N,M,L,A}
if dims != A
onehotvectors_dim = dims < A ? dims : dims - 1
OneHotArray(A, cat((x.onehotvectors for x in xs)...; dims=onehotvectors_dim))
else
invoke(cat, AbstractArray, xs...; dims=dims)
end
end

Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) =
vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...)

# optimized concatenation for matrices and vectors of same parameters
Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))
Base.hcat(x::OneHotVector, xs::OneHotVector...) =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))

if isdefined(Base, :stack)
import Base: _stack
else
import Compat: _stack
end
function _stack(::Colon, xs::AbstractArray{<:OneHotArray})
n = _nlabels(first(xs))
all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))
OneHotArray(Compat.stack(_indices, xs), n)
end

Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)

function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
(typeof(S).name.wrapper){var"N+1"}()
end
Base.argmax(x::OneHotArray; dims = Colon()) =
dims == onehotaxis(x) ?
argmax.(x.onehotvectors) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
"""
OneHotMatrix{T, L, A} = OneHotArray{T,2,1,L,1}
OneHotMatrix(indices, L)

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
A one-hot matrix (with `L` labels) typically constructed using [`onehotbatch`](@ref).
"""
const OneHotMatrix{T, L} = OneHotArray{T, 2,1,L,1}
OneHotMatrix(indices, L) = OneHotArray(1, [OneHotVector(index,L) for index in indices])
32 changes: 15 additions & 17 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
function Base.:(*)(A::AbstractMatrix, B::OneHotLike)
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
function Base.:(*)(A::AbstractMatrix, B::OneHotArray)
onehotaxis(B) == 1 || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
return A[:, onecold(B)]
return A[:, argmax(B;dims=1)]
end

function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
function Base.:(*)(A::AbstractMatrix, B::OneHotArray{T,2,1,L,Ax}) where {T,L,Ax}
onehotaxis(B) == 1 || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
return NNlib.gather(A, _indices(B))
return NNlib.gather(A, [v.index for v in B.onehotvectors])
end

function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotArray})
B_dim = length(_indices(parent(B)))
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
end

for wrapper in [:Adjoint, :Transpose]
@eval begin
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
size(A, 2) == length(b) ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))"))

return A[:, onecold(b)]
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{T,L}) where {T,L}
size(A, 2) == L ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(L)"))
return A[:, argmax(b)]
end

function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector) where T
size(A, 2) == length(b) ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))"))

return A[onecold(b)]
function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{T,L}) where {T,L}
size(A, 2) == L ||
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(L)"))
return A[argmax(b)]
end
end
end
25 changes: 11 additions & 14 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ julia> hcat(αβγ...) # preserves sparsity
function onehot(x, labels)
i = _findval(x, labels)
isnothing(i) && error("Value $x is not in labels")
OneHotVector{UInt32}(i, length(labels))
OneHotVector(i, length(labels))
end

function onehot(x, labels, default)
i = _findval(x, labels)
isnothing(i) && return onehot(default, labels)
OneHotVector{UInt32}(i, length(labels))
OneHotVector(i, length(labels))
end

_findval(val, labels) = findfirst(isequal(val), labels)
Expand Down Expand Up @@ -90,14 +90,16 @@ function _onehotbatch(data, labels)
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
L = length(labels)
return OneHotArray(1, (index -> OneHotVector(Int32, index, L)).(indices))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
L = length(labels)
return OneHotArray(1, (index -> OneHotVector(index, L)).(indices))
end

function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
Expand All @@ -106,7 +108,8 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<
hi > last(labels) && error("Value $hi not found in labels")
offset = 1 - first(labels)
indices = UInt32.(data .+ offset)
return OneHotArray(indices, length(labels))
L = length(labels)
return OneHotArray(1, (index -> OneHotVector(index, L)).(indices))
end
# That bounds check with extrema synchronises on GPU, much slower than rest of the function,
# hence add a special method, with a less helpful error message:
Expand All @@ -117,7 +120,8 @@ function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRang
checkbounds(labels, i)
i
end
return OneHotArray(indices, length(labels))
L = length(labels)
return OneHotArray(1, (index -> OneHotVector(index, L)).(indices))
end

"""
Expand Down Expand Up @@ -165,14 +169,7 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = _indices(x)
function _fast_argmax(x::OneHotLike)
if _isonehot(x)
return _indices(x)
else
return _fast_argmax(convert(_onehot_bool_type(x), x))
end
end
_fast_argmax(x::Union{OneHotVector, OneHotArray}) = argmax(x)

ChainRulesCore.@non_differentiable onehot(::Any...)
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
Expand Down