Skip to content

Commit

Permalink
[torchtune][dcp] Refactor checkpointing code from the recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
saumishr committed Dec 2, 2024
1 parent 4669ac3 commit e5a51b5
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 128 deletions.
170 changes: 52 additions & 118 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import sys
import time

from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn
Expand All @@ -15,11 +14,6 @@
from omegaconf import DictConfig, ListConfig
from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.checkpoint.state_dict import (
_init_optim_state,
set_model_state_dict,
set_state_dict,
)

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
Expand All @@ -28,10 +22,14 @@
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DistributedCheckpointer, DummyProfiler, PROFILER_KEY
from torchtune.training import (
CheckpointClient,
DummyProfiler,
PROFILER_KEY,
TrainingProgress,
)
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr
from torchtune.utils._logging import log_rank_zero
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -146,6 +144,7 @@ def __init__(self, cfg: DictConfig) -> None:
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
Expand Down Expand Up @@ -193,101 +192,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._dcp_checkpointer = None

def get_dcp_checkpointer(self):
if not self._dcp_checkpointer:
self._dcp_checkpointer = DistributedCheckpointer(
checkpoint_dir=self._checkpointer._checkpoint_dir,
output_dir=self._checkpointer._output_dir,
)

return self._dcp_checkpointer

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate.
If resume_from_checkpoint is True, this also includes the recipe state.
"""
# If async checkpointing is enabled, set the should_load_recipe_state as false for the checkpointer.
# We will use DistributedCheckpointer instead to load the intermediate checkpoints which includes the
# recipe state as well.
should_load_recipe_state: bool = (
False if self._enable_async_checkpointing else self._resume_from_checkpoint
)

self._checkpointer = config.instantiate(
cfg_checkpointer,
should_load_recipe_state=should_load_recipe_state,
)

return self._checkpointer.load_checkpoint()

def resume_from_checkpoint(self, checkpoint_dict: Dict[str, any]) -> Dict[str, Any]:
"""
Load the intermediate checkpoint to restore the training state.
If async checkpointing is enabled, user does not need to provide any recipe state.
DistributedCheckpoint loads the latest intermediate checkpoint state which includes the recipe state already.
Otherwise, update the recipe state as loaded in the checkpoint already via the load_checkpoint call.
"""
if self._enable_async_checkpointing:
if self._is_rank_zero:
dcp_load_start = time.perf_counter()

model_state_dict = self._model.state_dict()

if not self._optimizer_in_bwd:
_init_optim_state(self._optimizer)
optim_state_dict = self._optimizer.state_dict()
else:
optim_state_dict = self._optim_ckpt_wrapper.state_dict()

state_dict: Dict[str:Any] = {}
state_dict.update(
{
training.MODEL_KEY: model_state_dict,
training.OPT_KEY: optim_state_dict,
training.SEED_KEY: 0,
training.EPOCHS_KEY: 0,
training.TOTAL_EPOCHS_KEY: 0,
training.MAX_STEPS_KEY: 0,
}
)

dcp_checkpointer = self.get_dcp_checkpointer()
state_dict = dcp_checkpointer.load_checkpoint(state_dict)

if self._is_rank_zero:
log.info(
f"DistributedCheckpointer loaded the checkpoint in {time.perf_counter() - dcp_load_start:.2f} seconds."
)

if not self._optimizer_in_bwd:
set_state_dict(
self._model,
self._optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
)
else:
set_model_state_dict(
model=self._model,
model_state_dict=model_state_dict,
)
self._optim_ckpt_wrapper.load_state_dict(optim_state_dict)

# Update the recipe state from the loaded checkpoint
self._update_recipe_state(state_dict)
log_rank_zero(
log,
msg=f"Loaded intermediate distributed checkpoint for the epoch: {state_dict[training.EPOCHS_KEY]}",
)
else:
self._update_recipe_state(checkpoint_dict)
log_rank_zero(
log,
msg=f"Loaded intermediate checkpoint for the epoch: {checkpoint_dict[training.EPOCHS_KEY]}",
)

def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -341,7 +245,7 @@ def setup(self, cfg: DictConfig) -> None:
self._metric_logger.log_config(cfg)

# Load the base model
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._compile = cfg.get("compile", False)
self._model = self._setup_model(
Expand All @@ -362,13 +266,24 @@ def setup(self, cfg: DictConfig) -> None:
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint and not self._enable_async_checkpointing
if training.OPT_KEY in checkpoint_dict
else None
),
)

if self._resume_from_checkpoint:
self.resume_from_checkpoint(checkpoint_dict)
if self._enable_async_checkpointing:
checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)

# Update recipe state from checkpoint
self._update_recipe_state(checkpoint_dict)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
Expand Down Expand Up @@ -549,7 +464,7 @@ def _setup_model(
enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
model_state_dict: Dict[str, Any] = None,
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
Expand Down Expand Up @@ -613,16 +528,20 @@ def _setup_model(
if hasattr(m, "rope_init"):
m.rope_init()

# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
if model_state_dict:
# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)

# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
Expand All @@ -636,6 +555,7 @@ def _setup_model(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)

if self._is_rank_zero:
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -1069,7 +989,21 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
total_epochs=self.total_epochs,
max_steps_per_epoch=self.max_steps_per_epoch,
),
epoch=curr_epoch,
)

self._profiler.stop()

Expand Down
4 changes: 4 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torchtune.training.checkpointing import (
ADAPTER_CONFIG,
ADAPTER_KEY,
CheckpointClient,
Checkpointer,
DistributedCheckpointer,
EPOCHS_KEY,
Expand All @@ -51,6 +52,7 @@
SEED_KEY,
STEPS_KEY,
TOTAL_EPOCHS_KEY,
TrainingProgress,
update_state_dict_for_classifier,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup, get_lr
Expand Down Expand Up @@ -78,6 +80,7 @@
"get_dtype",
"set_default_dtype",
"validate_expected_param_dtype",
"CheckpointClient",
"FullModelHFCheckpointer",
"FullModelMetaCheckpointer",
"DistributedCheckpointer",
Expand Down Expand Up @@ -131,4 +134,5 @@
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
"TrainingProgress",
]
7 changes: 7 additions & 0 deletions torchtune/training/checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# LICENSE file in the root directory of this source tree.
from typing import Union

from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)

from torchtune.training.checkpointing._checkpointer import (
DistributedCheckpointer,
FullModelHFCheckpointer,
Expand Down Expand Up @@ -35,6 +40,7 @@
]

__all__ = [
"CheckpointClient",
"FullModelHFCheckpointer",
"FullModelMetaCheckpointer",
"FullModelTorchTuneCheckpointer",
Expand All @@ -53,4 +59,5 @@
"STEPS_KEY",
"TOTAL_EPOCHS_KEY",
"FormattedCheckpointFiles",
"TrainingProgress",
]
Loading

0 comments on commit e5a51b5

Please sign in to comment.