Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary dimension one-hot arrays #1448

Merged
merged 13 commits into from
Jan 8, 2021

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Jan 1, 2021

This supersedes #1447. It should address the same issues:

This PR introduces a new one-hot N-dimensional array type, OneHotArray. Like #1447, this approach avoids the pointer allocations associated with OneHotMatrix being an array of OneHotVectors. It also lifts the "height" into the type parameter to avoid unnecessary allocation. Unlike #1447, this approach does not introduce a new primitive type. Instead, a "one-hot vector" is represented with a single subtype of Integer that is configurable by the user. By default, the exposed API will use UInt32.

Fundamentally, the primitive type is necessary because wrapping a UInt32 as a OneHotVector will suffer memory penalties when you create an Array{<:OneHotVector}. But if we begin by designing for N-dimensions, then OneHotVector is just the specialized 1D case (similar to how Vector{T} = Array{T, 1}).

Performance

I compared against the same tests mentioned in #1447. Please suggest more if you want to.

  1. Huge performance difference between sparse and dense representation on GPU #189
#master
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);

julia> W = rand(128, 100);

julia> @btime $W * $x;
  5.095 μs (13 allocations: 50.86 KiB)

julia> cW, cx = cu(W), cu(x);

julia> @btime $cW * $cx;
  24.948 μs (86 allocations: 3.11 KiB)

#1447
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);

julia> W = rand(128, 100);

julia> @btime $W * $x;
  5.312 μs (3 allocations: 50.36 KiB)

julia> cW, cx = cu(W), cu(x);

julia> @btime $cW * $cx;
  8.466 μs (61 allocations: 1.69 KiB)

# this PR
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);

julia> W = rand(128, 100);

julia> @btime $W * $x;
  4.708 μs (3 allocations: 50.56 KiB)

julia> cW, cx = cu(W), cu(x);

julia> @btime $cW * $cx;
  8.576 μs (63 allocations: 1.73 KiB)
  1. onecold is very slow #556
#master
julia> valY = randn(1000, 128);

julia> @btime Flux.onecold($valY);
  365.712 μs (1131 allocations: 38.16 KiB)

julia> @btime Flux.onecold($(gpu(valY)));
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43
  1.330 s (781248 allocations: 31.59 MiB)

#1447
julia> valY = randn(1000, 128);

julia> @btime Flux.onecold($valY);
  524.767 μs (8 allocations: 4.00 KiB)

julia> @btime Flux.onecold($(gpu(valY)));
  27.563 μs (169 allocations: 5.56 KiB)

# this PR
julia> valY = randn(1000, 128);

julia> @btime Flux.onecold($valY);
  493.017 μs (8 allocations: 4.53 KiB)

julia> @btime Flux.onecold($(gpu(valY)));
  26.702 μs (171 allocations: 5.61 KiB)

Summary

This should basically be #1447 but simpler to maintain w/ fewer changes. Tests are passing, though I think we should add more tests for one-hot data (currently our test set seems pretty sparse). Performance matches #1447 where I have tested, but please suggest more performance tests. In theory, any performance difference between #1447 and this PR should be recoverable.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @DhairyaLGandhi (for API changes).

cc @CarloLucibello @chengchingwen

@darsnack darsnack changed the title Darsnack/arbitrary one hot Arbitrary dimension one-hot arrays Jan 1, 2021
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find generalising the type to be nicer in general, but I've comments around defining some of methods and the type seems to hold more information than necessary. That would induce dynamic dispatch which would be nice to avoid

src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to simply return the type of the underlying array iiic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying array is an integer array. This is just an internal convenience function I use when I want to convert the OneHotArray to an Bool array. I use this to decide whether to convert to a Array{Bool} or CuArray{Bool} depending on the underlying storage location.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do we need to have this? The implementation for CuArray should fall straight out of assuming regular array

src/onehot.jl Outdated Show resolved Hide resolved
@darsnack darsnack mentioned this pull request Jan 2, 2021
4 tasks
src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member Author

darsnack commented Jan 5, 2021

The last remaining issue that needs to be addressed is that Base's reshape logic returns a lazy ReshapedArray for cases like reshape(x, 10, :). This means that if you collect the lazy iterator, you'll currently get an Array{Bool} when we want to return a OneHotArray (see the failing test cases).

@CarloLucibello
Copy link
Member

What Base does with reshape is just crazy.
I think you have to intercept the following line
https://github.com/JuliaLang/julia/blob/788b2c77c10c2160f4794a4d4b6b81a95a90940c/base/reshapedarray.jl#L118

So maybe overload

Base._reshape(x::OneHotArray, dims:::Dims)

?

@DhairyaLGandhi
Copy link
Member

Please don't touch internal Julia functions

@darsnack
Copy link
Member Author

darsnack commented Jan 6, 2021

Probably we just want to extend

reshape(x::OneHotArray, dims::Tuple{Vararg{Union{Int,Colon}}})

Only downside is we'd still rely on the internal Base._reshape_uncolon. Ideally, Base wouldn't have this singular odd dispatch path. I opened JuliaLang/julia#39123, so I think we should just define

reshape(x::OneHotArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape(x, Base._reshape_uncolon(x, dims))

Once JuliaLang/julia#39123 is addressed, we can remove this line (so temporary, minimal use of an internal function).

@darsnack
Copy link
Member Author

darsnack commented Jan 6, 2021

Never mind, that leads to a method ambiguity that I think can only be resolved in Base. We can either overload _reshape like @CarloLucibello suggested, or we ship with these tests broken until this is addressed in Base.

@CarloLucibello
Copy link
Member

Never mind, that leads to a method ambiguity that I think can only be resolved in Base. We can either overload _reshape like @CarloLucibello suggested, or we ship with these tests broken until this is addressed in Base.

I wouldn't hold on adding a very basic functionality just because the situation is a bit weird in Base. I say we add _reshape, if things will change in future julia versions we will change accordingly, the important thing is to test cases like reshape(x, 10, :).

Most importantly though, we have to decide whether we go with this or #1447

@DhairyaLGandhi
Copy link
Member

Agreed on testing for the right cases but not on overloading internal functions. We can catch the failing case as a separate method instead

@darsnack
Copy link
Member Author

darsnack commented Jan 7, 2021

I wouldn't hold on adding a very basic functionality just because the situation is a bit weird in Base.

I agree.

We can catch the failing case as a separate method instead

That's what I tried. Due to the situation in Base, we are forced to catch the failing case with Base._reshape. There's no way to catch all the cases with Base.reshape alone.

Most importantly though, we have to decide whether we go with this or #1447

If this is for the v0.12 milestone, then I suggest we go for this. Based on this comment, I think there is a slim set of models where the performance of this PR and #1447 can't be matched. We can always adopt #1447 if we need to in the future. Going back from a new primitive type will be harder.

#1447 and this PR are composable changes. I think all that's required to move to #1447 in the future is to replace the indices fields with the primitive OneHot.

CarloLucibello
CarloLucibello previously approved these changes Jan 8, 2021
@CarloLucibello
Copy link
Member

@DhairyaLGandhi I asked a few times already, could you lift github restrictions (I get "The base branch restricts merging to authorized users. Learn more about protected branches.") and clarify if we have to use bors r+ or not now that we have buildkite?

@darsnack we have a failing onecold test on gpu.

To me the plan sounds good, the simplicity of this approach is very compelling, we can merge this and revisit later if performance issues are lamented on those corner cases. This unless @chengchingwen has strong objections against

@CarloLucibello
Copy link
Member

@DhairyaLGandhi I asked a few times already, could you lift github restrictions (I get "The base branch restricts merging to authorized users. Learn more about protected branches.") and clarify if we have to use bors r+ or not now that we have buildkite?

@DhairyaLGandhi bump

CarloLucibello
CarloLucibello previously approved these changes Jan 8, 2021
test/onehot.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

Thanks, I've added a thought around saving the vcat to be an error with a message pointing users to collect the onehotarray if it's no longer one hot

I'm inclined to go with this over #1447 since they address the same concerns but in a neater fashion.

We'll continue with bors for the time being.

@darsnack
Copy link
Member Author

darsnack commented Jan 8, 2021

Okay the constructors should be backwards compatible and vcat will throw an error.

@DhairyaLGandhi
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Jan 8, 2021

Build succeeded:

@bors bors bot merged commit ebd37d6 into FluxML:master Jan 8, 2021
@darsnack darsnack deleted the darsnack/arbitrary-one-hot branch January 8, 2021 19:43
Comment on lines +45 to +48
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason for adding this? It seems overly restrictive to require that first(dims) == L, in fact, this broke some of my code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback worked fine for me before

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial implementation converted to a Bool array for the else case. @CarloLucibello seems like we should add that back in?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Simeon is referring to before we merged this? Did we convert to a bool array then? I think it would be difficult to guarantee the return type of the function then. I agree the check seems pretty restrictive

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what fallback @simeonschaub is referring to, but the original implementation of this PR did not throw an error. It converted to a Bool array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, should have clarified: I meant falling back to the default defintion in Base for AbstractArray, which produces a ReshapedArray. I think if we overload reshape here, we shouldn't make it error in cases where the fallback would work, since that makes it hard to use reshape in generic code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that reshaping the first dimension was something never done in practice, but since we broke @simeonschaub's code maybe is not so rare. It can be handled by reshape(collect(oh), ...), but i would be fine to relax the dims check if people feel that need, although, as @DhairyaLGandhi said, this would make reshape type unstable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit a quick fix PR

Copy link
Member

@simeonschaub simeonschaub Jan 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to explain, in my use case I was adding a singleton dimension in front of the rest for the purpose of broadcasting, i.e. something like:

reshape(Flux.onehotbatch(Flux.onecold(ŷ, classes), classes), 1, 4, :) .===  reshape(outputs_onehot, 4, 1, :);

@cossio
Copy link
Contributor

cossio commented Feb 25, 2021

It would be nice to add docs for this. #1519

@cossio
Copy link
Contributor

cossio commented Feb 25, 2021

I think this PR hasn't been released yet. Why?

@darsnack
Copy link
Member Author

Not 100% positive cause I’m still hazy on semver but I think the struct/constructor changes means it is breaking. So it will have to wait until v0.12. Even though the highest level APIs like onehotbatch are the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Issues about OneHotVector/OneHotMatrix
6 participants