Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache support to decoding journey #745

Merged
merged 14 commits into from
Feb 24, 2023
3 changes: 3 additions & 0 deletions keras_nlp/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.layers.cached_multi_head_attention import (
CachedMultiHeadAttention,
)
from keras_nlp.layers.f_net_encoder import FNetEncoder
from keras_nlp.layers.masked_lm_head import MaskedLMHead
from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
115 changes: 115 additions & 0 deletions keras_nlp/layers/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cached MHA layer based on `keras.layers.MultiHeadAttention`."""

import tensorflow as tf
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice


class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
"""MutliHeadAttention layer with cache support.
In autoregressive decoding, it's a common practice to cache the key/value in
multi-head attention of previously seen tokens in order to make the
computation faster. With cached K and V, we can compute the attention output
of the last token without recomputing the forward pass for previously seen
tokens. This caching method is only useful during decoding, and should not
be used during training.
Call arguments:
query: Query `Tensor` of shape `(B, T, dim)` if `cache=None`,
otherwise `(B, 1, dim)`.
value: Value `Tensor` of shape `(B, S, dim)` if `cache=None`,
otherwise `(B, 1, dim)`.
key: Optional key `Tensor` of shape `(B, S, dim)` if `cache=None`,
otherwise `(B, 1, dim)`. If not given, will use `value` for both
`key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `(B, T, S)` if `cache=None`,
otherwise `(B, 1, S)`. `attention_mask` prevents
attention to certain positions. The boolean mask specifies which
query elements can attend to which key elements, 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.
cache: a dense float Tensor. The cache of key/value of leading tokens.
`cache` is of shape [B, 2, max_seq_len, num_heads, key_dims].
cache_index: a int or int Tensor, the index of the current token being
processed. If `cache_index=None` while `cache` is set, it means
it's the first pass to build the cache.
Returns:
An (attention_output, cache) tuple. `attention_output` is the result of
the computation, of shape `(B, T, E)`, where `T` is for target sequence
shapes and `E` is the query input last dimension if `output_shape` is
`None`. Otherwise, the multi-head outputs are projected to the shape
specified by `output_shape`. `cache` is the updated cache.
"""

def call(
self,
query,
value,
key=None,
attention_mask=None,
use_causal_mask=False,
cache=None,
cache_index=None,
):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value

attention_mask = self._compute_attention_mask(
query,
value,
key=key,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

query = self._query_dense(query)
key = self._key_dense(key)
value = self._value_dense(value)

if cache is not None:
if cache_index is None:
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
seq_len = tf.shape(query)[1]
k_update_indices = [0, 0, 0, 0, 0]
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
v_update_indices = [0, 1, 0, 0, 0]
cache_index = seq_len - 1
else:
k_update_indices = [0, 0, cache_index, 0, 0]
v_update_indices = [0, 1, cache_index, 0, 0]
cache = dynamic_update_slice(
cache, key[:, tf.newaxis, ...], k_update_indices
)
cache = dynamic_update_slice(
cache, value[:, tf.newaxis, ...], v_update_indices
)
key = cache[:, 0, : cache_index + 1, :, :]
value = cache[:, 1, : cache_index + 1, :, :]

query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self._key_dim)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._dropout_layer(attention_scores)

attention_output = tf.einsum(
self._combine_equation, attention_scores, value
)
attention_output = self._output_dense(attention_output)
return attention_output, cache
79 changes: 79 additions & 0 deletions keras_nlp/layers/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for CachedMultiHeadAttention."""

import tensorflow as tf
from absl.testing import parameterized
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.layers.cached_multi_head_attention import (
CachedMultiHeadAttention,
)


class CachedMultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase):
def test_valid_call(self):
layer = CachedMultiHeadAttention(num_heads=2, key_dim=4)
x = tf.random.uniform(shape=[2, 2, 8])
layer(query=x, value=x)

def test_cache_call_is_correct(self):
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
batch_size = 2
seq_len = 5
num_heads = 2
key_dim = 4

layer = CachedMultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
x = tf.random.uniform(shape=[batch_size, seq_len, num_heads * key_dim])
cache = tf.zeros([batch_size, 2, seq_len, num_heads, key_dim])
# Build the intial cache.
intial_seq_len = 2
initial_inputs = x[:, :intial_seq_len, :]
outputs = tf.zeros_like(x)
output, cache = layer(
query=initial_inputs,
value=initial_inputs,
use_causal_mask=True,
cache=cache,
)
# Update the outputs in place.
outputs = dynamic_update_slice(outputs, output, [0, 0, 0])

def call(i, cache, outputs):
def loop_body(i, cache, outputs):
# Compute the rest tokens.
current_input = x[:, i : i + 1, :]
output, cache = layer(
query=current_input,
value=current_input,
use_causal_mask=True,
cache=cache,
cache_index=i,
)
outputs = dynamic_update_slice(outputs, output, [0, i, 0])
return i + 1, cache, outputs

i, cache, cached_outputs = tf.while_loop(
cond=lambda i, cache, outputs: i < seq_len,
body=loop_body,
loop_vars=[i, cache, outputs],
)
return cached_outputs

cached_outputs = call(intial_seq_len, cache, outputs)
graph_call = tf.function(call)
graph_cached_outputs = graph_call(intial_seq_len, cache, outputs)
normal_outputs, _ = layer(query=x, value=x, use_causal_mask=True)
self.assertAllClose(cached_outputs, normal_outputs)
self.assertAllClose(graph_cached_outputs, normal_outputs)
24 changes: 20 additions & 4 deletions keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,9 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.cached_multi_head_attention import (
CachedMultiHeadAttention,
)
from keras_nlp.utils.keras_utils import clone_initializer

from keras_nlp.layers.transformer_layer_utils import ( # isort:skip
@@ -141,7 +144,7 @@ def _build(self, input_shape, has_cross_attention):
head_dim = int(hidden_dim // self.num_heads)

# Self attention layers.
self._self_attention_layer = keras.layers.MultiHeadAttention(
self._self_attention_layer = CachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
dropout=self.dropout,
@@ -208,6 +211,8 @@ def call(
decoder_attention_mask=None,
encoder_padding_mask=None,
encoder_attention_mask=None,
cache=None,
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
cache_index=None,
):
"""Forward pass of the TransformerDecoder.
@@ -227,8 +232,15 @@ def call(
encoder_attention_mask: a boolean Tensor. Customized encoder
sequence mask, must of shape
[batch_size, encoder_sequence_length, encoder_sequence_length].
cache: a dense float Tensor. The cache of key/value of leading
tokens. `cache` is of shape [B, 2, max_seq_len, num_heads,
key_dims].
cache_index: a int or int Tensor, the index of the current token
being processed. If `cache_index=None` while `cache` is set, it
means it's the first pass to build the cache.
Returns:
A Tensor of the same shape as the `decoder_sequence`.
Either a tuple of (outputs, cache) if a cache was passed, or a
single value outputs if a cache was not passed.
"""

has_encoder_sequence = encoder_sequence is not None
@@ -271,9 +283,11 @@ def call(
residual = x
if self.normalize_first:
x = self._self_attention_layernorm(x)
x = self._self_attention_layer(
x, cache = self._self_attention_layer(
query=x,
value=x,
cache=cache,
cache_index=cache_index,
attention_mask=self_attention_mask,
)
x = self._self_attention_dropout(x)
@@ -313,7 +327,9 @@ def call(
if not self.normalize_first:
x = self._feedforward_layernorm(x)

return x
if cache is None:
return x
return (x, cache)

def get_config(self):
config = super().get_config()
52 changes: 52 additions & 0 deletions keras_nlp/layers/transformer_decoder_test.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.layers import transformer_decoder

@@ -265,6 +266,57 @@ def test_mask_propagation_without_cross_attention(self):
outputs = decoder(decoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)

def test_cached_decoding_is_correct(self):
batch_size = 2
seq_len = 5
num_heads = 2
hidden_dim = 8

layer = transformer_decoder.TransformerDecoder(
intermediate_dim=4,
num_heads=num_heads,
)
x = tf.random.uniform(shape=[batch_size, seq_len, hidden_dim])
cache = tf.zeros(
[batch_size, 2, seq_len, num_heads, hidden_dim // num_heads]
)
# Build the intial cache.
initial_seq_len = 2
initial_inputs = x[:, :initial_seq_len, :]
outputs = tf.zeros_like(x)
output, cache = layer(
decoder_sequence=initial_inputs,
cache=cache,
)
# Update the outputs in place.
outputs = dynamic_update_slice(outputs, output, [0, 0, 0])

def call(i, cache, outputs):
def loop_body(i, cache, outputs):
# Compute the rest tokens.
current_input = x[:, i : i + 1, :]
output, cache = layer(
decoder_sequence=current_input,
cache=cache,
cache_index=i,
)
outputs = dynamic_update_slice(outputs, output, [0, i, 0])
return i + 1, cache, outputs

i, cache, cached_outputs = tf.while_loop(
cond=lambda i, cache, outputs: i < seq_len,
body=loop_body,
loop_vars=[i, cache, outputs],
)
return cached_outputs

cached_outputs = call(initial_seq_len, cache, outputs)
graph_call = tf.function(call)
graph_cached_outputs = graph_call(initial_seq_len, cache, outputs)
normal_outputs = layer(decoder_sequence=x)
self.assertAllClose(cached_outputs, normal_outputs)
self.assertAllClose(graph_cached_outputs, normal_outputs)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
1 change: 0 additions & 1 deletion keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
@@ -102,7 +102,6 @@ def from_preset(
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model

4 changes: 3 additions & 1 deletion keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
@@ -125,7 +125,9 @@ def __init__(
)(token_embedding)

# Sum and apply dropout to embeddings.
x = keras.layers.Add()((token_embedding, position_embedding))
x = keras.layers.Add(name="embeddings_add")(
(token_embedding, position_embedding)
)
x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
145 changes: 142 additions & 3 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@

import tensorflow as tf
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

import keras_nlp
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
@@ -200,12 +201,138 @@ def backbone_cls(cls):
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def _get_token_probability(self, prompt, mask):
def call_with_cache(self, token_ids, padding_mask, cache, cache_index=None):
"""Forward pass of `GPT2CausalLM` with cache.
`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value Tensors in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.
Args:
token_ids: a dense int Tensor, input token ids.
padding_mask: a dense bool Tensor, input padding mask.
cache: a dense float Tensor, the cache of key and value.
cache_index: int, or int Tensor, defaults to None. If set, it
represents the index of current inputs in the whole sequence.
Returns:
A (logits, cache) tuple. Where the first output is the language
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
model logits for the input token_ids and the second output is the
cache.
"""
transformer_layers = []
for i in range(self.backbone.num_layers):
transformer_layers.append(
self.backbone.get_layer(f"transformer_layer_{i}")
)
layer_norm = self.backbone.get_layer("layer_norm")
token_embeddings = self.backbone.get_layer("token_embedding")
position_embeddings = self.backbone.get_layer("position_embedding")
embeddings_add = self.backbone.get_layer("embeddings_add")
embeddings_dropout = self.backbone.get_layer("embeddings_dropout")

token_embedding = token_embeddings(token_ids)
if cache_index is None:
position_embedding = position_embeddings(token_embedding)
else:
position_embedding = position_embeddings.position_embeddings[
cache_index, :
]
x = embeddings_add((token_embedding, position_embedding))
x = embeddings_dropout(x)
# Each `TransformerDecoder` layer has a cache, we update separately.
caches = tf.unstack(cache, axis=1)
for i, transformer_layer in enumerate(transformer_layers):
current_cache = caches[i]
x, current_cache = transformer_layer(
x,
decoder_padding_mask=padding_mask,
cache=current_cache,
cache_index=cache_index,
)
caches[i] = current_cache
cache = tf.stack(caches, axis=1)
x = layer_norm(x)
x = tf.matmul(
x,
self.backbone.token_embedding.embeddings,
transpose_b=True,
)
return x, cache

def build_initial_cache(self, initial_inputs, max_length):
"""Build initial cache based on the prompt.
This method should be called before the decoding loop to build the
initial cache. The cache is of shape [batch_size, `self.num_layers`, 2
max_length, `self.num_heads`, `self.hidden_dim // self.num_heads`].
The first dim represents it's a key or value in multi-head attention.
Args:
initial_inputs: a dense Tensor, the initial inputs to the decoding
loop.
max_length: int, the max length of the generated sequence.
Returns:
A (logits, cache) tuple. The first output is the language
model logits for the input token_ids and the second output is the
cache.
"""
token_ids = initial_inputs["token_ids"]
padding_mask = initial_inputs["padding_mask"]

if max_length < self.backbone.max_sequence_length:
token_ids = token_ids[:, :max_length]
padding_mask = padding_mask[:, :max_length]

batch_size = tf.shape(token_ids)[0]
outputs = tf.zeros(
[batch_size, max_length, self.backbone.vocabulary_size]
)
cache = tf.zeros(
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
[
batch_size,
self.backbone.num_layers,
2,
max_length,
self.backbone.num_heads,
self.backbone.hidden_dim // self.backbone.num_heads,
],
)

output, cache = self.call_with_cache(
token_ids,
padding_mask,
cache=cache,
)
outputs = dynamic_update_slice(outputs, output, [0, 0, 0])
return outputs, cache

def _get_token_probability(
self,
prompt,
mask,
cache=None,
cache_index=None,
):
model_inputs = {
"token_ids": prompt,
"padding_mask": mask,
}
return self(model_inputs)
if cache_index is None and cache is None:
return self(model_inputs)
if cache_index is not None:
batch_size = tf.shape(prompt)[0]
prompt = tf.slice(prompt, [0, cache_index], [batch_size, 1])
mask = mask[:, : cache_index + 1]
output, cache = self.call_with_cache(
prompt,
mask,
cache,
cache_index,
)
return output, cache

def generate(
self,
@@ -227,6 +354,11 @@ def generate(
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
sampler to be used for text generation.
"""
if self.preprocessor is None:
raise ValueError(
"`self.preprocessor` is None, please make sure `preprocessor` "
"is set before calling `generate`."
)
end_token_id = self.preprocessor.tokenizer.end_token_id

sampler = keras_nlp.samplers.get(sampler)
@@ -235,10 +367,17 @@ def generate(
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
x, _, _ = self.preprocessor(prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I just noticed, this is assuming that the preprocessor has a wider sequence length than max_length, I'm not sure we should make that assumption, as preprocessor is passed in by the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay when preprocessor has a shorter seq length. The cache's length is max(max_length, gpt2_seq_len), if the preprocessor has a seq_len < max_length < gpt2_seq_len, then inside the sampler call, the prompt will be padded to max_length. If seq_len > gpt2_seq_len, the forward pass is broken anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if, say...

  • Initial prompt length is 20.
  • preprocessor.sequence_length is 10.
  • max_length is 40.

The prompt will get truncated right and the cache won't be correctly seeded right? Just seems prudent to avoid coupling with preprocessor sequence length entirely.

if len(x["token_ids"].shape) == 1:
x["token_ids"] = x["token_ids"][tf.newaxis, :]
x["padding_mask"] = x["padding_mask"][tf.newaxis, :]
_, cache = self.build_initial_cache(x, max_length)
next_token_probability = self._get_token_probability
generated = sampler(
self.preprocessor.tokenizer(prompt),
self._get_token_probability,
next_token_probability,
max_length=max_length,
end_token_id=end_token_id,
cache=cache,
)
return self.preprocessor.tokenizer.detokenize(generated)
8 changes: 5 additions & 3 deletions keras_nlp/models/gpt2/gpt2_causal_lm_test.py
Original file line number Diff line number Diff line change
@@ -122,10 +122,12 @@ def test_gpt2_causal_lm_fit_no_preprocessing(self, jit_compile):
self.causal_lm_no_preprocessing.fit(self.preprocessed_dataset)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
("non_jit_compile_cache", False, True),
("non_jit_compile_non_cache", False, False),
("jit_compile_non_cache", True, False),
)
def test_gpt2_causal_lm_generate(self, jit_compile):
self.causal_lm_no_preprocessing.compile(jit_compile=jit_compile)
def test_gpt2_causal_lm_generate(self, jit_compile, use_cache):
self.causal_lm.compile(jit_compile=jit_compile)
self.causal_lm.generate(
self.raw_batch,
max_length=10,
16 changes: 15 additions & 1 deletion keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
"""Beam Sampler."""

import tensorflow as tf
from absl import logging
from tensorflow import keras

from keras_nlp.samplers.sampler import Sampler
@@ -85,13 +86,26 @@ def get_next_token(self, next_token_probs):
pass

def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
self,
prompt,
token_probability_fn,
mask,
num_steps,
from_logits=True,
cache=None,
token_probs=None,
):
"""Sampling logic implementation.
Because beam search uses a different loop body, we have to override the
whole `sample` method instead of just the `get_next_token` method.
"""
if cache is not None:
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
logging.warning(
"`BeamSampler` does not support cache decoding now, the cache "
"will be ignored. To use cache decoding, please use a "
"different sampler, e.g., `TopPSampler`."
)
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
length = max_length - num_steps
113 changes: 67 additions & 46 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@

import tensorflow as tf
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.utils.python_utils import format_docstring

@@ -39,6 +40,9 @@
from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
returns logits. If False, `token_probability_fn` returns probability
distributions.
cache: a dense int tensor, a cache of intermediate key and value tensor
computed by each decoder self-attention layer. These values will only
be computed once for each new token in the generated sequence.
"""


@@ -219,6 +223,7 @@ def __call__(
mask=None,
end_token_id=None,
from_logits=True,
cache=None,
):
prompt, mask = self._validate_prompt_and_mask(prompt, mask)
input_is_1d = prompt.shape.rank == 1
@@ -244,9 +249,9 @@ def __call__(
token_probability_fn,
mask,
max_length - shortest_prompt_len,
from_logits,
cache=cache,
from_logits=from_logits,
)

# Mask out tokens after `end_token_id`.
if end_token_id is not None:
prompt = self._mask_tokens_after_end_token(
@@ -271,7 +276,13 @@ def get_next_token(self, next_token_probs):
raise NotImplementedError

def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
self,
prompt,
token_probability_fn,
mask,
num_steps,
from_logits=True,
cache=None,
):
"""Sampling logic implementation.
@@ -286,73 +297,83 @@ def sample(
from_logits: bool, defaults to True. Indicate if the
`token_probability_fn` returns logits. If False,
`token_probability_fn` returns probability distributions.
cache: a dense int tensor, the cache used in decoding. The cache
stores the key and value of each
`keras_nlp.layers.CachedMultiHeadAttention` layer to make the
decoding faster by avoiding duplicated computation.
Returns:
A dense int Tensor, representing the generated text in token id
space.
"""
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
num_steps = tf.cast(num_steps, tf.int32)
max_length = tf.cast(max_length, tf.int32)
# The index of the last non-padding token in prompt. Since all sequences
# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps

def one_step(current_index, prompt, mask):
probs = token_probability_fn(prompt, mask)
next_token_probs = tf.gather(
probs,
tf.repeat(current_index - 1, batch_size),
axis=1,
batch_dims=1,
)
def one_step(
current_index,
prompt,
mask,
cache=None,
):
last_index = current_index - 1
if cache is not None:
probs, cache = token_probability_fn(
prompt,
mask,
cache=cache,
cache_index=last_index,
)
next_token_probs = tf.squeeze(probs, axis=1)
else:
probs = token_probability_fn(prompt, mask)
next_token_probs = tf.gather(
probs,
tf.repeat(current_index - 1, batch_size),
axis=1,
batch_dims=1,
)

if from_logits:
next_token_probs = keras.activations.softmax(
next_token_probs, axis=-1
)
next_token = self.get_next_token(next_token_probs)
next_token = tf.cast(next_token, prompt.dtype)
next_token = tf.where(
mask[:, current_index], prompt[:, current_index], next_token
mask[:, current_index],
prompt[:, current_index],
next_token,
)
mask = tf.tensor_scatter_nd_update(
tensor=mask,
indices=tf.stack(
(
tf.cast(
tf.range(batch_size), dtype=current_index.dtype
),
tf.repeat(current_index, batch_size),
),
axis=1,
),
updates=tf.repeat(True, batch_size),
)

# Append the next token to current sequence.
prompt = tf.tensor_scatter_nd_update(
tensor=prompt,
indices=tf.stack(
(
tf.cast(
tf.range(batch_size), dtype=current_index.dtype
),
tf.repeat(current_index, batch_size),
),
axis=1,
next_token = next_token[:, tf.newaxis]
next_mask = tf.fill([batch_size, 1], True)
slice_start = [0, current_index]
mask = dynamic_update_slice(mask, next_mask, slice_start)
prompt = dynamic_update_slice(prompt, next_token, slice_start)
current_index = tf.add(current_index, 1)
if cache is None:
return current_index, prompt, mask
return [current_index, prompt, mask, cache]

if cache is None:
current_index, prompt, mask = tf.while_loop(
cond=lambda current_index, prompt, mask: tf.less(
current_index, max_length
),
updates=next_token,
body=one_step,
loop_vars=[current_index, prompt, mask],
)

current_index = tf.add(current_index, 1)
return (current_index, prompt, mask)

return prompt
# Run a while loop till `max_length` of tokens has been generated.
current_index, prompt, mask = tf.while_loop(
cond=lambda current_index, prompt, mask: tf.less(
current_index, prompt, mask, cache = tf.while_loop(
cond=lambda current_index, prompt, mask, cache: tf.less(
current_index, max_length
),
body=one_step,
loop_vars=(current_index, prompt, mask),
loop_vars=[current_index, prompt, mask, cache],
)
return prompt