-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU #161
Comments
Another reference impl can be found here -- same caveats as outlined above. |
I've started something, it feels like some of the logic would need to be changed a bit for that to make sense at a kernel level, at least for triton. In particular it's hard to sequence things outside of a kernel, and reproducing the same logic as the one from the paper would lead to big buffers (if the computation is tiled), which diminish the interest a lot. It feels like the best approach is with a kernel owning the whole line, and a couple of rows at a time to help with data fetch reuse |
thanks ! |
🚀 Feature
Implement https://arxiv.org/pdf/2112.05682v2.pdf using Triton
Motivation
There are existing implementations in Pytorch, but they re bound to be a little slow. It s actually not that much work to write that down in Triton, give it a shot. Given the FW speed (should be similar to normal attention, without the memory) and the expected BW speed (about 60% of the vanilla attention), feels like a compromise that many would use
Pitch
The required kernel is actually not that far from some of the kernels that we already have, at least for the FW. The chunk strategy proposed by the paper is actually fairly classic in that field, nothing out of the ordinary (see for instance), so it's bound to be pretty fast if correctly implemented.
Alternatives
At least support a pure pytorch variant in xformers ?
The text was updated successfully, but these errors were encountered: