Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache support to decoding journey #745

Merged
merged 14 commits into from
Feb 24, 2023

Conversation

chenmoneygithub
Copy link
Contributor

@chenmoneygithub chenmoneygithub commented Feb 11, 2023

There are a few things to note here:

  • XLA path is not supported as it returns a different result from non-xla graph, I am looking into the cause.
  • Beam search does not support cache for now, because it needs to maintain num=num_beams cache. I will do this as a followup.
  • We are using a lot of slicing and update calls now, unsure how to optimize it. I have been looking into this for too long, so I am trapped in the rabbit hole.
  • To play with the API and compare the performance, you can use this colab.

This is oddly and insanely complex, so thanks for the review work.

Cheers.

@chenmoneygithub chenmoneygithub marked this pull request as draft February 11, 2023 05:37
@chenmoneygithub chenmoneygithub force-pushed the cached-decoding branch 3 times, most recently from 476867d to 38f0c5e Compare February 14, 2023 00:18
@chenmoneygithub chenmoneygithub marked this pull request as ready for review February 14, 2023 01:26
@mattdangerw
Copy link
Member

Thanks for the PR! This is some of the most complex code we have added yet! Went deep into this and played around with the code for most of yesterday.

Short version, I have a commit based on your branch we may want to just fold into the PR -> 11181a1. It should be both a little more efficient, and cuts down on code a lot.

Some bigger notes:

  • Let's move all of the new forward pass to GptCausalLM. There is really no use case for this cache stuff besides auto-regressive decoding, so we can limit the "blast radius" of the change significantly by leaving call_with_cache only on the task model and not the backbone.
  • Rather then slice the input short after doing token and position embeddings, we should actually keep the input short coming in to the model. Basically "token_ids" during generation should have shape (bs, 1). There are a few reasons for this. First, it is more efficient, in the version you have here you are recomputing the same token and position embeddings over and over. Second, it allows for much cleaner state passing in the sampling code. This requires some code changes for position embedding and mask computation.
  • Let's only support a "cached" decoding path right now. Support both cached and non-cached makes the code really hard to follow, and I don't think there is a good case for it.
  • Let's avoid passing unnecessary state between the sampler and masked lm classes, e.g. initial_output and existing_outputs.

Some more minor notes:

  • Let's avoid a dynamic return types where we don't have to. I think TransformerDecoder has to support a return with a cache or not (for compatibility). But the cached attention layer can just always return the cache (None or not).
  • Let's remove what options we can from CachedMultiHeadAttention and keep it simple, we don't need 1-1 parity with MultiHeadAttention (that class is complex).
  • We are using dynamic_update_slice to do simple stacking and unstacking of the cache tensor. Let's just used tf.stack and tf.unstack. Much more readable.
  • Similarly in the sampler we use scatten_nd op in the sampler to do the same thing as dynamic_update_slice. Let's just use dynamic_update_slice for readability.
  • I thought the names cache and cache_index were nicely parallel as inputs, the reader will know they are linked!

Open questions:

  • Let's try to keep figuring out XLA compilation failures here. We know it is doable!
  • With the cache only approach, beam search in it's current form is pretty useless. How much work is adding support for beam search?

@chenmoneygithub
Copy link
Contributor Author

chenmoneygithub commented Feb 14, 2023

@mattdangerw Thanks Matt!

Let's move all of the new forward pass to GptCausalLM.

Putting things into GptCausalLM sounds good to me.

Rather then slice the input short after doing token and position embeddings...

I think I am already doing that, the only exception is the first pass to build the cache. I checked your commit, it looks good to me to combine the first pass as well.

Let's only support a "cached" decoding path right now.

I believe the "cached" path is broken on TPU due to XLA issue, so I guess we should keep the non-cached one until we solve the XLA issue. BTW, "cache" is not guaranteed to be faster, and actually slower when the max_length is short.

Let's avoid passing unnecessary state between the sampler and masked lm classes, e.g. initial_output and existing_outputs.

They are actually required for some decoding algorithm, such as contrastive search. Basically these decoding algo needs to know previous token's representation.

Let's avoid a dynamic return types where we don't have to.

Fine with me!

Let's remove what options we can from CachedMultiHeadAttention and keep it simple..

I am not sure if we want to diverge from MultiHeadAttention, because techinically this layer is a addon to MHA.

We are using dynamic_update_slice to do simple stacking and unstacking of the cache tensor.

I think we need to keep cache of a static shape? dynamic_update_slice is actually not that bad IMO, it's much more straightforward than tf.tensor_scatter_nd_update.

Similarly in the sampler we use scatten_nd op in the sampler to do the same thing as dynamic_update_slice

SGTM!

I thought the names cache and cache_index were nicely parallel as inputs, the reader will know they are linked!

good point, SGTM!

Let's try to keep figuring out XLA compilation failures here. We know it is doable!

Sure, but I don't want this to be blocking. HF proves it's doable in their way, but we may not want to mirror their coding logic.

With the cache only approach, beam search in it's current form is pretty useless. How much work is adding support for beam search?

Holy a lot T T....

@mattdangerw
Copy link
Member

BTW, "cache" is not guaranteed to be faster, and actually slower when the max_length is short.

How short practically? Is it enough that real world generative use cases will want the non-cached version in practice? If we can scope down (at least temporarily) here, that would make this much easier to land.

We are using dynamic_update_slice to do simple stacking and unstacking of the cache tensor.

I think we need to keep cache of a static shape? dynamic_update_slice is actually not that bad IMO, it's much more straightforward than tf.tensor_scatter_nd_update.

Static shape is actually the only thing a unstack/stack combo will handle. Take a look at the commit I linked, it's a pure readability and code reduction. Logic untouched.

Let's try to keep figuring out XLA compilation failures here. We know it is doable!

Sure, but I don't want this to be blocking. HF proves it's doable in their way, but we may not want to mirror their coding logic.

The issue is our whole setup is for XLA right now (all of our padding to static shapes). If we don't support XLA, we actually have a overly complex and underly efficient version right now compared to what is "non XLA compilable." Take a look at the original GPT implementation for a non-XLA version that uses concat and dynamic shapes that is both simpler and faster than ours.

Let's keep digging into XLA failures!

@mattdangerw
Copy link
Member

Let's avoid passing unnecessary state between the sampler and masked lm classes, e.g. initial_output and existing_outputs.

They are actually required for some decoding algorithm, such as contrastive search. Basically these decoding algo needs to know previous token's representation.

This is an interesting point we should discuss. If the decoding algorithm needs it, seems like that could still not leak into GPT2MaskedLM implementation. If would be very doable to write a sampler that used both a previous token and next token logits, without that logic needing to spread across to our language model classes as well.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small comments as we do the more major surgery talked about on the comment thread.

keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
keras_nlp/layers/cached_multi_head_attention_test.py Outdated Show resolved Hide resolved
keras_nlp/layers/transformer_decoder.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_backbone.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_backbone.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_backbone.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/samplers/beam_sampler.py Show resolved Hide resolved
keras_nlp/samplers/sampler.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is shaping up!

There are a few unaddressed comments from an earlier round of review, added another pass now.

Still want to play around with our underlying graph a bit more, so may add a few more as I go.

keras_nlp/layers/transformer_decoder.py Show resolved Hide resolved
keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/samplers/sampler.py Outdated Show resolved Hide resolved
keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
@@ -235,10 +371,26 @@ def generate(
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
initial_output, cache = None, None
x, _, _ = self.preprocessor(prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I just noticed, this is assuming that the preprocessor has a wider sequence length than max_length, I'm not sure we should make that assumption, as preprocessor is passed in by the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay when preprocessor has a shorter seq length. The cache's length is max(max_length, gpt2_seq_len), if the preprocessor has a seq_len < max_length < gpt2_seq_len, then inside the sampler call, the prompt will be padded to max_length. If seq_len > gpt2_seq_len, the forward pass is broken anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if, say...

  • Initial prompt length is 20.
  • preprocessor.sequence_length is 10.
  • max_length is 40.

The prompt will get truncated right and the cache won't be correctly seeded right? Just seems prudent to avoid coupling with preprocessor sequence length entirely.

@mattdangerw
Copy link
Member

@chenmoneygithub also a bug!

If you pass a variable length batch, e.g. print(gpt2_lm.generate(["that's weird", "that's even weirder"], sampler="greedy", max_length=30)), the longer sentence wont be respected. The word "weirder" will get deleted in the output.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mechanism is fine -- easy to read and maintain 👍

keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_backbone.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
The difference between `call_with_cache` and normal `__call__` is in
this method, a `cache` arg is set, and the inputs is of
`sequence_length=1`. By cachine the previous key/value in multi-head
attention, we avoid recomputing the outputs of seen tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstrings might need a round of copyedits in general

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

call_with_cache adds an additional forward pass for the model for autoregressive inference. Unlike calling the model directly, this method allows caching previous key/value Tensors in multi-head attention layer, and avoids recomputing the outputs of seen tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes!

keras_nlp/samplers/sampler.py Outdated Show resolved Hide resolved
keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
keras_nlp/layers/transformer_decoder.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more comments.

keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Approving!

Mainly some copy edits. Though I do think we might want to figure out the one unresolved comment change re preprocessor.sequence_length messing up in some edge cases.

keras_nlp/layers/cached_multi_head_attention.py Outdated Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
The difference between `call_with_cache` and normal `__call__` is in
this method, a `cache` arg is set, and the inputs is of
`sequence_length=1`. By cachine the previous key/value in multi-head
attention, we avoid recomputing the outputs of seen tokens.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

call_with_cache adds an additional forward pass for the model for autoregressive inference. Unlike calling the model directly, this method allows caching previous key/value Tensors in multi-head attention layer, and avoids recomputing the outputs of seen tokens.

keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
@@ -235,10 +371,26 @@ def generate(
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
initial_output, cache = None, None
x, _, _ = self.preprocessor(prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if, say...

  • Initial prompt length is 20.
  • preprocessor.sequence_length is 10.
  • max_length is 40.

The prompt will get truncated right and the cache won't be correctly seeded right? Just seems prudent to avoid coupling with preprocessor sequence length entirely.

keras_nlp/layers/cached_multi_head_attention_test.py Outdated Show resolved Hide resolved
@chenmoneygithub
Copy link
Contributor Author

@mattdangerw Thanks!

Re the seq_len of preprocessor - to me it's users' liability to choose the right seq_len of preprocessor, otherwise they will also hit performance issue if they just want to use the backbone itself. For most of our users, they probably won't hit this issue. I am also fine with swapping the sequence_length to gpt2.sequence_length before calling preprocessor.

Actually as we talked offline, we might want to rethink our logic on calling preprocessor and tokenizer inside the generate method.

@mattdangerw
Copy link
Member

Actually as we talked offline, we might want to rethink our logic on calling preprocessor and tokenizer inside the generate method.

sgtm to revisit this after looking at simplifying the preprocessing generally for generate(). Thanks!

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good to me! 👍

@chenmoneygithub
Copy link
Contributor Author

Thanks all for reviewing this looooong PR! Cheers.

@chenmoneygithub chenmoneygithub merged commit f988582 into keras-team:master Feb 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants