Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary dimension one-hot arrays #1448

Merged
merged 13 commits into from
Jan 8, 2021
6 changes: 3 additions & 3 deletions docs/src/data/onehot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog
julia> using Flux: onehot, onecold

julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
1
0

julia> onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
0
1
Expand Down Expand Up @@ -44,7 +44,7 @@ Flux.onecold
julia> using Flux: onehotbatch

julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
0 1 0
1 0 1
0 0 0
Expand Down
109 changes: 62 additions & 47 deletions src/onehot.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,61 @@
import Base: *
import Adapt
import .CUDA

struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
indices::I
end
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices)
OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

Base.size(xs::OneHotVector) = (Int64(xs.of),)
_indices(x::OneHotArray) = x.indices

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
OneHotVector(L, idx) = OneHotArray(L, idx)
OneHotMatrix(L, indices) = OneHotArray(L, indices)

function Base.:*(A::AbstractMatrix, b::OneHotVector)
if size(A, 2) != b.of
throw(DimensionMismatch("Matrix column must correspond with OneHotVector size"))
end
return A[:, b.ix]
end
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int
data::A
end
_onehotindex(x, i) = (x == i)

Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x

Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...])
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to simply return the type of the underlying array iiic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying array is an integer array. This is just an internal convenience function I use when I want to convert the OneHotArray to an Bool array. I use this to decide whether to convert to a Array{Bool} or CuArray{Bool} depending on the underlying storage location.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do we need to have this? The implementation for CuArray should fall straight out of assuming regular array


function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
darsnack marked this conversation as resolved.
Show resolved Hide resolved
if isone(dims)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1)
else
return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1))
end
end

# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)

batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs))

import Adapt: adapt, adapt_structure
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices))

adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()

import .CUDA: CuArray, CuArrayStyle, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)

"""
onehot(l, labels[, unk])
Expand All @@ -60,13 +69,13 @@ If `l` is not found in labels and `unk` is present, the function returns
# Examples
```jldoctest
julia> Flux.onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
1
0

julia> Flux.onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
0
1
Expand All @@ -75,13 +84,13 @@ julia> Flux.onehot(:c, [:a, :b, :c])
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
OneHotVector{UInt32, length(labels)}(i)
end

function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
OneHotVector{UInt32, length(labels)}(i)
end

"""
Expand All @@ -95,16 +104,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro
# Examples
```jldoctest
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
0 1 0
1 0 1
0 0 0
```
"""
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

Base.argmax(xs::OneHotVector) = xs.ix
onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls])

"""
onecold(y[, labels = 1:length(y)])
Expand All @@ -120,11 +126,20 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]
function onecold(y::AbstractArray, labels = 1:size(y, 1))
indices = _fast_argmax(y)
xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA
darsnack marked this conversation as resolved.
Show resolved Hide resolved

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
return map(xi -> labels[xi[1]], xs)
end

onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data)
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = x.indices

@nograd onecold, onehot, onehotbatch
@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return A[:, onecold(B)]
end
4 changes: 3 additions & 1 deletion test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using LinearAlgebra: I, cholesky, Cholesky

x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@test cx isa Flux.OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
Expand All @@ -40,8 +40,10 @@ end

@testset "onecold gpu" begin
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
@test Flux.onecold(y) isa CuArray
@test y[3,:] isa CuArray
@test Flux.onecold(y, l) isa CuArray
end

@testset "restructure gpu" begin
Expand Down
78 changes: 76 additions & 2 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,83 @@ end

@testset "abstractmatrix onehotvector multiplication" begin
A = [1 3 5; 2 4 6; 3 6 9]
b1 = Flux.OneHotVector(1,3)
b2 = Flux.OneHotVector(3,5)
b1 = Flux.OneHotVector{eltype(A), 3}(1)
b2 = Flux.OneHotVector{eltype(A), 5}(3)
darsnack marked this conversation as resolved.
Show resolved Hide resolved

@test A*b1 == A[:,1]
@test_throws DimensionMismatch A*b2
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix

ov = OneHotVector(10, rand(1:10))
om = OneHotMatrix(10, rand(1:10, 5))
oa = OneHotArray(10, rand(1:10, 5, 5))

# sizes
@testset "Base.size" begin
@test size(ov) == (10,)
@test size(om) == (10, 5)
@test size(oa) == (10, 5, 5)
end

@testset "Indexing" begin
# vector indexing
@test ov[3] == (ov.indices == 3)
@test ov[:] == ov

# matrix indexing
@test om[3, 3] == (om.indices[3] == 3)
@test om[:, 3] == OneHotVector(10, om.indices[3])
@test om[3, :] == (om.indices .== 3)
@test om[:, :] == om

# array indexing
@test oa[3, 3, 3] == (oa.indices[3, 3] == 3)
@test oa[:, 3, 3] == OneHotVector(10, oa.indices[3, 3])
@test oa[3, :, 3] == (oa.indices[:, 3] .== 3)
@test oa[3, :, :] == (oa.indices .== 3)
@test oa[:, 3, :] == OneHotMatrix(10, oa.indices[3, :])
@test oa[:, :, :] == oa

# cartesian indexing
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
end

@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(10, vcat(ov.indices, ov.indices))
@test vcat(ov, ov) == vcat(convert(Array{Bool}, ov), convert(Array{Bool}, ov))
@test cat(ov, ov; dims = 3) == OneHotArray(10, cat(ov.indices, ov.indices; dims = 2))

# matrix cat
@test hcat(om, om) == OneHotMatrix(10, vcat(om.indices, om.indices))
@test vcat(om, om) == vcat(convert(Array{Bool}, om), convert(Array{Bool}, om))
@test cat(om, om; dims = 3) == OneHotArray(10, cat(om.indices, om.indices; dims = 2))

# array cat
@test cat(oa, oa; dims = 3) == OneHotArray(10, cat(oa.indices, oa.indices; dims = 2))
@test cat(oa, oa; dims = 1) == cat(convert(Array{Bool}, oa), convert(Array{Bool}, oa); dims = 1)
end

@testset "Base.reshape" begin
# reshape test
@test reshape(oa, 10, 25) isa OneHotArray
@test reshape(oa, 10, :) isa OneHotArray
@test reshape(oa, :, 25) isa OneHotArray
@test_throws ArgumentError reshape(oa, 50, :)
@test_throws ArgumentError reshape(oa, 5, 10, 5)
@test reshape(oa, (10, 25)) isa OneHotArray
end

@testset "Base.argmax" begin
# argmax test
@test argmax(ov) == argmax(convert(Array{Bool}, ov))
@test argmax(om) == argmax(convert(Array{Bool}, om))
@test argmax(om; dims = 1) == argmax(convert(Array{Bool}, om); dims = 1)
@test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2)
@test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1)
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
end
end