Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile_fn parameter for Trainer #20269

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.types import _PATH
from lightning.fabric.wrappers import _to_compiled, _unwrap_compiled
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from lightning.pytorch.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -530,19 +531,25 @@ def fit(
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

"""
model = _maybe_unwrap_optimized(model)
# when provided compiled model, unwrap and re-do after applied strategy
model, compile_kwargs = (
_unwrap_compiled(model)
if isinstance(model, torch._dynamo.OptimizedModule)
else (_maybe_unwrap_optimized(model), None)
)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
Expand Down Expand Up @@ -572,7 +579,7 @@ def _fit_impl(
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=ckpt_path)
self._run(model, compile_kwargs, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -903,7 +910,10 @@ def _predict_impl(
return results

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
ckpt_path: Optional[_PATH] = None,
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
Expand Down Expand Up @@ -957,6 +967,10 @@ def _run(
# strategy will configure model and move it to the device
self.strategy.setup(self)

# when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy
if compile_kwargs is not None:
self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs)

# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
import torch
from torch._dynamo import OptimizedModule
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.multiprocessing import ProcessRaisedException
from torch.nn.parallel.distributed import DistributedDataParallel
Expand Down Expand Up @@ -448,3 +449,30 @@ def creates_processes_externally(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp", max_steps=2, logger=False)

model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

trainer.fit(compiled_model)
trainer_model = trainer.strategy.model

assert isinstance(trainer_model, OptimizedModule)
assert isinstance(trainer_model._orig_mod, DistributedDataParallel)
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)

assert trainer_model._orig_mod.module == model

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
trainer_model(torch.randn(2, 32, device="gpu")).sum().backward()
28 changes: 28 additions & 0 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import torch
import torch.nn as nn
from torch._dynamo import OptimizedModule
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy
Expand Down Expand Up @@ -971,3 +972,30 @@ def configure_optimizers(self):
max_steps=4,
)
trainer.fit(model, ckpt_path=checkpoint_path_full)


@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False)

model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

trainer.fit(compiled_model)
trainer_model = trainer.strategy.model

assert isinstance(trainer_model, OptimizedModule)
assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel)
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)

assert trainer_model._orig_mod.module == model

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
trainer_model(torch.randn(2, 32, device="gpu")).sum().backward()
Loading