-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT][Blocked] Mem efficient attention - FW pass #162
Conversation
…er case -whole line in kernel-
fc02cc8
to
18744d8
Compare
v_ptrs = V + rn[:, None] * L + rl_i[None, :] # (BLOCK_N, BLOCK_L) | ||
v = tl.load(v_ptrs, mask=((rn[:, None] < N) & (rl_i[None, :] < L)), other=0.0) | ||
|
||
qkv = tl.dot(exp_acc, v).to(tl.float32) # (BLOCK_M, BLOCK_L) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ptillet not having .to(tl.float32) means that this crashes for instance, it's not really obvious to me why
141b484
to
3f8b95f
Compare
3f8b95f
to
c91f376
Compare
a47ef5b
to
6fe0903
Compare
quick update: working around a bug in the Triton compiler, PoC is there and runs, not shippable as is. Lots of perf potential, the FW could actually be faster than a vanilla take while using a lot less memory. The BW will always be a little slower but the end result could well be worth it |
Yep there's definitely a bug in the compiler. The holiday break seems like the right time to rewrite how Triton handles data layout. I think there's a lot of potential for this kind of fused attention function to be very performant and memory-efficient. |
keeping the branch up but closing the PR, I cannot do much on this topic at the moment, dependent on upstream fixes on Triton |
Small Fix, order in which tensors passed to attention
BLOCK_N = min(triton.next_power_of_2(N), 1024) # increase the ceiling to save more memory | ||
BLOCK_L = 8 | ||
|
||
tiles_n = triton.cdiv(N, BLOCK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @dianaml0
Hey! FYI mem-efficient attention is becoming a bigger priority for us. The bug is pretty deep inside of the compiler but I am seriously considering to dive in and take care of it. I have more time to address stability issues in Triton. |
I just updated this old branch, I'm getting an IndexError: map::at at compilation time with the latest dev package, let me know when there's something up to check out @ptillet ! |
What does this PR do?
First take for #161, only the forward pass with this PR (no training possible). The method is described here, the gist is that you compute the attention with you current best guess for the softmax renormalization, save your offset, and correct post-hoc
TODO:
cc @ptillet
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.