Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BART Backbone #661

Merged
merged 21 commits into from
Jan 18, 2023
Merged

Add BART Backbone #661

merged 21 commits into from
Jan 18, 2023

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jan 14, 2023

@abheesht17 abheesht17 changed the title Add BartBackbone Add BART Backbone Jan 14, 2023
Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great, some quick comments. What is going on with our TransformerDecoder!?

@@ -165,7 +165,7 @@ def _build(self, input_shape, has_cross_attention):
self._cross_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=hidden_dim,
value_dim=head_dim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoa, did you find a bug here? If so, let's discuss and break this into a different PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is possibly a bug. The thing is that we haven't actually used the cross-attention layer for any of our models (we have used TransformerDecoder for GPT-2, but since it is a decoder-only model, we don't use the cross-attention layer) so far...so, this bug escaped our attention xD. Discussed this with @mattdangerw on Friday. I'll open a separate PR for this, we might want to patch it up to 0.4.0 ASAP, I suppose?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is merged, so let's rebase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

keras_nlp/models/bart/bart_backbone.py Outdated Show resolved Hide resolved
keras_nlp/models/bart/bart_backbone.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some quick initial comments!



@keras.utils.register_keras_serializable(package="keras_nlp")
class BartBackbone(Backbone):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish they had chosen a name that wasn't so visually similar to Bert. This is going to confuse me so much :P

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish they explained what BART stands for....

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BERT -> Bidirectional Encoder Representations from Transformers
BART -> Bidirectional Auto-encoder Representations from Transformers?

Guessing...:P

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Couldn't find it in the paper

)

# Token embedding layer. This layer is shared by encoder and decoder.
token_embedding_layer = keras.layers.Embedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When BART is used as a language model, are the embedding weights shared for the output projection? Or is there a separate set of parameters used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't checkpoint conversion confirm this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, confirmed in the Colab notebook

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Great, this is the easy case for us, where the backbone is self sufficient for language modeling tasks. We have prior art for this with BERT pretraining and Chen's upcoming GPT2 language model.

keras_nlp/models/bart/bart_backbone.py Show resolved Hide resolved
Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw, replied to the comment about embedding layers.



@keras.utils.register_keras_serializable(package="keras_nlp")
class BartBackbone(Backbone):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂

)

# Token embedding layer. This layer is shared by encoder and decoder.
token_embedding_layer = keras.layers.Embedding(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keras_nlp/models/bart/bart_backbone.py Show resolved Hide resolved
Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to rebase on #667 and we should be able to merge

@@ -165,7 +165,7 @@ def _build(self, input_shape, has_cross_attention):
self._cross_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=hidden_dim,
value_dim=head_dim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is merged, so let's rebase

This class implements a Transformer-based encoder-decoder model as
described in
["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461).
It includes the embedding lookups and transformer layers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, just stuck to what's there in other backbone models. Can remove

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think BART has any "pretraining head". For all pretraining tasks, it just autoregressively generates/recovers the original input from the denoised input.

For other models, the whole point of this line was to emphasise that the model class does not have pretraining heads. So, yeah having just It includes the embedding lookups and transformer layers. kinda seems repetitive for BART, and can be removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to remove!

FWIW, there will still be a pretraining head, even if that head has no parameters. You still need to map from the dense hidden_dim output to LM logits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah. Brain fart, indeed 🤦🏼‍♂️

)

# Token embedding layer. This layer is shared by encoder and decoder.
token_embedding_layer = keras.layers.Embedding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't checkpoint conversion confirm this?

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments. I can fix as I merge.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/fairseq/tree/main/examples/bart).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should actually link to the location with the LICENSE file here, which is the base of the repo.

),
"decoder_token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"decoder_padding_mask": tf.constant(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to show these padding masks can in fact be different. (Just switch to 0s at a different point.)


# Embed tokens and positions.
token_embedding = token_embedding_layer(encoder_token_id_input)
position_embedding = PositionEmbedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment explaining that the position embedding is not shared, but the token embedding is.

@mattdangerw mattdangerw merged commit c9e5040 into keras-team:master Jan 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a BART backbone
3 participants