Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESM2 Tutorial Updates #426

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
50551ef
move scripts
farhadrgh Nov 1, 2024
a4ff843
Merge branch 'main' into farhadr/refactor
farhadrgh Nov 6, 2024
08ca0f9
resolve conflicts
farhadrgh Nov 6, 2024
2eaf9af
move test file
farhadrgh Nov 6, 2024
dfcbb37
undo mv
farhadrgh Nov 7, 2024
557b62e
reformat into package
farhadrgh Nov 7, 2024
cb3116c
Merge branch 'main' into farhadr/refactor
farhadrgh Nov 7, 2024
306ee68
update docs
farhadrgh Nov 7, 2024
30695ba
Merge branch 'main' into farhadr/refactor
farhadrgh Nov 7, 2024
72cb566
Merge branch 'main' into farhadr/refactor
farhadrgh Nov 8, 2024
4bfae90
resolve conflicts
farhadrgh Nov 8, 2024
abb7b17
Merge branch 'farhadr/refactor' of https://github.com/NVIDIA/bionemo-…
farhadrgh Nov 8, 2024
f869901
add new infer api
farhadrgh Nov 12, 2024
d594ab1
Merge branch 'main' into farhadr/infer_docs
farhadrgh Nov 13, 2024
a6ca8f5
update
farhadrgh Nov 15, 2024
696a2b9
return mask
farhadrgh Nov 15, 2024
b7c649b
Merge branch 'main' into farhadr/infer_docs
farhadrgh Nov 15, 2024
ad25171
include input_ids
farhadrgh Nov 15, 2024
84545d3
AAlpha NoteBook
farhadrgh Nov 15, 2024
01c3201
check token_ids
farhadrgh Nov 18, 2024
8146dae
Merge branch 'main' into farhadr/infer_docs
farhadrgh Nov 18, 2024
071776b
fix tag
farhadrgh Nov 18, 2024
16080cd
alow partial batches
farhadrgh Nov 19, 2024
91fbd5e
Merge branch 'main' into farhadr/infer_docs
farhadrgh Nov 19, 2024
05c23b6
alow partial batches
farhadrgh Nov 19, 2024
fc35c75
Merge branch 'farhadr/infer_docs' of https://github.com/NVIDIA/bionem…
farhadrgh Nov 19, 2024
73e5e75
add link to brevdev lunchable
farhadrgh Nov 20, 2024
eb8ba90
add inference notebook
farhadrgh Nov 20, 2024
0c2151f
add launchable link
farhadrgh Nov 20, 2024
cf99920
update Note block
farhadrgh Nov 21, 2024
e5658f0
update
farhadrgh Nov 21, 2024
02b0261
Merge branch 'main' into farhadr/infer_docs
farhadrgh Nov 21, 2024
3e906ee
Update docs/docs/user-guide/examples/bionemo-esm2/finetune.md
farhadrgh Nov 21, 2024
7900708
add include_input_ids option
farhadrgh Nov 21, 2024
a76d1ce
Merge branch 'farhadr/infer_docs' of https://github.com/NVIDIA/bionem…
farhadrgh Nov 21, 2024
9d89373
add include_input_ids
farhadrgh Nov 21, 2024
d5b1d59
add include_input_ids
farhadrgh Nov 21, 2024
9d5b650
undo bump
farhadrgh Nov 21, 2024
2510eab
fix nbval failure
farhadrgh Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 54 additions & 30 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ data_module = ESM2FineTuneDataModule(

# Fine-Tuning the Regressor Task Head for ESM2

Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.
Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.

```python
# create a List[Tuple] with (sequence, target) values
Expand All @@ -174,33 +174,35 @@ data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]
dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)

with tempfile.TemporaryDirectory() as experiment_tempdir_name:
experiment_dir = Path(experiment_tempdir_name)
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42

config = ESM2FineTuneSeqConfig(
# initial_ckpt_path=str(pretrain_ckpt_path)
)
# To download a 650M pre-trained ESM2 model
pretrain_ckpt_path = load("esm2/650m:2.0")

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=experiment_dir, # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
config = ESM2FineTuneSeqConfig(
initial_ckpt_path=str(pretrain_ckpt_path)
)

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
```

This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by:
```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
python -m bionemo.esm2.model.finetune.train
```

## Notes
1. The above example is fine-tuning a randomly initialized ESM-2 model for demonstration purposes. In order to fine-tune a pre-trained ESM-2 model, please download the ESM-2 650M checkpoint from NGC resources using the following bash command
1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.
```bash
download_bionemo_data esm2/650m:2.0 --source ngc
download_bionemo_data esm2/650m:2.0
```
and pass the output path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`) as an argument into `initial_ckpt_path` while setting the config object:
```python
Expand All @@ -219,21 +221,43 @@ python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/fine
3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics.

# Fine-Tuned ESM-2 Model Inference
Once we have a checkpoint we can create a config object by pointing the path in `initial_ckpt_path` and use that for inference. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `nitial_ckpt_skip_keys_with_these_prefixes` in this config. Now we can use the ```bionemo.esm2.model.fnetune.train.infer``` to run inference on prediction dataset.
Now we can use the ```bionemo.esm2.model.fnetune.train.infer``` to run inference on an example prediction dataset.
farhadrgh marked this conversation as resolved.
Show resolved Hide resolved
Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`).

```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = finetuned_checkpoint,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file.

```bash
mkdir -p $WORKDIR/esm2_finetune_tutorial

# download sample data CSV for inference
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0 --source ngc)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/inference_results.pt

infer_esm2 --checkpoint-path <finetune checkpoint path> \
--data-path $DATA_PATH \
--results-path $RESULTS_PATH \
--config-class ESM2FineTuneSeqConfig
```

This example is implemented in ```bionemo.esm2.model.finetune.infer``` and can be executed by:
This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/inference_results.pt` which can be loaded via PyTorch library in python environment:

```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py
```python
import torch

# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
results = torch.load(results_path)
pstjohn marked this conversation as resolved.
Show resolved Hide resolved

# results is a python dict which includes the following result tensors for this example:
# results['regression_output'] is a tensor with shape: torch.Size([10, 1])
```

## Notes
1. For demonstration purposes, executing the above command will infer a randomly initialized `ESM2FineTuneSeqModel` unless `initial_ckpt_path` is specified and set to an already trained model.
2. If a fine-tuned checkpoint is provided as (`initial_ckpt_path`) the `initial_ckpt_skip_keys_with_these_prefixes` should reset to `field(default_factory=list)` and avoid skipping any parameters.
- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config.

```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = <finetuned checkpoint>,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
```
Loading
Loading