-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cache support to decoding journey #745
Add cache support to decoding journey #745
Conversation
476867d
to
38f0c5e
Compare
Thanks for the PR! This is some of the most complex code we have added yet! Went deep into this and played around with the code for most of yesterday. Short version, I have a commit based on your branch we may want to just fold into the PR -> 11181a1. It should be both a little more efficient, and cuts down on code a lot. Some bigger notes:
Some more minor notes:
Open questions:
|
@mattdangerw Thanks Matt!
Putting things into
I think I am already doing that, the only exception is the first pass to build the cache. I checked your commit, it looks good to me to combine the first pass as well.
I believe the "cached" path is broken on TPU due to XLA issue, so I guess we should keep the non-cached one until we solve the XLA issue. BTW, "cache" is not guaranteed to be faster, and actually slower when the
They are actually required for some decoding algorithm, such as contrastive search. Basically these decoding algo needs to know previous token's representation.
Fine with me!
I am not sure if we want to diverge from
I think we need to keep cache of a static shape?
SGTM!
good point, SGTM!
Sure, but I don't want this to be blocking. HF proves it's doable in their way, but we may not want to mirror their coding logic.
Holy a lot T T.... |
How short practically? Is it enough that real world generative use cases will want the non-cached version in practice? If we can scope down (at least temporarily) here, that would make this much easier to land.
Static shape is actually the only thing a unstack/stack combo will handle. Take a look at the commit I linked, it's a pure readability and code reduction. Logic untouched.
The issue is our whole setup is for XLA right now (all of our padding to static shapes). If we don't support XLA, we actually have a overly complex and underly efficient version right now compared to what is "non XLA compilable." Take a look at the original GPT implementation for a non-XLA version that uses Let's keep digging into XLA failures! |
This is an interesting point we should discuss. If the decoding algorithm needs it, seems like that could still not leak into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some small comments as we do the more major surgery talked about on the comment thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is shaping up!
There are a few unaddressed comments from an earlier round of review, added another pass now.
Still want to play around with our underlying graph a bit more, so may add a few more as I go.
@@ -235,10 +371,26 @@ def generate( | |||
# backward compat. | |||
sampler.jit_compile = self.jit_compile | |||
sampler.run_eagerly = self.run_eagerly | |||
initial_output, cache = None, None | |||
x, _, _ = self.preprocessor(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing I just noticed, this is assuming that the preprocessor has a wider sequence length than max_length, I'm not sure we should make that assumption, as preprocessor
is passed in by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay when preprocessor has a shorter seq length. The cache's length is max(max_length, gpt2_seq_len)
, if the preprocessor has a seq_len < max_length < gpt2_seq_len
, then inside the sampler
call, the prompt will be padded to max_length
. If seq_len > gpt2_seq_len
, the forward pass is broken anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if, say...
- Initial prompt length is 20.
- preprocessor.sequence_length is 10.
- max_length is 40.
The prompt will get truncated right and the cache won't be correctly seeded right? Just seems prudent to avoid coupling with preprocessor sequence length entirely.
@chenmoneygithub also a bug! If you pass a variable length batch, e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mechanism is fine -- easy to read and maintain 👍
The difference between `call_with_cache` and normal `__call__` is in | ||
this method, a `cache` arg is set, and the inputs is of | ||
`sequence_length=1`. By cachine the previous key/value in multi-head | ||
attention, we avoid recomputing the outputs of seen tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstrings might need a round of copyedits in general
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
call_with_cache
adds an additional forward pass for the model for autoregressive inference. Unlike calling the model directly, this method allows caching previous key/value Tensors in multi-head attention layer, and avoids recomputing the outputs of seen tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple more comments.
8323427
to
cd37dec
Compare
cd37dec
to
2991df4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Approving!
Mainly some copy edits. Though I do think we might want to figure out the one unresolved comment change re preprocessor.sequence_length messing up in some edge cases.
The difference between `call_with_cache` and normal `__call__` is in | ||
this method, a `cache` arg is set, and the inputs is of | ||
`sequence_length=1`. By cachine the previous key/value in multi-head | ||
attention, we avoid recomputing the outputs of seen tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
call_with_cache
adds an additional forward pass for the model for autoregressive inference. Unlike calling the model directly, this method allows caching previous key/value Tensors in multi-head attention layer, and avoids recomputing the outputs of seen tokens.
@@ -235,10 +371,26 @@ def generate( | |||
# backward compat. | |||
sampler.jit_compile = self.jit_compile | |||
sampler.run_eagerly = self.run_eagerly | |||
initial_output, cache = None, None | |||
x, _, _ = self.preprocessor(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if, say...
- Initial prompt length is 20.
- preprocessor.sequence_length is 10.
- max_length is 40.
The prompt will get truncated right and the cache won't be correctly seeded right? Just seems prudent to avoid coupling with preprocessor sequence length entirely.
@mattdangerw Thanks! Re the Actually as we talked offline, we might want to rethink our logic on calling |
sgtm to revisit this after looking at simplifying the preprocessing generally for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks good to me! 👍
Thanks all for reviewing this looooong PR! Cheers. |
There are a few things to note here:
num=num_beams
cache. I will do this as a followup.This is oddly and insanely complex, so thanks for the review work.
Cheers.