-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keras_nlp.samplers #563
Add keras_nlp.samplers #563
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @chenmoneygithub! Some early comments:
- Are we making a new
keras_nlp.samplers
API? Cool! - Why aren't we importing
utils/text_generation.py
? - You can use @mattdangerw's cool docstring decorator for formatting (link)
@jbischof Thanks!
Yes! It's cleaner and more visible than our current offerings, and more importantly, it suits the generation task of pretrained models better.
Currently we have a few standalone functions, but wrapping them into class suit the generation task better, and also is more readable and cleaner. Talking to Francois, we want these samplers to have their own namespace, that's why we do the move.
Thanks! |
Thanks! Will take a proper pass soon! Re docstrings, I would say we should just replicate the repeated bits of docstring everywhere, so remove all the fancy formatting. The reason to do a format docstring is where we have something programatic we need to get in the docstrings. But otherwise I think Keras as a whole would favor the simplicity of easily readable docstring blocks over the brevity of abstracting out certain argument sections. |
f0a874c
to
6a8b59b
Compare
Some high level thoughts on this... 1. UsageI am somewhat tempted to keep the samplers as generic as possible and not require any knowledge of special tokens. After all, there are no real standards here (e.g. GPT2 does not have a pad token really). For inputs we could assume either a ragged input or ask users to pass their own mask. For outputs we could just return the entire 2. NamingShould we ditch
We only did this repetitive naming for tokenizers because some of the class names looked weird, e.g. 3. SerializationI know we will have no direct need to serialize these in our own usage, but registering these objects as serializable and supporting That is, even though we will not save these directly on a model, someone else should be able to for their own custom code (same as |
|
||
|
||
class Sampler: | ||
"""Base sampler class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might document this class explaining how it could be subclasses, similar to Tokenizer
. Even though we are not exposing it right now, we could in the future, and explaining it's intended usage would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do this as a followup after we are satisfied with the API design. At this time we are still making API changes, so the guide will be unreliable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! But I would just cut this docstring down then. If we aren't ready to describe subclassing, we probably should not put the base class in the release (which seems totally fine).
So we should focus our actual runnable example documentation of the exported symbols themselves.
keras_nlp/samplers/greedy_sampler.py
Outdated
|
||
Examples: | ||
```python | ||
BATCH_SIZE = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see most of these params used more than once so you can inline them to match our other docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just one style - some people prefer to define macros in one place, this is the same as our util docstring, e.g., link
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
utils.text_generation
has a lot of style weirdness I've been cleaning up. I'd prefer the args inlined
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your idea! I am open to both solutions, sharing my opinion here: to me it's a bit different - defining all constants/hypers before usage is cleaner to me, like in our guide: https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/#setup, where EPOCHS
is only used once.
keras_nlp/samplers/greedy_sampler.py
Outdated
): | ||
super().__init__(end_token_id, pad_token_id, jit_compile) | ||
|
||
def sample(self, token_probability_fn, prompt, mask, num_steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this method call our util
functions? Otherwise we should probably delete them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I am going to delete those utils and only keep the sampler class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not delete as part of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually I prefer a pure clean-up PR to remove all legacy code to keep the history clean.
@mattdangerw Thx for the comments! Sharing my thoughts:
My take is truncating by "end_token" is pretty common in real applications, so I would love to have the generator/
Yes, I have struggled a bit about if to keep The short answer is I am not sure, we can leave the naming call for Francois, and focus on other parts for now I guess. It's easy to do the renaming.
I was also hesitating on this one. Our proposed usage is this "Sampler" class is only to be used with method "generate()", while optimizer/initializer/activations are part of the model serialization. Also users can use the restored optimizer/initializer/activations directly, but they always need to reconstruct a "Sampler" class. Also I am curious about how "Sampler" serialization could happen - which code will be calling the |
We should definitely have a way to do it. I am ok to have it as a call time argument, where if it is passed, we return a truncated ragged, and if it is not passed, we return a dense. I worry that this polymorphic return might be a little confusing. Another possibility is to handle that during detokenization? Then we could always return a simple dense with sequence_length shape. But IDK! token_ids = sampler(token_ids, mask, prob_fn)
tokenizer.detokenize(token_ids, truncate_after=tokenizer.token_to_id("<eos>")) Padding token is also interesting to me. If we are taking in a dense, and passing both a def next_token_fn(prompt, mask):
return # Do stuff
prompt = tokenizer(inputs)
mask = prompt != tokenizer.token_to_id("<pad>")
# These seems quite readable, in that the sampler and callback have similar signatures!
sampler(prompt, mask) |
6a8b59b
to
65228f3
Compare
5e049ad
to
8a8f940
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left another pass on mostly minor comments!
|
||
def __init__( | ||
self, | ||
jit_compile=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
give that we talked about moving compilation to model.generate
, I think we can remove all the jit_compile
stuff here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we cannot unfortunately, not all of the sampler's __call__
is XLA-compatible. The prompt preprocessing part, because it changes the shape, is not XLA-compatible. The XLA-compatible part is the sample
method, which is the part actually benefits from XLA.
|
||
|
||
class Sampler: | ||
"""Base sampler class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! But I would just cut this docstring down then. If we aren't ready to describe subclassing, we probably should not put the base class in the release (which seems totally fine).
So we should focus our actual runnable example documentation of the exported symbols themselves.
|
||
# Convert `sample` method to a `tf.function`, and turn on | ||
# `jit_compile` accordingly. | ||
sample = tf.function(self.sample, jit_compile=self.jit_compile) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove compilation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. Let's keep the tf.function
on sample()
for now.
|
||
prompt = tf.fill((BATCH_SIZE, 1), START_ID) | ||
|
||
sampler = keras_nlp.samplers.Greedy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Now that we're introducing a lot of abstract base classes (ones no one should use directly), we may want to start noting at the top of docstrings e.g., "An abstract base class for samplers." I could imagine the same for Backbone
and Preprocessor
# are aligned to the right side, the index is the same for all. | ||
current_index = max_length - num_steps | ||
|
||
def one_step(current_index, prompt, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we factor out one_step
like train_step
in keras.Model
? Might improve encapsulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's actually a good idea, the issue is the customization of child samplers
classes happen at sample()
level, because before the while_loop, there are some custom code required, e.g., beam
needs to constructs the beam before the loop.
We can expose an abstract method sample_step()
in base class as well, in which case we will have two abstract methods sample()
and sample_step()
to override, which looks a bit redundant.
f84235f
to
caa6754
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approval from me, granted we might need to change how we handle compilation down the road pending some more discussion.
caa6754
to
bb430dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work!
Base sampler + GreedySampler
For how this sampler is actually used, please see this link: https://colab.sandbox.google.com/gist/chenmoneygithub/5a1204a1888b1b56e37c3fa8101044b3/gpt2-generation-github-branch-version.ipynb
And here is a demo PR: #592