-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Dropout partial bw fusion (second take) #164
Conversation
using less seeds tiling + vertical seeds Computing the FW and BW per tile over M better scheduling defaults, improves across the board good enough perfs catching the slow case and diverting to pytorch in that case
cc @ptillet @suchenzang, should finally be ready and perfs are still good. Turns out randint4x from Triton works fine as long as you don´t request too many numbers |
13baac5
to
9ee09ee
Compare
Codecov Report
@@ Coverage Diff @@
## main #164 +/- ##
==========================================
- Coverage 90.61% 90.56% -0.06%
==========================================
Files 56 56
Lines 2824 2829 +5
==========================================
+ Hits 2559 2562 +3
- Misses 265 267 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
hmm testing on a desktop 3080, the perf for < 2048 buffers is not very good, and it's not just hyper params. The perf above that is a lot better than before, but would be nice to have the best of both worlds... Edit: fixed with bigger tiles, the accuracy error before was not due to tile size it turns out, but not creating the exact same mask in between FW and BW (in this layer we don't save the mask but only the seeds, the mask is re-created on the fly for each pass. Choosing different scheduling params led to a different use of the seeds, meaning different masks and that's why the training was not as good in the previous PR -and in the layer currently checked in actually-. TLDR is that this is fixed, and the perf is even above the curves displayed in previous comments, while I checked with a full training that the accuracy is maintained |
4a178b2
to
6103a46
Compare
6103a46
to
3ffcc2c
Compare
3ffcc2c
to
8b9befa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! :)
…Triton jit can be a little picky
f801c00
to
3e74a7e
Compare
* Orthoformer attention Author: Mandela Patrick et al. * only select landmarks amid the queries
What does this PR do?
TL:DR:
Dropout(activation(x+bias)) layer, now with partially fused BW and much better FW perf (bandwidth bound mostly)
Previously broken in #144 since the end accuracy was not as good as pytorch's. Now fixed
The longer version:
Re-opening #144, I finally got the time to look into that (being blocked on the mem efficient attention by a triton bug :) ). So the issue was in the random number generation, which calls tl.rand4intx from Triton, starting from a random seed that Pytorch generates. This seems to break if the number of generated numbers is too big. This PR introduces a processing per tile, each of the tiles being covered by one seed initially + these random numbers. The solution was just to limit the size of the tiles / use more seeds, and adjust the other hyperparameters accordingly (number of warps), the perf is not really changed.
Training on minGPT
(orange is the previous take, red is pure pytorch MLP, green is MLP with this layer, scale is log)
one epoch of minGPT:
-- training
-- forward pass only (still with dropout)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.