Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][Blocked] Mem efficient attention - FW pass #162

Closed
wants to merge 6 commits into from

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Dec 22, 2021

What does this PR do?

First take for #161, only the forward pass with this PR (no training possible). The method is described here, the gist is that you compute the attention with you current best guess for the softmax renormalization, save your offset, and correct post-hoc

TODO:

  • check parity all along (buggy for now)
  • possibly change the scheduling to improve L2 rate (may not be a limiting factor really)
  • handle bias
  • change the way we handle batch dimension ?
  • normalize in the kernel if there's only ever one tile over N
  • tiling case is broken by a factor of kN
    cc @ptillet

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 22, 2021
v_ptrs = V + rn[:, None] * L + rl_i[None, :] # (BLOCK_N, BLOCK_L)
v = tl.load(v_ptrs, mask=((rn[:, None] < N) & (rl_i[None, :] < L)), other=0.0)

qkv = tl.dot(exp_acc, v).to(tl.float32) # (BLOCK_M, BLOCK_L)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptillet not having .to(tl.float32) means that this crashes for instance, it's not really obvious to me why

@blefaudeux
Copy link
Contributor Author

quick update: working around a bug in the Triton compiler, PoC is there and runs, not shippable as is. Lots of perf potential, the FW could actually be faster than a vanilla take while using a lot less memory. The BW will always be a little slower but the end result could well be worth it

@ptillet
Copy link

ptillet commented Dec 22, 2021

Yep there's definitely a bug in the compiler. The holiday break seems like the right time to rewrite how Triton handles data layout. I think there's a lot of potential for this kind of fused attention function to be very performant and memory-efficient.

@blefaudeux blefaudeux changed the title [DRAFT] Mem efficient attention - FW pass [DRAFT][Blocked] Mem efficient attention - FW pass Jan 10, 2022
@blefaudeux
Copy link
Contributor Author

keeping the branch up but closing the PR, I cannot do much on this topic at the moment, dependent on upstream fixes on Triton

@blefaudeux blefaudeux closed this Jan 10, 2022
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
Small Fix, order in which tensors passed to attention
BLOCK_N = min(triton.next_power_of_2(N), 1024) # increase the ceiling to save more memory
BLOCK_L = 8

tiles_n = triton.cdiv(N, BLOCK_N)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptillet
Copy link

ptillet commented Mar 14, 2022

Hey! FYI mem-efficient attention is becoming a bigger priority for us. The bug is pretty deep inside of the compiler but I am seriously considering to dive in and take care of it. I have more time to address stability issues in Triton.

@blefaudeux
Copy link
Contributor Author

Hey! FYI mem-efficient attention is becoming a bigger priority for us. The bug is pretty deep inside of the compiler but I am seriously considering to dive in and take care of it. I have more time to address stability issues in Triton.

I just updated this old branch, I'm getting an IndexError: map::at at compilation time with the latest dev package, let me know when there's something up to check out @ptillet !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU
3 participants