Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProtT5 model training loss function #164

Open
Alex2975 opened this issue Dec 11, 2024 · 5 comments
Open

ProtT5 model training loss function #164

Alex2975 opened this issue Dec 11, 2024 · 5 comments

Comments

@Alex2975
Copy link

Dear Authors,

Thank you very much for the great work. I got a question and would appreciate your insights.

For the ProtT5 training, since it will predict the full sequence, not just the masked tokens. What is the loss function for the ProtT5 training, is it Torch crossentropyloss with reduction=SUM, or it is Torch crossentropyloss with reduction=MEAN?

@mheinzinger
Copy link
Collaborator

I would always recommend to use mean as you want the loss to be independent of the number of tokens you have in your batch

@Alex2975
Copy link
Author

Alex2975 commented Jan 3, 2025

Thank you for the insights, @mheinzinger . If I want to compute the perplexity of a protein sequence from the ProtT5 model, how do I do it? Since it is using MLM objective, I think the following will not work. Could you please share some insights?

inputs = tokenizer(sequence, return_tensors="pt")["input_ids"]
labels = tokenizer(sequence, return_tensors="pt")["input_ids"]

Get outputs

outputs = model(**inputs, labels=labels)
loss = outputs.loss

Compute perplexity

perplexity = torch.exp(loss).item()
print(f"Perplexity: {perplexity}")

@mheinzinger
Copy link
Collaborator

Indeed, the above won't work. Especially, as perplexity is ill-defined for models trained via MLM.
I guess the best you can get is the pseudo perplexity where one masks one token at a time, reconstructs it, compute loss against groundtruth, and repeat for the full sequence before averaging & taking exponent.
We have some implementation for step-wise masking of ProtT5 here, maybe this helps.

@Alex2975
Copy link
Author

Alex2975 commented Jan 7, 2025

Thank you so much for the tips, @mheinzinger . I found how the "masks one token at a time" from the link you shared. Will try that.

@mheinzinger
Copy link
Collaborator

Just in case; we now have an example for continuing pre-training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants