-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Solve #721 Deberta masklm model #732
Conversation
@Plutone11011 thanks! Will take a pass soon.
Can you elaborate? If it's all preprocessing, like this comment, I think it is fine to cover the fancier token masking schemes down the road. If this is a whole different setup, or different "head" on the backbone, we might want to consider some changes. |
In https://arxiv.org/pdf/2111.09543.pdf the authors talk about a replacement for MLM called RTD (replaced token detection), this is what mainly distinguishes DeBERTaV3 from previous versions. It is still somewhat based on MLM, but it uses a GAN style approach, jointly training a MLM generator and a discriminator/classifier. See section Section 3.1 and 2.3.2 of the paper, specifically. |
Ah right! DeBERTaV3 uses electra style (GAN like) pre-training. I think it is still totally valid to ship a mlm task for deberta, but we can probably make a note in the docstring that this is not the "task setup" used by deberta during pre-training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great! Left a few comments
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. The underlying model is provided by a | ||
third party and subject to a separate license, available | ||
[here](https://github.com/facebookresearch/fairseq). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switch this to the deberta repo
outputs = MaskedLMHead( | ||
vocabulary_size=backbone.vocabulary_size, | ||
embedding_weights=backbone.token_embedding.embeddings, | ||
intermediate_activation="gelu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably use the same "approximate gelu" as the backbone itself here-> https://github.com/keras-team/keras-nlp/blob/4c1c6ae9e5a3adcf80271ba206b6835caf1b39f2/keras_nlp/models/deberta_v3/deberta_v3_backbone.py#L155
Examples: | ||
```python | ||
# Load the preprocessor from a preset. | ||
preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset("deberta_v3_base_en") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format this to fit our lint limit
preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset(
"deberta_v3_base_en",
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format script is not detecting it, is it because it's docstrings?
@@ -72,6 +72,7 @@ def __init__(self, proto, **kwargs): | |||
cls_token = "[CLS]" | |||
sep_token = "[SEP]" | |||
pad_token = "[PAD]" | |||
mask_token = "[MASK]" | |||
for token in [cls_token, pad_token, sep_token]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a check for the mask token here, which might also mean you need to update some unit tests for the preprocessor and tokenizer layers (so they add a mask token to the vocabulary).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I added the check and as you can see it fails the preset tests, because apparently there isn't a [MASK]
token in the deberta_v3_extra_small_en
vocabulary, which seems strange to me (I've tried with deberta_v3_base_en
too). Any advice?
bos_piece="[CLS]", | ||
eos_piece="[SEP]", | ||
unk_piece="[UNK]", | ||
user_defined_symbols="[MASK]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! should this not be a list? is it just a comma separate string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it seems to be a comma separated string.
From the options doc
--user_defined_symbols (comma separated list of user defined symbols) type: std::string default: ""
proto = bytes_io.getvalue() | ||
self.preprocessor = DebertaV3MaskedLMPreprocessor( | ||
tokenizer=DebertaV3Tokenizer(proto=proto), | ||
# Simplify out testing by masking every available token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this comment applies here.
Addressed the comments, also added mask check in tokenizer and I had to adjust vocab size and some tests as a result |
0aa36db
to
8dfd375
Compare
@Plutone11011 thanks! Overall this is looking good, but hit a snag while trying to test this. Here's a gist -> https://colab.research.google.com/gist/mattdangerw/550ca0fc007579353ec7d0f11ebee03b/deberta-masked-lm.ipynb Essentially the problem is the We might need a to do a little digging here. What token id is assigned for Once we figure out how the original implementation handle this, we can figure out what changes we need to make here. |
Everything looking good here, we can just merge this with #759 when it is ready. We can't merge before without breaking the main model usages. |
Ok #759 is merged, so we should be able to rebase this and get things working. You can use the colab I linked above to test things out (I recommend a GPU runtime). |
095d46d
to
07115f8
Compare
@Plutone11011 is this ready for review again? If so can take a pass tomorrow. |
Yes, I've checked the notebook, training and preprocessor work, there is however still a problem when calling |
@mattdangerw in the colab notebook the call to |
@Plutone11011 thanks for checking this out! IMO we don't need to block on the detokenize functionality (it is not critical to any MLM workflow), but it could be worth handling that in a follow up. I'll take a pass over the code again shortly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! This was a bit of a journey! All looks good to me.
PR for DeBERTa mask model and preprocessor.
Related to this, I noticed that DeBERTaV3 uses a different technique for masking, however the backbone explicitly doesn't include it so I don't know if it is relevant here.