-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arbitrary dimension one-hot arrays #1448
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find generalising the type to be nicer in general, but I've comments around defining some of methods and the type seems to hold more information than necessary. That would induce dynamic dispatch which would be nice to avoid
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) | ||
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) | ||
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N} | ||
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to simply return the type of the underlying array iiic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The underlying array is an integer array. This is just an internal convenience function I use when I want to convert the OneHotArray
to an Bool
array. I use this to decide whether to convert to a Array{Bool}
or CuArray{Bool}
depending on the underlying storage location.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why do we need to have this? The implementation for CuArray should fall straight out of assuming regular array
The last remaining issue that needs to be addressed is that Base's |
What Base does with reshape is just crazy. So maybe overload Base._reshape(x::OneHotArray, dims:::Dims) ? |
Please don't touch internal Julia functions |
Probably we just want to extend reshape(x::OneHotArray, dims::Tuple{Vararg{Union{Int,Colon}}}) Only downside is we'd still rely on the internal reshape(x::OneHotArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape(x, Base._reshape_uncolon(x, dims)) Once JuliaLang/julia#39123 is addressed, we can remove this line (so temporary, minimal use of an internal function). |
Never mind, that leads to a method ambiguity that I think can only be resolved in Base. We can either overload |
I wouldn't hold on adding a very basic functionality just because the situation is a bit weird in Base. I say we add Most importantly though, we have to decide whether we go with this or #1447 |
Agreed on testing for the right cases but not on overloading internal functions. We can catch the failing case as a separate method instead |
I agree.
That's what I tried. Due to the situation in Base, we are forced to catch the failing case with
If this is for the v0.12 milestone, then I suggest we go for this. Based on this comment, I think there is a slim set of models where the performance of this PR and #1447 can't be matched. We can always adopt #1447 if we need to in the future. Going back from a new primitive type will be harder. #1447 and this PR are composable changes. I think all that's required to move to #1447 in the future is to replace the |
@DhairyaLGandhi I asked a few times already, could you lift github restrictions (I get "The base branch restricts merging to authorized users. Learn more about protected branches.") and clarify if we have to use bors r+ or not now that we have buildkite? @darsnack we have a failing onecold test on gpu. To me the plan sounds good, the simplicity of this approach is very compelling, we can merge this and revisit later if performance issues are lamented on those corner cases. This unless @chengchingwen has strong objections against |
@DhairyaLGandhi bump |
Thanks, I've added a thought around saving the vcat to be an error with a message pointing users to collect the onehotarray if it's no longer one hot I'm inclined to go with this over #1447 since they address the same concerns but in a neater fashion. We'll continue with bors for the time being. |
Okay the constructors should be backwards compatible and |
bors r+ |
Build succeeded: |
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = | ||
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) : | ||
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) | ||
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the reason for adding this? It seems overly restrictive to require that first(dims) == L
, in fact, this broke some of my code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback worked fine for me before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initial implementation converted to a Bool
array for the else case. @CarloLucibello seems like we should add that back in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Simeon is referring to before we merged this? Did we convert to a bool array then? I think it would be difficult to guarantee the return type of the function then. I agree the check seems pretty restrictive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what fallback @simeonschaub is referring to, but the original implementation of this PR did not throw an error. It converted to a Bool
array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, should have clarified: I meant falling back to the default defintion in Base for AbstractArray
, which produces a ReshapedArray
. I think if we overload reshape
here, we shouldn't make it error in cases where the fallback would work, since that makes it hard to use reshape
in generic code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that reshaping the first dimension was something never done in practice, but since we broke @simeonschaub's code maybe is not so rare. It can be handled by reshape(collect(oh), ...)
, but i would be fine to relax the dims check if people feel that need, although, as @DhairyaLGandhi said, this would make reshape type unstable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll submit a quick fix PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to explain, in my use case I was adding a singleton dimension in front of the rest for the purpose of broadcasting, i.e. something like:
reshape(Flux.onehotbatch(Flux.onecold(ŷ, classes), classes), 1, 4, :) .=== reshape(outputs_onehot, 4, 1, :);
It would be nice to add docs for this. #1519 |
I think this PR hasn't been released yet. Why? |
Not 100% positive cause I’m still hazy on semver but I think the struct/constructor changes means it is breaking. So it will have to wait until v0.12. Even though the highest level APIs like |
This supersedes #1447. It should address the same issues:
This PR introduces a new one-hot N-dimensional array type,
OneHotArray
. Like #1447, this approach avoids the pointer allocations associated withOneHotMatrix
being an array ofOneHotVector
s. It also lifts the "height" into the type parameter to avoid unnecessary allocation. Unlike #1447, this approach does not introduce a new primitive type. Instead, a "one-hot vector" is represented with a single subtype ofInteger
that is configurable by the user. By default, the exposed API will useUInt32
.Fundamentally, the primitive type is necessary because wrapping a
UInt32
as aOneHotVector
will suffer memory penalties when you create anArray{<:OneHotVector}
. But if we begin by designing for N-dimensions, thenOneHotVector
is just the specialized 1D case (similar to howVector{T} = Array{T, 1}
).Performance
I compared against the same tests mentioned in #1447. Please suggest more if you want to.
Summary
This should basically be #1447 but simpler to maintain w/ fewer changes. Tests are passing, though I think we should add more tests for one-hot data (currently our test set seems pretty sparse). Performance matches #1447 where I have tested, but please suggest more performance tests. In theory, any performance difference between #1447 and this PR should be recoverable.
PR Checklist
cc @CarloLucibello @chengchingwen