Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve #721 Deberta masklm model #732

Merged
merged 7 commits into from
Mar 3, 2023

Conversation

Plutone11011
Copy link
Contributor

PR for DeBERTa mask model and preprocessor.

Related to this, I noticed that DeBERTaV3 uses a different technique for masking, however the backbone explicitly doesn't include it so I don't know if it is relevant here.

@Plutone11011 Plutone11011 changed the title #721 Deberta masklm model Solve #721 Deberta masklm model Feb 9, 2023
@mattdangerw
Copy link
Member

@Plutone11011 thanks! Will take a pass soon.

I noticed that DeBERTaV3 uses a different technique for masking

Can you elaborate? If it's all preprocessing, like this comment, I think it is fine to cover the fancier token masking schemes down the road.

If this is a whole different setup, or different "head" on the backbone, we might want to consider some changes.

@Plutone11011
Copy link
Contributor Author

@Plutone11011 thanks! Will take a pass soon.

I noticed that DeBERTaV3 uses a different technique for masking

Can you elaborate? If it's all preprocessing, like this comment, I think it is fine to cover the fancier token masking schemes down the road.

If this is a whole different setup, or different "head" on the backbone, we might want to consider some changes.

In https://arxiv.org/pdf/2111.09543.pdf the authors talk about a replacement for MLM called RTD (replaced token detection), this is what mainly distinguishes DeBERTaV3 from previous versions. It is still somewhat based on MLM, but it uses a GAN style approach, jointly training a MLM generator and a discriminator/classifier. See section Section 3.1 and 2.3.2 of the paper, specifically.

@mattdangerw
Copy link
Member

Ah right! DeBERTaV3 uses electra style (GAN like) pre-training.

I think it is still totally valid to ship a mlm task for deberta, but we can probably make a note in the docstring that this is not the "task setup" used by deberta during pre-training.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great! Left a few comments

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/fairseq).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch this to the deberta repo

outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
embedding_weights=backbone.token_embedding.embeddings,
intermediate_activation="gelu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset("deberta_v3_base_en")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format this to fit our lint limit

preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset(
    "deberta_v3_base_en",
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format script is not detecting it, is it because it's docstrings?

@@ -72,6 +72,7 @@ def __init__(self, proto, **kwargs):
cls_token = "[CLS]"
sep_token = "[SEP]"
pad_token = "[PAD]"
mask_token = "[MASK]"
for token in [cls_token, pad_token, sep_token]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a check for the mask token here, which might also mean you need to update some unit tests for the preprocessor and tokenizer layers (so they add a mask token to the vocabulary).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I added the check and as you can see it fails the preset tests, because apparently there isn't a [MASK] token in the deberta_v3_extra_small_en vocabulary, which seems strange to me (I've tried with deberta_v3_base_en too). Any advice?

bos_piece="[CLS]",
eos_piece="[SEP]",
unk_piece="[UNK]",
user_defined_symbols="[MASK]",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! should this not be a list? is it just a comma separate string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it seems to be a comma separated string.
From the options doc
--user_defined_symbols (comma separated list of user defined symbols) type: std::string default: ""

proto = bytes_io.getvalue()
self.preprocessor = DebertaV3MaskedLMPreprocessor(
tokenizer=DebertaV3Tokenizer(proto=proto),
# Simplify out testing by masking every available token.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment applies here.

@Plutone11011
Copy link
Contributor Author

Addressed the comments, also added mask check in tokenizer and I had to adjust vocab size and some tests as a result

@mattdangerw
Copy link
Member

@Plutone11011 thanks! Overall this is looking good, but hit a snag while trying to test this.

Here's a gist -> https://colab.research.google.com/gist/mattdangerw/550ca0fc007579353ec7d0f11ebee03b/deberta-masked-lm.ipynb

Essentially the problem is the [MASK] token does not appear in our version of the debertav3 vocabulary. Looking at the upstream code, it looks like they do have support for a mask token when using sentencepiece, but that it may be layered on top of sentencepiece itself -> https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/deberta/spm_tokenizer.py#L43

We might need a to do a little digging here. What token id is assigned for [MASK] in the upstream implementation? For huggingface it looks like it is appended as a final token, but would that mean we are attempting to do a embedding lookup that falls outside of our embedding size?

Once we figure out how the original implementation handle this, we can figure out what changes we need to make here.

@mattdangerw
Copy link
Member

Everything looking good here, we can just merge this with #759 when it is ready. We can't merge before without breaking the main model usages.

@mattdangerw
Copy link
Member

Ok #759 is merged, so we should be able to rebase this and get things working. You can use the colab I linked above to test things out (I recommend a GPU runtime).

@mattdangerw
Copy link
Member

@Plutone11011 is this ready for review again? If so can take a pass tomorrow.

@Plutone11011
Copy link
Contributor Author

@Plutone11011 is this ready for review again? If so can take a pass tomorrow.

Yes, I've checked the notebook, training and preprocessor work, there is however still a problem when calling detokenize that I haven't delved into, it doesn't find the [MASK] id.

@Plutone11011
Copy link
Contributor Author

Plutone11011 commented Mar 1, 2023

@mattdangerw in the colab notebook the call to detokenize yields an OutOfRangeError, invalid id 128000, basically unable to find the [MASK] id. I haven't followed closely #759 but I guess this is due to the fact that the [MASK] token is handled internally by DebertaV3Tokenizer in self.mask_token_id, whereas detokenize calls the SentencePiece method. Do you think we should implement a detokenize method inside DebertaV3Tokenizer?

@mattdangerw
Copy link
Member

@Plutone11011 thanks for checking this out! IMO we don't need to block on the detokenize functionality (it is not critical to any MLM workflow), but it could be worth handling that in a follow up.

I'll take a pass over the code again shortly.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This was a bit of a journey! All looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants