Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix comments
Browse files Browse the repository at this point in the history
chenmoneygithub committed Feb 23, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8aa9113 commit cd37dec
Showing 5 changed files with 14 additions and 75 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/layers/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ def call(
value,
key=None,
attention_mask=None,
use_causal_mask=True,
use_causal_mask=False,
cache=None,
cache_index=None,
):
6 changes: 3 additions & 3 deletions keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
@@ -239,8 +239,8 @@ def call(
being processed. If `cache_index=None` while `cache` is set, it
means it's the first pass to build the cache.
Returns:
x: A Tensor of the same shape as the `decoder_sequence`.
cache: [Optional], a dense float Tensor. The updated `cache`.
Either a tuple of (outputs, cache) if a cache was passed, or a
single value outputs if a cache was not passed.
"""

has_encoder_sequence = encoder_sequence is not None
@@ -329,7 +329,7 @@ def call(

if cache is None:
return x
return x, cache
return (x, cache)

def get_config(self):
config = super().get_config()
44 changes: 0 additions & 44 deletions keras_nlp/models/gpt2/gpt2_backbone_test.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone

@@ -65,49 +64,6 @@ def test_variable_sequence_length_call_gpt2(self):
}
self.model(input_data)

def test_cache_call_is_correct(self):
initial_seq_len = 64
max_seq_len = 80
initial_inputs = {
"token_ids": tf.ones(
(self.batch_size, initial_seq_len), dtype="int32"
),
"padding_mask": tf.ones(
(self.batch_size, initial_seq_len), dtype="int32"
),
}
outputs, cache = self.model.build_initial_cache(
initial_inputs,
max_seq_len,
)

def call(i, cache, outputs):
def loop_body(i, cache, outputs):
# Compute the rest tokens.
output, cache = self.model.call_with_cache(
inputs=self.input_batch,
cache=cache,
cache_index=i,
)
outputs = dynamic_update_slice(outputs, output, [0, i, 0])
return i + 1, cache, outputs

i, cache, cached_outputs = tf.while_loop(
cond=lambda i, cache, outputs: i < max_seq_len,
body=loop_body,
loop_vars=[i, cache, outputs],
)
return cached_outputs

cached_outputs = call(initial_seq_len, cache, outputs)
graph_call = tf.function(call)
graph_cached_outputs = graph_call(initial_seq_len, cache, outputs)
normal_outputs = self.model(self.input_batch)
normal_outputs = normal_outputs[:, :max_seq_len, :]

self.assertAllClose(cached_outputs, normal_outputs)
self.assertAllClose(graph_cached_outputs, normal_outputs)

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
21 changes: 10 additions & 11 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -217,8 +217,9 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index=None):
represents the index of current inputs in the whole sequence.
Returns:
x: a dense float Tensor, the next token logits of `inputs`.
cache: a dense float Tensor, the updated cache.
A (logits, cache) tuple. Where the first output is the language
model logits for the input token_ids and the second output is the
cache.
"""
transformer_layers = []
for i in range(self.backbone.num_layers):
@@ -241,7 +242,7 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index=None):
x = embeddings_add((token_embedding, position_embedding))
x = embeddings_dropout(x)
# Each `TransformerDecoder` layer has a cache, we update separately.
caches = tf.unstack(cache, axis=0)
caches = tf.unstack(cache, axis=1)
for i, transformer_layer in enumerate(transformer_layers):
current_cache = caches[i]
x, current_cache = transformer_layer(
@@ -251,7 +252,7 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index=None):
cache_index=cache_index,
)
caches[i] = current_cache
cache = tf.stack(caches, axis=0)
cache = tf.stack(caches, axis=1)
x = layer_norm(x)
x = tf.matmul(
x,
@@ -274,8 +275,9 @@ def build_initial_cache(self, initial_inputs, max_length):
max_length: int, the max length of the generated sequence.
Returns:
cache: a dense float Tensor, the cache of key and value.
max_length: int, the max length of generated sequence.
A (logits, cache) tuple. Where the first output is the language
model logits for the input token_ids and the second output is the
cache.
"""
token_ids = initial_inputs["token_ids"]
padding_mask = initial_inputs["padding_mask"]
@@ -290,8 +292,8 @@ def build_initial_cache(self, initial_inputs, max_length):
)
cache = tf.zeros(
[
self.backbone.num_layers,
batch_size,
self.backbone.num_layers,
2,
max_length,
self.backbone.num_heads,
@@ -336,7 +338,6 @@ def generate(
self,
prompt,
max_length,
use_cache=True,
sampler="top_k",
):
"""Generate text.
@@ -350,14 +351,12 @@ def generate(
prompt: a string, string Tensor or string RaggedTensor. The prompt
text for generation.
max_length: int. The max length of generated sequence.
use_cache: bool, defaults to True. If True, cache will be used
during decoding, which increases the decoding speed.
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
sampler to be used for text generation.
"""
if self.preprocessor is None:
raise ValueError(
"Missing `self.preprocessor`, please make sure `preprocessor` "
"`self.preprocessor` is None, please make sure `preprocessor` "
"is set before calling `generate`."
)
end_token_id = self.preprocessor.tokenizer.end_token_id
16 changes: 0 additions & 16 deletions keras_nlp/models/gpt2/gpt2_causal_lm_test.py
Original file line number Diff line number Diff line change
@@ -142,22 +142,6 @@ def test_gpt2_causal_lm_generate(self, jit_compile, use_cache):
generated = generated.numpy().decode("utf-8")
self.assertTrue(prompt in generated)

def test_generate_same_with_cache(self):
self.causal_lm.compile(jit_compile=False)
cached_result = self.causal_lm.generate(
self.raw_batch,
max_length=10,
sampler="greedy",
use_cache=True,
)
non_cached_result = self.causal_lm.generate(
self.raw_batch,
max_length=10,
sampler="greedy",
use_cache=False,
)
self.assertAllEqual(cached_result, non_cached_result)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),

0 comments on commit cd37dec

Please sign in to comment.