Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keras_nlp.samplers #563

Merged

Conversation

chenmoneygithub
Copy link
Contributor

@chenmoneygithub chenmoneygithub commented Dec 10, 2022

Base sampler + GreedySampler

For how this sampler is actually used, please see this link: https://colab.sandbox.google.com/gist/chenmoneygithub/5a1204a1888b1b56e37c3fa8101044b3/gpt2-generation-github-branch-version.ipynb

And here is a demo PR: #592

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chenmoneygithub! Some early comments:

  • Are we making a new keras_nlp.samplers API? Cool!
  • Why aren't we importing utils/text_generation.py?
  • You can use @mattdangerw's cool docstring decorator for formatting (link)

keras_nlp/samplers/greedy_sampler.py Outdated Show resolved Hide resolved
@chenmoneygithub
Copy link
Contributor Author

@jbischof Thanks!

Are we making a new keras_nlp.samplers API? Cool!

Yes! It's cleaner and more visible than our current offerings, and more importantly, it suits the generation task of pretrained models better.

Why aren't we importing utils/text_generation.py?

Currently we have a few standalone functions, but wrapping them into class suit the generation task better, and also is more readable and cleaner. Talking to Francois, we want these samplers to have their own namespace, that's why we do the move.

You can use @mattdangerw's cool docstring decorator for formatting (link)

Thanks!

@mattdangerw
Copy link
Member

Thanks! Will take a proper pass soon!

Re docstrings, I would say we should just replicate the repeated bits of docstring everywhere, so remove all the fancy formatting. The reason to do a format docstring is where we have something programatic we need to get in the docstrings. But otherwise I think Keras as a whole would favor the simplicity of easily readable docstring blocks over the brevity of abstracting out certain argument sections.

@mattdangerw
Copy link
Member

Some high level thoughts on this...

1. Usage

I am somewhat tempted to keep the samplers as generic as possible and not require any knowledge of special tokens. After all, there are no real standards here (e.g. GPT2 does not have a pad token really).

For inputs we could assume either a ragged input or ask users to pass their own mask. For outputs we could just return the entire sequence_length sequence, and if user wants to then strip all tokens after the first eos token, that could be done as a post process. Then we could just expose the "core" sampling algorithm in as generic a way as possible, while still covering an end-to-end workflow via generate(). Curious peoples thoughts on this.

2. Naming

Should we ditch Sampler in the class name? I find this pretty readable and more concise...

keras_nlp.samplers.Greedy()
keras_nlp.samplers.Beam()
keras_nlp.samplers.Random()
keras_nlp.samplers.TopK()

We only did this repetitive naming for tokenizers because some of the class names looked weird, e.g. tokenizers.Byte. But I don't think we have that problem here! cc @fchollet in case any thoughts.

3. Serialization

I know we will have no direct need to serialize these in our own usage, but registering these objects as serializable and supporting from_config() seems like a consistent thing to do at a library level.

That is, even though we will not save these directly on a model, someone else should be able to for their own custom code (same as keras.optimizers, keras.activations, keras.initializers). Part of exposing these as a top-level keras_nlp.samplers to me means we should conform to Keras wide expectations around serialization.

keras_nlp/samplers/greedy_sampler.py Outdated Show resolved Hide resolved


class Sampler:
"""Base sampler class.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might document this class explaining how it could be subclasses, similar to Tokenizer. Even though we are not exposing it right now, we could in the future, and explaining it's intended usage would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this as a followup after we are satisfied with the API design. At this time we are still making API changes, so the guide will be unreliable.

Copy link
Member

@mattdangerw mattdangerw Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! But I would just cut this docstring down then. If we aren't ready to describe subclassing, we probably should not put the base class in the release (which seems totally fine).

So we should focus our actual runnable example documentation of the exported symbols themselves.

keras_nlp/samplers/greedy_sampler.py Outdated Show resolved Hide resolved
keras_nlp/samplers/greedy_sampler.py Outdated Show resolved Hide resolved
keras_nlp/samplers/greedy_sampler.py Outdated Show resolved Hide resolved

Examples:
```python
BATCH_SIZE = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see most of these params used more than once so you can inline them to match our other docstrings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just one style - some people prefer to define macros in one place, this is the same as our util docstring, e.g., link

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils.text_generation has a lot of style weirdness I've been cleaning up. I'd prefer the args inlined

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your idea! I am open to both solutions, sharing my opinion here: to me it's a bit different - defining all constants/hypers before usage is cleaner to me, like in our guide: https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/#setup, where EPOCHS is only used once.

):
super().__init__(end_token_id, pad_token_id, jit_compile)

def sample(self, token_probability_fn, prompt, mask, num_steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method call our util functions? Otherwise we should probably delete them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I am going to delete those utils and only keep the sampler class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not delete as part of this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually I prefer a pure clean-up PR to remove all legacy code to keep the history clean.

@chenmoneygithub
Copy link
Contributor Author

@mattdangerw Thx for the comments! Sharing my thoughts:

  1. Usage

My take is truncating by "end_token" is pretty common in real applications, so I would love to have the generator/generate() to have the truncating functionality.

  1. Naming

Yes, I have struggled a bit about if to keep Sampler in the name. On the one hand I feel "GreedySampler" sounds better than "Greedy", since the latter one is a bit weird by itself. On the other hand I agree "Sampler" is a duplication.

The short answer is I am not sure, we can leave the naming call for Francois, and focus on other parts for now I guess. It's easy to do the renaming.

  1. Serialization

I was also hesitating on this one. Our proposed usage is this "Sampler" class is only to be used with method "generate()", while optimizer/initializer/activations are part of the model serialization. Also users can use the restored optimizer/initializer/activations directly, but they always need to reconstruct a "Sampler" class.

Also I am curious about how "Sampler" serialization could happen - which code will be calling the get_config method? For optimizer, there are custom saving code in core Keras handling it, my guess it's similar for activations/initializers. If we simply attach the "Sampler" to a model, will saving be automatically handled?

@mattdangerw
Copy link
Member

mattdangerw commented Jan 4, 2023

My take is truncating by "end_token" is pretty common in real applications.

We should definitely have a way to do it. I am ok to have it as a call time argument, where if it is passed, we return a truncated ragged, and if it is not passed, we return a dense. I worry that this polymorphic return might be a little confusing.

Another possibility is to handle that during detokenization? Then we could always return a simple dense with sequence_length shape. But IDK!

token_ids = sampler(token_ids, mask, prob_fn)
tokenizer.detokenize(token_ids, truncate_after=tokenizer.token_to_id("<eos>"))

Padding token is also interesting to me. If we are taking in a dense, and passing both a prompt and a mask to the callback, it seems frankly clearer to ask a user to pass this in.

def next_token_fn(prompt, mask):
   return # Do stuff

prompt = tokenizer(inputs)
mask = prompt != tokenizer.token_to_id("<pad>")
# These seems quite readable, in that the sampler and callback have similar signatures!
sampler(prompt, mask)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left another pass on mostly minor comments!

keras_nlp/samplers/__init__.py Outdated Show resolved Hide resolved
keras_nlp/samplers/__init__.py Outdated Show resolved Hide resolved
keras_nlp/samplers/__init__.py Outdated Show resolved Hide resolved
keras_nlp/samplers/greedy.py Show resolved Hide resolved

def __init__(
self,
jit_compile=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give that we talked about moving compilation to model.generate, I think we can remove all the jit_compile stuff here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot unfortunately, not all of the sampler's __call__ is XLA-compatible. The prompt preprocessing part, because it changes the shape, is not XLA-compatible. The XLA-compatible part is the sample method, which is the part actually benefits from XLA.

keras_nlp/samplers/greedy.py Show resolved Hide resolved


class Sampler:
"""Base sampler class.
Copy link
Member

@mattdangerw mattdangerw Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! But I would just cut this docstring down then. If we aren't ready to describe subclassing, we probably should not put the base class in the release (which seems totally fine).

So we should focus our actual runnable example documentation of the exported symbols themselves.

keras_nlp/samplers/sampler.py Outdated Show resolved Hide resolved

# Convert `sample` method to a `tf.function`, and turn on
# `jit_compile` accordingly.
sample = tf.function(self.sample, jit_compile=self.jit_compile)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove compilation

Copy link
Contributor Author

@chenmoneygithub chenmoneygithub Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Let's keep the tf.function on sample() for now.

keras_nlp/samplers/sampler_test.py Outdated Show resolved Hide resolved
keras_nlp/samplers/__init__.py Outdated Show resolved Hide resolved

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

sampler = keras_nlp.samplers.Greedy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Now that we're introducing a lot of abstract base classes (ones no one should use directly), we may want to start noting at the top of docstrings e.g., "An abstract base class for samplers." I could imagine the same for Backbone and Preprocessor

# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps

def one_step(current_index, prompt, mask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out one_step like train_step in keras.Model? Might improve encapsulation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a good idea, the issue is the customization of child samplers classes happen at sample() level, because before the while_loop, there are some custom code required, e.g., beam needs to constructs the beam before the loop.

We can expose an abstract method sample_step() in base class as well, in which case we will have two abstract methods sample() and sample_step() to override, which looks a bit redundant.

keras_nlp/samplers/sampler.py Show resolved Hide resolved
keras_nlp/samplers/sampler.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval from me, granted we might need to change how we handle compilation down the road pending some more discussion.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

@chenmoneygithub chenmoneygithub merged commit 37add8b into keras-team:master Jan 10, 2023
@chenmoneygithub chenmoneygithub deleted the text-generation-class branch February 2, 2023 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants