Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generic torchmetric ppl logging on esm2 #557

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

sichu2023
Copy link
Collaborator

@sichu2023 sichu2023 commented Dec 28, 2024

Summary

Logging perplexity through torchmetric.text.Perplexity

Details

Pipeline parallel last stage is handled in training_step and validation_step. Generic torchmetric seems to handle tensor parallelism without any noticeable difference, where the logits are already provided in training/validation_step and no backward pass is needed.

However, valid_ppl does not seem to work properly.

Usage

python sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py --log_train_ppl --log_val_ppl

W B Chart 12_28_2024, 12_09_19 PM
W B Chart 12_28_2024, 12_09_25 PM
W B Chart 12_28_2024, 12_12_35 PM
W B Chart 12_28_2024, 12_12_40 PM

@sichu2023 sichu2023 self-assigned this Dec 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant