Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Dropout partial bw fusion (second take) #164

Merged
merged 7 commits into from
Jan 3, 2022
Merged

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Dec 23, 2021

What does this PR do?

TL:DR:
Dropout(activation(x+bias)) layer, now with partially fused BW and much better FW perf (bandwidth bound mostly)
Previously broken in #144 since the end accuracy was not as good as pytorch's. Now fixed

The longer version:
Re-opening #144, I finally got the time to look into that (being blocked on the mem efficient attention by a triton bug :) ). So the issue was in the random number generation, which calls tl.rand4intx from Triton, starting from a random seed that Pytorch generates. This seems to break if the number of generated numbers is too big. This PR introduces a processing per tile, each of the tiles being covered by one seed initially + these random numbers. The solution was just to limit the size of the tiles / use more seeds, and adjust the other hyperparameters accordingly (number of warps), the perf is not really changed.

  • Training on minGPT
    (orange is the previous take, red is pure pytorch MLP, green is MLP with this layer, scale is log)
    Screenshot from 2021-12-22 20-51-50

  • one epoch of minGPT:

Friends of my soul and flatterers, your forfeit,
Go, give me leave, and with my comfort than
'Tis mine own greatness.

SICINIUS:
I know you not.

MENENIUS:
Consider that which not.
Being once the tackless rock about his soldier,
Is all the fiery ports of this case,
But only in your resolve I have sold to Ravenspurgh,
And not omity my father's death with me?
Wife, and there I will take such on him
That you shall steal me on the business
And then was cruel for it; for in his pleasure,
Beyond the whole of her faults, and full of farewells;
For welcome but thy father return'd could not,
Call their defence and them accuse their wifes,
Suppose two and to be glads.

First Citizen:
He says let Turkey, will not be altered; since led
you on the foundations?

MISTRESS OVERDONE:
Look, what I have done this holy word to fine
He would contemplate; and there my father
Hath the assistantied by as the chair
Where is he was fulfilled: and his way is rather
Hath some his bonnet state, and in his face
  • some perf curves (if bias is False then the perf gap grows even more)
    -- training
    Dropout_Bias_True_FW+BW_torch float16_Act:_gelu

Dropout_Bias_True_FW+BW_torch float16_Act:_squared_relu

-- forward pass only (still with dropout)
Dropout_Bias_True_FW_torch float16_Act:_gelu

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

using less seeds

tiling + vertical seeds

Computing the FW and BW per tile over M

better scheduling defaults, improves across the board

good enough perfs

catching the slow case and diverting to pytorch in that case
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 23, 2021
@blefaudeux blefaudeux changed the title Dropout bw fusion [feat] Dropout partial bw fusion (second take) Dec 23, 2021
@blefaudeux blefaudeux requested a review from dianaml0 December 23, 2021 05:06
@blefaudeux
Copy link
Contributor Author

cc @ptillet @suchenzang, should finally be ready and perfs are still good. Turns out randint4x from Triton works fine as long as you don´t request too many numbers

@codecov-commenter
Copy link

codecov-commenter commented Dec 23, 2021

Codecov Report

Merging #164 (3e74a7e) into main (5148844) will decrease coverage by 0.05%.
The diff coverage is 68.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #164      +/-   ##
==========================================
- Coverage   90.61%   90.56%   -0.06%     
==========================================
  Files          56       56              
  Lines        2824     2829       +5     
==========================================
+ Hits         2559     2562       +3     
- Misses        265      267       +2     
Flag Coverage Δ
Python 90.56% <68.00%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/triton/dropout.py 76.00% <68.00%> (-1.15%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5148844...3e74a7e. Read the comment docs.

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 26, 2021

hmm testing on a desktop 3080, the perf for < 2048 buffers is not very good, and it's not just hyper params. The perf above that is a lot better than before, but would be nice to have the best of both worlds...

Edit: fixed with bigger tiles, the accuracy error before was not due to tile size it turns out, but not creating the exact same mask in between FW and BW (in this layer we don't save the mask but only the seeds, the mask is re-created on the fly for each pass. Choosing different scheduling params led to a different use of the seeds, meaning different masks and that's why the training was not as good in the previous PR -and in the layer currently checked in actually-. TLDR is that this is fixed, and the perf is even above the curves displayed in previous comments, while I checked with a full training that the accuracy is maintained

@blefaudeux blefaudeux marked this pull request as draft December 26, 2021 22:35
@blefaudeux blefaudeux marked this pull request as ready for review December 27, 2021 00:44
Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! :)

xformers/triton/dropout.py Show resolved Hide resolved
xformers/triton/k_dropout.py Outdated Show resolved Hide resolved
@blefaudeux blefaudeux merged commit 9be9167 into main Jan 3, 2022
@blefaudeux blefaudeux deleted the dropout_bw_fusion branch January 3, 2022 19:47
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
* Orthoformer  attention
Author: Mandela Patrick et al.

* only select landmarks amid the queries
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants