Skip to content

Commit

Permalink
ESM2 Tutorial Updates (#426)
Browse files Browse the repository at this point in the history
- Deprecate `infer.py` example for finetuning and add Inference
tutorial.
- Add [ESM2 Mutant
Design](https://github.com/NVIDIA/bionemo-framework/blob/bionemo1/docs/bionemo/notebooks/esm2nv-mutant-design.ipynb)
notebook.

---------

Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Co-authored-by: Peter St. John <[email protected]>
  • Loading branch information
farhadrgh and pstjohn authored Nov 22, 2024
1 parent f5a4d81 commit 4c0a071
Show file tree
Hide file tree
Showing 11 changed files with 1,784 additions and 160 deletions.
84 changes: 54 additions & 30 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ data_module = ESM2FineTuneDataModule(

# Fine-Tuning the Regressor Task Head for ESM2

Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.
Now we can put these five requirements together to fine-tune a regressor task head starting from a pre-trained 650M ESM-2 model (`pretrain_ckpt_path`). We can take advantage of a simple training loop in ```bionemo.esm2.model.fnetune.train``` and use the ```train_model()`` function to start the fine-tuning process in the following.

```python
# create a List[Tuple] with (sequence, target) values
Expand All @@ -174,33 +174,35 @@ data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]
dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)

with tempfile.TemporaryDirectory() as experiment_tempdir_name:
experiment_dir = Path(experiment_tempdir_name)
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42
experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42

config = ESM2FineTuneSeqConfig(
# initial_ckpt_path=str(pretrain_ckpt_path)
)
# To download a 650M pre-trained ESM2 model
pretrain_ckpt_path = load("esm2/650m:2.0")

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=experiment_dir, # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
config = ESM2FineTuneSeqConfig(
initial_ckpt_path=str(pretrain_ckpt_path)
)

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
)
```

This example is fully implemented in ```bionemo.esm2.model.finetune.train``` and can be executed by:
```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
python -m bionemo.esm2.model.finetune.train
```

## Notes
1. The above example is fine-tuning a randomly initialized ESM-2 model for demonstration purposes. In order to fine-tune a pre-trained ESM-2 model, please download the ESM-2 650M checkpoint from NGC resources using the following bash command
1. The above example is fine-tuning a 650M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.
```bash
download_bionemo_data esm2/650m:2.0 --source ngc
download_bionemo_data esm2/650m:2.0
```
and pass the output path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`) as an argument into `initial_ckpt_path` while setting the config object:
```python
Expand All @@ -219,21 +221,43 @@ python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/fine
3. We are using a small dataset of artificial sequences as our fine-tuning data in this example. You may experience over-fitting and observe no change in the validation metrics.

# Fine-Tuned ESM-2 Model Inference
Once we have a checkpoint we can create a config object by pointing the path in `initial_ckpt_path` and use that for inference. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `nitial_ckpt_skip_keys_with_these_prefixes` in this config. Now we can use the ```bionemo.esm2.model.fnetune.train.infer``` to run inference on prediction dataset.
Now we can use ```bionemo.esm2.model.finetune.train.infer``` to run inference on an example prediction dataset.
Record the checkpoint path reported at the end of the finetuning run, after executing `python -m bionemo.esm2.model.finetune.train` (e.g. `/tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last`) and use that as an argument to inference script (`--checkpoint-path`).

```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = finetuned_checkpoint,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
We download a CSV example dataset of articical sequences for this inference example. Please refer to [ESM-2 Inference](./inference) tutorial for detailed explanation of the arguments and how to create your own CSV file.

```bash
mkdir -p $WORKDIR/esm2_finetune_tutorial
# download sample data CSV for inference
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0 --source ngc)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/inference_results.pt
infer_esm2 --checkpoint-path <finetune checkpoint path> \
--data-path $DATA_PATH \
--results-path $RESULTS_PATH \
--config-class ESM2FineTuneSeqConfig
```

This example is implemented in ```bionemo.esm2.model.finetune.infer``` and can be executed by:
This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/inference_results.pt` which can be loaded via PyTorch library in python environment:

```bash
python /workspace/bionemo2/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/infer.py
```python
import torch
# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
results = torch.load(results_path)
# results is a python dict which includes the following result tensors for this example:
# results['regression_output'] is a tensor with shape: torch.Size([10, 1])
```

## Notes
1. For demonstration purposes, executing the above command will infer a randomly initialized `ESM2FineTuneSeqModel` unless `initial_ckpt_path` is specified and set to an already trained model.
2. If a fine-tuned checkpoint is provided as (`initial_ckpt_path`) the `initial_ckpt_skip_keys_with_these_prefixes` should reset to `field(default_factory=list)` and avoid skipping any parameters.
- ESM2 Inference module takes the `--checkpoint-path` and `--config-class` arguments to create a config object by pointing the path in `initial_ckpt_path`. Since we need to load all the parameters from this checkpoint (and don't skip the head) we reset the `initial_ckpt_skip_keys_with_these_prefixes` in this config.
```python
config = ESM2FineTuneSeqConfig(
initial_ckpt_path = <finetuned checkpoint>,
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
)
```
Loading

0 comments on commit 4c0a071

Please sign in to comment.