-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory efficient attention - backward pass #281
Conversation
It's super slow!
But we still have a long way to go
Merge two loops together and use local buffers for accumulation and grad_q. The use of local buffers as is currently introduces limitations on the sizes of dimension K
Brings an extra 2x improvement
Makes it another 20% faster, and doesn't use extra memory
Use all threads to compute grad_q
This brings 50% speedup compared to the previous approach, despite redundant computation. The benefit comes from the fact that we are using better block sizes for the matmul computation of grad_q, which doesnt involve the transpose of the attention matrix
Brings an additional 12% speedup despite duplicate computation
Potentially due to avoiding bank conflicts?
Brings 10% improvement, being better than my previous best version
This is now significantly faster than what we had before, and is even faster than the vanilla implementation
This is now 18% faster than the vanilla implementation
Remove previous implementation
it's great, thanks @fmassa ! Open question: how would it make the more sense to integrate it with the rest, for people who build from the registers ? I can add a follow up PR to add that to the existing attention mechanisms, or it could just be a flag for "scaled dot product" (or something else ?) |
ref = ref_attention(query, key, value) | ||
ref.backward(torch.ones_like(query)) | ||
|
||
# there is some extra precision loss in the CPU implementation due to an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the comment, helpful !
q = torch.rand(shape, device=device) | ||
sub_label = f"B={B}, M={M}, K={K}" | ||
|
||
if True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a debug flag that is sometimes helpful: sometimes I "break" the kernel by removing some parts of the computation and see what speedup I would get. But doing so means that the computation won't be correct anymore, so it was useful to just disable correctness checks.
I can remove this in if you want, but as I expect to still do some more performance tuning, I'd like to keep this around for a bit longer if it's ok with you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's totally ok, flagging it just in case but understood, no worries
pprint.pprint(mem_use) | ||
|
||
|
||
benchmark_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] maybe possible to factorize the two, but not super important, good tool to have already !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it's totally possible. I've also added in a separate branch a benchmark_forward_and_backward
case, and it started to have quite a bit of duplication. I can look into refactoring this up in a follow-up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not urgent and not blocking, same as above more of a mental note, sounds good
at::TensorAccessor<scalar_t, 3> query, | ||
at::TensorAccessor<scalar_t, 3> key, | ||
at::TensorAccessor<scalar_t, 3> value, | ||
at::TensorAccessor<scalar_t, 3> buffer //, | ||
at::TensorAccessor<scalar_t, 3> buffer, | ||
bool compute_logsumexp | ||
// at::TensorAccessor<int64_t, 2> mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahah, next step is to make it sparse ? :D (not that much of a joke..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, integrating sparsity is in the plans :-) But before, I'll look into the K > 32
case
@@ -90,15 +92,18 @@ void attention_kernel( | |||
for (int64_t k = 0; k < K; k++) { | |||
oo[k] = buf[k] / s_prime; | |||
} | |||
if (compute_logsumexp) | |||
logsumexp[i][j] = m_prime + std::log(s_prime); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I didn't think of that, this is why your backward is so fast, and it does not weight that much actually (one per line, not the whole attention map)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, keeping this temporary around made the backward kernel faster and also easier to implement, so I decided that it was worth the extra memory.
Also, in a follow-up PR I'll avoid allocating the logsumexp
buffer in the function if compute_logsumexp
is false
, so that this is only a memory price to pay during training
@@ -58,6 +59,25 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) { | |||
out[0] /= x1; | |||
} | |||
|
|||
template <typename scalar_t> | |||
__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float4 val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting, I thought that this was a little costly and not too much of a good idea in practice, looks like it's not a good intuition.. We're doing this (accumulate across threads) on the triton side for layernorm, but could be a good idea to extend this for linear layer / bias gradient, I'll have a look when I get the time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also a bit afraid of the atomicAdds in the beginning, but it turned out to not be too slow. Plus it made it easier to parallelize the kernel, so why not.
@@ -393,6 +416,7 @@ __global__ void attention_kernel( | |||
output_block[q_item_idx] = | |||
reinterpret_cast<vec_t*>(output[batch_idx][index].data()); | |||
m_prime[q_item_idx] = -std::numeric_limits<scalar_t>::infinity(); | |||
logsumexp_block[q_item_idx] = &logsumexp[batch_idx][index]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it was already there in the FW PR, but I missed this, could you elaborate ? You're collecting all the pointers as a first step ? Necessary for the pragma unroll down the line ? Trying to understand this cuda trick :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main benefit of this preamble is to handle sequence lengths which are not multiple of 32 (or the block size which I'm using). Instead of handling the out-of-bonds directly in the hotpaths of the kernel, I handle it beforehand and repeat the last element if needed, so I don't index out of the bounds of the kernel.
There are probably other / better ways of handling generic sequence lengths, but this was the easiest to implement so I decided to go for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh I didn't see that at first sight, makes sense. I don't know of a better way myself (except that maybe that CUDA has similar concepts as all the graphics interfaces, where you can clamp memory/textures fetches automatically so that it does not out of bound and repeat or pad). It's been a long time since I wrote in Cuda so by now I just don't know..
@@ -473,56 +497,30 @@ __global__ void attention_kernel( | |||
output_block[q_item_idx][k] = tmp; | |||
} | |||
} | |||
|
|||
if (compute_logsumexp) { | |||
#pragma unroll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
guess is that the pointer array is because of this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might also help as well, yes, although it was not directly the main reason why I did it
at::Tensor logsumexp = at::empty({B, M}, query.options()); | ||
|
||
// have to pass compute_logsumexp as a template parameter | ||
// otherwise there is a slowdown in the kernel... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
register spilling ? (like, it inlines compute logsumexp + a branch, and it takes too much space ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be. Maybe I should try this out as a non-inlined function call to see if it make things better. Lots of improvements to be done in the future! :-)
vec_t tt = __ldg(vb[k_item_idx] + k); | ||
#pragma unroll | ||
for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { | ||
sputnik::VectorCompute<vec_t>::Dot( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I guess that computeDot
is the most interesting place for that, but I cannot comment there.) It could be interesting to pull in tensor cores for the .dot() for fp16, and for fp32 on newer hardware (A100 +), I don't know if there are existing primitives to preferably use there (from Cuda/cudnn/cublas, like in this example or from torch).
Maybe that it's not needed actually, sorry for bringing these tensor cores up all the time but just trying to think ahead for the possible next steps (and possibly hidden caveats if you developed on a P100)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorCores are going to be a very important next step indeed, and I'm already looking into how to use them. But before that I think it might be better to support the K > 32
case first, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree on bigger K first, at this point the perf impact is not super clear to me (for fp32 at least, on the 3080 which has tensor cores this was very competitive vs. pytorch and tensor cores) while the limitation of a small K is very well defined
r = xformers.ops.memory_efficient_attention(q, q, q) | ||
|
||
rr = ref_attention(q, q, q) | ||
assert (r - rr).abs().max() < 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not pass on a 3080 / cuda 11.6, could be interesting to test with the T4s on CircleCi, could well be because of TF32 (you would need to switch the torch flag forcing fp32 computations). Implicitly this probably means that the torch implementation switched to tensor cores I think, which changes the time difference in between the two implementations (but not a fundamental issue)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I should probably change those defaults, or just disable TF32
in the benchmarks (but that makes it for slower baselines), or just disable this correctness check by default. Which one would you prefer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would switch tf32 off here, I think that it's the best correctness check : you assume fp32 in the kernel, let's check correctness against fp32 ? (torch.backends.cuda.matmul.allow_tf32 = False
)
Good to keep in mind in the benchmarks that the comparison is not iso-accuracy by the way, your implementation is actually more precise :)
|
||
rr = ref_attention(q, q, q) | ||
rr.backward(torch.ones_like(q)) | ||
assert (grad - q.grad).abs().max() < 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, this does not pass on a 3080, guess is because of tf32 vs. float32 (would be the same with a A100, not sure about tf32 on a V100)
int TILE_SIZEQ, | ||
int TILE_SIZEK, | ||
bool check_bounds> | ||
__global__ void attention_backward_grad_v_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit harder for me to follow since it's not something that I dived into (vs. the fw pass) so I'm forced to skim a liltte.. looks clean as always
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, I should probably write some comments in the top of the functions with the implementation that this is actually doing, as a PyTorch code for ease of read. If you don't mind, I'd like to get this PR merged now, but I can add more comments with the K > 32
PR that I'm working on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done > wild plans, sound good to me !
#pragma unroll | ||
for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) { | ||
fact[kBlockSizeQ * threadIdx.x + q_item_idx] | ||
[kBlockSizeK * threadIdx.y + k_item_idx] = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no faster call than that to zero the buffer ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean with memset
, or with default initializer (like = {};
) ?
It turns out that shared memory doesn't allow for default initializers, so I let all threads in the block to collaboratively zero the buffer. From some basic profiling, this part of the code doesn't seem to take any noticeable time, so I decided to leave it like that.
Also worth noting that I need to zero this only in the cases where the input sequence length is not a multiple of 32, otherwise there would be uninitialized values in there. Now that we have a template parameter in the kernel for this case, I could probably also put this under an if (check_bounds)
, but I wouldn't expect it to bring noticeable improvements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I was thinking about memset or similar, it's super weird that cuda does not offer a primitive to that ? Ok for the timing, could be that the compiler optimizes this away actually..
scalar_t normalizer[kBlockSizeQ]; | ||
scalar_t tmp_sum[kBlockSizeQ] = {0}; | ||
|
||
vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the other one, collect all pointers first then proceed, ok. It's typical I guess but I'm not familiar enough probably
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it was to simplify a bit the handling of the code in other parts of the kernel. Maybe it saves on a couple of instructions, so it might be slightly faster to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks massively great to me, thanks @fmassa ! Couple of questions, mostly to understand better on my side and ask for possible follow ups down the line
if that helps, speed report on a 3080
|
Another follow up question is about the supported K, I think that 256 is typical in NLP / big models for instance, or 64 / 128 for the GPT series. Certainly not blocking, but it's a request bound to come up :) |
Thanks for the timely review @blefaudeux ! Yes, the Using TensorCores is also next in my optimization list, which I hope to get to after I'm back from holidays. |
BTW, I forgot to answer one of your comments:
My thinking was to enable this directly in |
sounds good to me, I would have done that also ! |
What does this PR do?
This PR implements the memory-efficient attention mechanism from https://arxiv.org/pdf/2112.05682v2.pdf, with both CPU and CUDA kernels, targetting the backward implementation. For now, only fp32 is supported.
The CPU implementation is naive and not meant to be fast, and is there only as a reference.
The CUDA implementation has competitive runtimes compared to a vanilla PyTorch implementation, while using 10x+ less memory.
Contrary to the forward implementation, the backwards supports inputs of arbitrary number of embedding sizes. I'll probably update the forward implementation to use a similar approach in the future.
In order to keep the backwards somewhat efficient, we need to return the
logsumexp
during forward as well. For some reason, I needed to template it in CUDA otherwise I would face performance slowdowns.In the same vein, in order to support arbitrary sequence lengths, I use a masking approach. But the masking (and in particular the
min(index, M)
operations) are very slow, so I templated this as well so that we don't need to run this if the sequence length is a multiple of 32.Full details for the benchmark I run can be found in here
Fixes #161.