Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Address comments: move token_probability_fn to the second place
Browse files Browse the repository at this point in the history
chenmoneygithub committed Jan 9, 2023
1 parent 8323153 commit caa6754
Showing 4 changed files with 23 additions and 26 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ def get(identifier):
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.
>>> cfg = {'class_name': 'Greedy', 'config': {}}
>>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)
In the case that the `identifier` is a class, this method will return a new
4 changes: 2 additions & 2 deletions keras_nlp/samplers/greedy.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ def token_probability_fn(inputs, mask):
sampler = keras_nlp.samplers.Greedy()
# Print the generated sequence (token ids).
print(sampler(token_probability_fn, prompt, 10))
print(sampler(prompt, token_probability_fn, 10))
```
"""

@@ -79,7 +79,7 @@ def __init__(

@format_docstring(sample_args=sample_args_docstring)
def sample(
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
"""Sampling logic implementation.
16 changes: 8 additions & 8 deletions keras_nlp/samplers/greedy_test.py
Original file line number Diff line number Diff line change
@@ -48,17 +48,17 @@ def token_probability_fn(inputs, mask):

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
@@ -71,7 +71,7 @@ def token_probability_fn(inputs, mask):
return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1)

inputs = tf.ragged.constant([[1], [2, 1, 2]])
outputs = self.sampler(token_probability_fn, inputs, max_length)
outputs = self.sampler(inputs, token_probability_fn, max_length)
self.assertEqual(outputs.shape, [2, 5])

def test_assert_generation_is_correct(self):
@@ -86,7 +86,7 @@ def token_probability_fn(inputs, mask):

inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
outputs = self.sampler(
token_probability_fn, inputs, max_length=max_length
inputs, token_probability_fn, max_length=max_length
)
self.assertAllEqual(
outputs, 3 * tf.ones(shape=[batch_size, max_length])
@@ -105,8 +105,8 @@ def token_probability_fn(inputs, mask):
sampler = Greedy()
inputs = tf.constant([[0, 1], [1, 2]])
outputs = sampler(
token_probability_fn,
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
)
@@ -117,12 +117,12 @@ def test_compare_xla_noxla_results(self):
inputs = [[1], [1]]
xla_sampler = Greedy(jit_compile=True)
outputs_xla = xla_sampler(
self.token_probability_fn, inputs, max_length=5
inputs, self.token_probability_fn, max_length=5
)

xla_sampler = Greedy(jit_compile=False)
outputs_no_xla = xla_sampler(
self.token_probability_fn, inputs, max_length=5
inputs, self.token_probability_fn, max_length=5
)

self.assertAllEqual(outputs_xla, outputs_no_xla)
27 changes: 12 additions & 15 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
@@ -23,12 +23,12 @@
"""

call_args_docstring = """
token_probability_fn: a function that generates the probability of
the next token over the whole vocabulary for each input token.
prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
initial tokens to append generated tokens.
token_probability_fn: a function that generates the probability of
the next token over the whole vocabulary for each input token.
max_length: int. The max length of generated sequence.
padding_mask: a tensor, defaults to None. The padding mask of the prompt.
mask: a tensor, defaults to None. The padding mask of the prompt.
end_token_id: int, defaults to None. The token marking the end of the
sequence, once encountered the generation is finished for the exact
sequence. If None, every sequence is generated up to `max_length`.
@@ -40,10 +40,10 @@
"""

sample_args_docstring = """
token_probability_fn: a function that generates the probability of
the next token over the whole vocabulary for each input token.
prompt: a dense int Tensor of shape [batch_size, max_length]. The
placeholder for generated sequence.
token_probability_fn: a function that generates the probability of
the next token over the whole vocabulary for each input token.
mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
prompt.
num_steps: int. The remaining number of tokens to generate.
@@ -100,7 +100,7 @@ def token_probability_fn(inputs, mask):
sampler = keras_nlp.samplers.Greedy()
# Print the generated sequence (token ids).
print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID))
print(sampler(prompt, token_probability_fn, 10, end_token_id=END_ID))
```
Use with string inputs:
@@ -131,8 +131,8 @@ def token_probability_fn(inputs, mask):
prompt = tokenizer("the quick brown fox")
sampler = keras_nlp.samplers.Greedy()
generated = sampler(
token_probability_fn,
prompt,
token_probability_fn,
10,
end_token_id=tokenizer.token_to_id("[END]")
)
@@ -210,17 +210,14 @@ def _mask_tokens_after_end_token(

def __call__(
self,
token_probability_fn,
prompt,
token_probability_fn,
max_length,
padding_mask=None,
mask=None,
end_token_id=None,
from_logits=True,
):
prompt, padding_mask = self._validate_prompt_and_mask(
prompt,
padding_mask,
)
prompt, mask = self._validate_prompt_and_mask(prompt, mask)

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
@@ -238,8 +235,8 @@ def __call__(
# `jit_compile` accordingly.
sample = tf.function(self.sample, jit_compile=self.jit_compile)
prompt = sample(
token_probability_fn,
prompt,
token_probability_fn,
mask,
max_length - shortest_prompt_len,
from_logits,
@@ -257,7 +254,7 @@ def __call__(

@format_docstring(sample_args=sample_args_docstring)
def sample(
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
"""Sampling logic implementation.

0 comments on commit caa6754

Please sign in to comment.