Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Redesign underlying storage #39

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

nomadbl
Copy link

@nomadbl nomadbl commented Jul 22, 2023

After working on #36 #35 and getting bogged down with the type inference, I'm trying the other route. I think it might also help with:

  • Simplifying the underlying code (probably)
  • May help with some other Pull requests and future additions

The concept:
Instead of a datatype holding the indices directly, keep an Array of OneHotVectors (essentially indices, but with type metadata) and some metadata of which axis is the OneHot one.

This PR is a draft of the idea but is already functional for some things and matmul was fast in the limited tests I made.
Your input is appreciated :)

PR Checklist

  • Tests are added
  • Documentation, if applicable

@darsnack
Copy link
Member

darsnack commented Jul 22, 2023

I would refer to FluxML/Flux.jl#1448 and the related issues and PRs. In particular, the storage change here is reverting back to Flux’s original implementation. This PR is a bit more sophisticated about its indexing, so it may not have the same issues that the original implementation.

But this PR will need a bunch of performance tests when you’re ready. The linked PR above has a good selection of the important ones. I would add hcat to the list too. In particular all of this needs to work and be fast on the GPU.

(These are just preliminary comments; will need more time to dig through the actual code changes)

@mcabbott
Copy link
Member

Can you explain more the motivation? As in:

  • what is this used for?
  • why is any modification to OneHotArray better than wrapping it in existing Transpose / PermutedDimsArray wrappers?

Any such modification is going to add complexity, which is a cost not necessarily for users, but for anyone maintaining the package in future.

@nomadbl
Copy link
Author

nomadbl commented Jul 23, 2023

My motivation is primarily the setting of arbitrary axes as "one hot".
I also tried going in the Transpose / PermutedDimsArray route (#36) , but I haven't been able to get type inference/stability right.
I think an additional advantage would be enabling easy concatenation without losing type stability (for cases where it remains one hot).
At the moment it seems this implementation is simpler than the previous one, but this could be wrong.
I put this PR here to get feedback and suggestions while I'm taking a stab at it. If it turns out to not be fruitful I'll close it.

@darsnack
Copy link
Member

darsnack commented Jul 23, 2023

I have commented on #36 to follow up that work. Unfortunately, I think this PR is going to be a non-starter since primary motivation for the current storage is having a contiguous array of indices for GPU performance. Fundamentally, I don't think this PR will be able to meet that performance goal. I'm open to being shown that I'm wrong though.

I think if we are going to support an arbitrary axis in the package, then it is better to follow #36's approach or add the axis to the struct as originally suggested in #35. A caveat being @mcabbott's point about complexity added to the library. If the only downstream operation that you care about is matrix multiplication, then it is easier to just use a permuted array in user code for that one function. Adding it to the package means it has to support not just matrix multiplication but every other present and future operation we define on OneHotArray.

@mcabbott
Copy link
Member

I also tried going in the Transpose / PermutedDimsArray route (#36) , but I haven't been able to get type inference/stability right.

Ah sorry, I mis-read, or forgot that this is what #36 does.

I think that, in any approach to write a new more flexible array type, you will probably need to store which axis is one-hot in the type, to make indexing fast & use dispatch to call other operations. As you do here, e.g. onehotaxis. And once you do this, you will have much the same type-stability issues as in constructing a PermutedDimsArray.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants