-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Dropout(Activation(x+bias)), now with partial BW fusion #144
Conversation
a18bda0
to
147e2c9
Compare
c7ab5b0
to
8dea61d
Compare
#Dropout_Bias_False_FW+BW_torch float16_Act:_gelu #Dropout_Bias_True_FW+BW_torch float16_Act:_squared_relu #Dropout_Bias_True_FW_torch float16_Act:_squared_relu Interested @suchenzang ? This took a while to get right. Some speed for small tensors should be recoverable, I didn´t play with the settings too much edit: old plots, see below for up to date numbers |
Oh man, coming to xFormers to shop for parts is great. @blefaudeux these numbers are on V100s or A100s? |
8dea61d
to
190bbaf
Compare
ahah, I missed you @suchenzang ! This is on an ampere laptop, working with what I have around.. 400GB/s is the max bandwidth, so basically there's not much to win on the inference side. The reported training number is not exact GB wise (there are operations in the middle not counted), but could well be that the scheduling brings back another 10/20%, these are very raw numbers. It should be at parity with pytorch accuracy wise, not compromising here, just fusing the kernels. I'll pull in @Chillee also in that he has a NVFuser solution for the same part that you could test out (we were wondering whether xformers could host that also), he was actually the incentive for me to revisit this (he's got very good numbers !) |
#Dropout_Bias_True_FW_torch float32_Act:_None edit: old plot, see below for up to date numbers |
cc @min-xu-ai , if it helped to have a look at a Triton kernel for Mevo |
I'll clean up the PR, sorry for all the extra changes |
190bbaf
to
73b7663
Compare
I just pushed a cleaned up version, should be better |
@@ -25,42 +25,48 @@ | |||
class _dropout(torch.autograd.Function): | |||
@staticmethod | |||
@custom_fwd(cast_inputs=torch.float16) | |||
def forward(ctx, x, p, bias, activation, activation_grad): | |||
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the trainable bias (or not) was not properly handled before (bias was always assumed to be trainable, which is mostly true but not always)
157789b
to
7f3718d
Compare
using less seeds tiling + vertical seeds Computing the FW and BW per tile over M workaround atomics, reintroduce locks, should not be too bad yet another take, partial sum better scheduling defaults, improves across the board giving atomic add a go back to the locks, completely fuse the BW
7f3718d
to
16808ea
Compare
89b92ba
to
7221f7e
Compare
…oblem for small buffers and BW
7221f7e
to
93668d5
Compare
…e and correct masks
Probably too long on that, it was on the side and I spent a lot of time getting back into the proper context all the time (+couple of small hiccups with triton on my way). Now with good enough perfs I think, there's one case which is not faster than pytorch (really small buffers + gelu + fp16), everything else is significantly faster and the FW speed almost doubled with this PR. I updated all the graphs, keep in mind when comparing that the previous ones were with V100 (HBM memory, something like 900GB/s max bandwidth) and the new ones are with a 3080 laptop (GDDR6 memory, something like 440GB/s max bandwidth). I think that this could be revisited with a newer Triton, or with NVFuser/functorch (if not too much work a PR would be great @Chillee, if the speed is there OOB I would definitely take it !) Up for review and next is #153, probably much more impactful if this is doable |
testing a training with microGPT, looks like sometihng is wrong, the loss plateau.. |
fixed, no impact on perf. Checking with loss curves right now that everything is fine but should be the case. I've improved on the unit test to catch that, basically p=0.5 was correct but p=0.1 was not |
d631179
to
ee9ea6f
Compare
this is against the previous main (blue, still FusedMLP), and MLP instead (red, pure pytorch). Scale is log, emphasizes the small differences, if not you cannot distinguish one from the other. The end test (sampled text) looked good. In both fused cases there's a small but measurable difference with pure pytorch though, I'm not sure why, except that maybe the AMP execution is not the same (in the FusedMLP case inputs are cast to fp16 and remain so over the layer, in the pytorch non-fused case maybe that not everything is fp16) |
So the only mildly worrisome thing here is seeing some divergence as training progresses. It seems like that delta shows up a bit vs the blue. If the two converges to different points, then we have a problem :( |
totally, it's very strange because there's no shortcut in the implementation, it's supposed to give the same results. I've tried AMP/fp16, it's not that, same results. Now it could just be about the seed, I'll check that next |
hmm, testing seeds it does change the result a bit, but still. With this PR we're using less seeds within the random number generation, and generating more numbers out of them, but that's the only measurable difference top of head with the previous take, the rest is just about a different kernel architecture |
Discussing with Phil, could be that the RNG is not good enough, checking that next |
could well be the reason, the fused dropout on main/ is fine. I checked if there was a pattern in the dropout, but not obvious on a heatmap. Dropping that for now, could be that this was the main reason for a very fast FW (generating less seeds), but cannot be at the expense of training accuracy |
Use configurations from real models
What does this PR do?
This was a long time in the making.. Fusing the BW part of the activation/bias/dropout kernel. Not quite perfect but in some places the speed goes really bananas (like x3 or x4 the naive calls).
Fusing this implied flipping the whole problem upside down, basically the seeds have to be per collum, and the kernels (fw and bw) also work that way. This allows us to fuse the bias gradient computations, since it's a sum over that direction
TODO:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.