Skip to content

Commit

Permalink
Wandb integration (#205)
Browse files Browse the repository at this point in the history
* add more arguments for wandb integration

* add more arguments to wandblogger

* edit explanation and delete unnecessary variables

* modification

* modification

* recommit

* arrange import parts

* Fix CLI parsing

---------

Co-authored-by: John St John <[email protected]>
  • Loading branch information
2 people authored and tshimko-nv committed Oct 2, 2024
1 parent aac66c9 commit 3eff47c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 36 deletions.
59 changes: 44 additions & 15 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
from pathlib import Path
from typing import Optional, Sequence, get_args
from typing import List, Optional, Sequence, get_args

from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
Expand Down Expand Up @@ -52,8 +51,6 @@ def main(
min_seq_length: Optional[int],
max_seq_length: int,
result_dir: Path,
wandb_project: Optional[str],
wandb_offline: bool,
num_steps: int,
warmup_steps: int,
limit_val_batches: int,
Expand All @@ -66,7 +63,14 @@ def main(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
wandb_entity: str = "clara-discovery",
wandb_entity: Optional[str] = None,
wandb_project: Optional[str] = None,
wandb_offline: bool = False,
wandb_tags: Optional[List[str]] = None,
wandb_group: Optional[str] = None,
wandb_id: Optional[str] = None,
wandb_anonymous: Optional[bool] = False,
wandb_log_model: bool = False,
pipeline_model_parallel_size: int = 1,
tensor_model_parallel_size: int = 1,
create_tensorboard_logger: bool = False,
Expand Down Expand Up @@ -94,8 +98,14 @@ def main(
devices (int): number of devices
seq_length (int): sequence length
result_dir (Path): directory to store results, logs and checkpoints
wandb_project (Optional[str]): weights and biases project name
wandb_offline (bool): if wandb should happen in offline mode
wandb_entity (str): The team posting this run (default: your username or your default team)
wandb_project (str): The name of the project to which this run will belong.
wandb_tags (List[str]): Tags associated with this run.
wandb_group (str): A unique string shared by all runs in a given group
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
wandb_id (str): Sets the version, mainly used to resume a previous run.
wandb_anonymous (bool): Enables or explicitly disables anonymous logging.
wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers.
num_steps (int): number of steps to train the model for
limit_val_batches (int): limit the number of validation global batches to this many
val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers (
Expand All @@ -106,7 +116,6 @@ def main(
experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username
create_tensorboard_logger (bool): create the tensorboard logger
restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
Expand All @@ -132,16 +141,23 @@ def main(
ckpt_include_optimizer=True,
)

# for wandb integration
# Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html"
wandb_options: Optional[WandbLoggerOptions] = (
None
if wandb_project is None
else WandbLoggerOptions(
offline=wandb_offline,
project=wandb_project,
entity=wandb_entity,
log_model=False,
tags=wandb_tags,
group=wandb_group,
id=wandb_id,
anonymous=wandb_anonymous,
log_model=wandb_log_model,
)
)

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
Expand Down Expand Up @@ -295,14 +311,21 @@ def main(
"--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
)
parser.add_argument("--experiment-name", type=str, required=False, default="esm2", help="Name of the experiment.")
parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.")

parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run")
parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ")
parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run")
parser.add_argument(
"--wandb-project",
type=str,
required=False,
default=None,
help="Wandb project name. Wandb will only happen if this is set.",
"--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
)
parser.add_argument(
"--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
)
parser.add_argument("--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging")
parser.add_argument(
"--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers"
)
parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode")
parser.add_argument(
"--num-gpus",
type=int,
Expand Down Expand Up @@ -490,7 +513,13 @@ def main(
min_seq_length=args.min_seq_length,
max_seq_length=args.max_seq_length,
result_dir=args.result_dir,
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
wandb_offline=args.wandb_offline,
num_steps=args.num_steps,
warmup_steps=args.warmup_steps,
Expand Down
56 changes: 42 additions & 14 deletions scripts/singlecell/geneformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import argparse
import math
from pathlib import Path
from typing import Dict, Optional, Sequence, Type, get_args
from typing import Dict, List, Optional, Sequence, Type, get_args

from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
Expand Down Expand Up @@ -56,8 +56,6 @@ def main(
devices: int,
seq_length: int,
result_dir: Path,
wandb_project: Optional[str],
wandb_offline: bool,
num_steps: int,
limit_val_batches: int,
val_check_interval: int,
Expand All @@ -71,7 +69,14 @@ def main(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
wandb_entity: str = "clara-discovery",
wandb_entity: Optional[str] = None,
wandb_project: Optional[str] = None,
wandb_offline: bool = False,
wandb_tags: Optional[List[str]] = None,
wandb_group: Optional[str] = None,
wandb_id: Optional[str] = None,
wandb_anonymous: Optional[bool] = False,
wandb_log_model: bool = False,
create_tensorboard_logger: bool = False,
nemo1_init_path: Path | None = None,
restore_from_checkpoint_path: Path | None = None,
Expand All @@ -90,8 +95,6 @@ def main(
devices (int): number of devices
seq_length (int): sequence length
result_dir (Path): directory to store results, logs and checkpoints
wandb_project (Optional[str]): weights and biases project name
wandb_offline (bool): if wandb should happen in offline mode
num_steps (int): number of steps to train the model for
limit_val_batches (int): limit the number of validation global batches to this many
val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers (
Expand All @@ -104,7 +107,14 @@ def main(
experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username
wandb_entity (str): The team posting this run (default: your username or your default team)
wandb_project (str): The name of the project to which this run will belong.
wandb_tags (List[str]): Tags associated with this run.
wandb_group (str): A unique string shared by all runs in a given group
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
wandb_id (str): Sets the version, mainly used to resume a previous run.
wandb_anonymous (bool): Enables or explicitly disables anonymous logging.
wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers.
create_tensorboard_logger (bool): create the tensorboard logger
restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
Expand Down Expand Up @@ -137,14 +147,20 @@ def main(
ckpt_include_optimizer=True,
)

# for wandb integration
# Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html"
wandb_options: Optional[WandbLoggerOptions] = (
None
if wandb_project is None
else WandbLoggerOptions(
offline=wandb_offline,
project=wandb_project,
entity=wandb_entity,
log_model=False,
tags=wandb_tags,
group=wandb_group,
id=wandb_id,
anonymous=wandb_anonymous,
log_model=wandb_log_model,
)
)
trainer = nl.Trainer(
Expand Down Expand Up @@ -319,14 +335,20 @@ def main(
parser.add_argument(
"--experiment-name", type=str, required=False, default="geneformer", help="Name of the experiment."
)
parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.")
parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run")
parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ")
parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run")
parser.add_argument(
"--wandb-project",
type=str,
required=False,
default=None,
help="Wandb project name. Wandb will only happen if this is set..",
"--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
)
parser.add_argument(
"--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
)
parser.add_argument("--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging")
parser.add_argument(
"--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers"
)
parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode")
parser.add_argument(
"--cosine-rampup-frac",
type=float,
Expand Down Expand Up @@ -490,7 +512,13 @@ def config_class_type(desc: str) -> Type[BioBertGenericConfig]:
devices=args.num_gpus,
seq_length=args.seq_length,
result_dir=args.result_dir,
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
wandb_offline=args.wandb_offline,
num_steps=args.num_steps,
limit_val_batches=args.limit_val_batches,
Expand Down
18 changes: 11 additions & 7 deletions sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
from typing import Any, Dict, Optional, Sequence, TypedDict
from typing import Any, Dict, List, Optional, Sequence, TypedDict

from nemo.lightning.nemo_logger import NeMoLogger
from nemo.lightning.pytorch import callbacks as nemo_callbacks
Expand All @@ -32,12 +32,16 @@ class WandbLoggerOptions(TypedDict):
`directory` is also omitted since it is set by the NeMoLogger.
""" # noqa: D205

offline: bool # offline mode
project: str # project name
entity: str # group name or user name
# name: str # experiment name, this is handled by NeMoLogger
# the directory is also set by NeMoLogger
log_model: bool # log model
entity: str # The team posting this run (default: your username or your default team)
project: str # The name of the project to which this run will belong.
# name: #Display name for the run. "This is handled by NeMoLogger"
# save_dir: #Path where data is saved. "This is handled by NeMoLogger"
tags: List[str] # Tags associated with this run.
group: str # A unique string shared by all runs in a given group
offline: bool # Run offline (data can be streamed later to wandb servers).
id: str # Sets the version, mainly used to resume a previous run.
anonymous: bool # Enables or explicitly disables anonymous logging.
log_model: bool # Save checkpoints in wandb dir to upload on W&B servers.


def setup_nemo_lightning_logger(
Expand Down

0 comments on commit 3eff47c

Please sign in to comment.