Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply torch.compile in Fabric.setup() #19280

Merged
merged 59 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
3fc088c
rewrap
awaelchli Dec 20, 2023
42c4f41
update
awaelchli Dec 21, 2023
561edd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
8676ece
add integration test
awaelchli Dec 21, 2023
59e2808
Merge branch 'feature/rewrap-optimized-module' of github.com:Lightnin…
awaelchli Dec 21, 2023
c3ab037
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
9e85169
update
awaelchli Dec 21, 2023
61684ab
Merge branch 'master' into feature/rewrap-optimized-module
awaelchli Dec 26, 2023
c8fc487
reapply
awaelchli Dec 26, 2023
2c41ae6
update
awaelchli Dec 26, 2023
60388fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2023
8e28d32
ticks
awaelchli Dec 26, 2023
23100f1
update test
awaelchli Dec 26, 2023
d03edc0
pytorch version
awaelchli Dec 26, 2023
2519dd6
update ddp test
awaelchli Dec 26, 2023
6ed234e
Merge branch 'feature/rewrap-optimized-module' of github.com:Lightnin…
awaelchli Dec 26, 2023
9335f5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2023
ab3d72d
tests
awaelchli Dec 26, 2023
9ed7ffc
Merge branch 'feature/rewrap-optimized-module' of github.com:Lightnin…
awaelchli Dec 26, 2023
d74e145
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2023
3b7a42b
mypy
awaelchli Dec 26, 2023
a308951
add another test
awaelchli Dec 26, 2023
f3252d9
Merge branch 'feature/rewrap-optimized-module' of github.com:Lightnin…
awaelchli Dec 26, 2023
4949504
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2023
56d050e
unwrap
awaelchli Dec 26, 2023
f1d80ea
tilde
awaelchli Dec 26, 2023
61299be
fix
awaelchli Dec 26, 2023
ea72a7d
type
awaelchli Dec 26, 2023
9d3f271
fix return
awaelchli Dec 26, 2023
6d8e641
fix test
awaelchli Dec 26, 2023
48b0452
test
awaelchli Dec 26, 2023
637766e
unwrap
awaelchli Dec 26, 2023
e8736cb
update tests
awaelchli Dec 26, 2023
379b9eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2023
5d7f856
Merge branch 'master' into feature/rewrap-optimized-module
Borda Jan 10, 2024
e5cb498
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
4e5ab86
capture compile args
awaelchli Jan 13, 2024
9e34e7c
update
awaelchli Jan 13, 2024
f6d9538
update tests
awaelchli Jan 13, 2024
5f152d9
tests
awaelchli Jan 13, 2024
e89c69f
update tests
awaelchli Jan 13, 2024
3ffab0b
Merge branch 'feature/rewrap-optimized-module2' of github.com:Lightni…
awaelchli Jan 13, 2024
f12de76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2024
d3f63e7
fix patch
awaelchli Jan 13, 2024
89b4aa7
conditional patch
awaelchli Jan 13, 2024
348713a
mypy
awaelchli Jan 13, 2024
4abefd3
docs
awaelchli Jan 13, 2024
75b3c62
docs
awaelchli Jan 13, 2024
56bec29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2024
b111f58
update doc
awaelchli Jan 14, 2024
aa7f146
todo
awaelchli Jan 15, 2024
b0546b9
do backward
awaelchli Jan 15, 2024
0b7da4c
chlog
awaelchli Jan 15, 2024
b71b10b
chlog
awaelchli Jan 15, 2024
d27876d
Merge branch 'master' into feature/rewrap-optimized-module2
awaelchli Jan 15, 2024
50efab0
error if compile used before import lightning
awaelchli Jan 15, 2024
4b0b840
Merge branch 'master' into feature/rewrap-optimized-module2
awaelchli Jan 19, 2024
1121284
Add comment about limitation
awaelchli Jan 19, 2024
c3ed341
add forgotten changelog
awaelchli Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source-fabric/api/fabric_methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ This is useful if your model experiences *exploding gradients* during training.
fabric.clip_gradients(model, optimizer, max_norm=2.0, norm_type="inf")

The :meth:`~lightning.fabric.fabric.Fabric.clip_gradients` method is agnostic to the precision and strategy being used.
Note: Gradient clipping with FSDP is not yet fully supported.


to_device
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for clipping gradients by value with FSDP ([#19236](https://github.com/Lightning-AI/lightning/pull/19236))


- (Experimental) Added support for re-compiling the model inside `Fabric.setup()` over the FSDP/DDP wrappers ([#19280](https://github.com/Lightning-AI/lightning/pull/19280))


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
Expand Down
36 changes: 29 additions & 7 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
_FabricDataLoader,
_FabricModule,
_FabricOptimizer,
_to_compiled,
_unwrap_compiled,
_unwrap_objects,
)
Expand Down Expand Up @@ -213,6 +214,7 @@ def setup(
module: nn.Module,
*optimizers: Optimizer,
move_to_device: bool = True,
_reapply_compile: Optional[bool] = None,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
r"""Set up a model and its optimizers for accelerated training.

Expand All @@ -221,12 +223,17 @@ def setup(
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.

Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.

"""
self._validate_setup(module, optimizers)
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module

module = self._precision.convert_module(module)
Expand All @@ -242,6 +249,8 @@ def setup(
else:
module = self._strategy.setup_module(module)

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
Expand All @@ -258,8 +267,8 @@ def setup(
self._models_setup += 1

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
original_module._fabric = self
original_module._fabric_optimizers = optimizers
if original_module not in self._callbacks:
self._callbacks.append(original_module)

Expand All @@ -270,7 +279,9 @@ def setup(
return (module, *optimizers)
return module

def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _FabricModule:
def setup_module(
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
) -> _FabricModule:
r"""Set up a model for accelerated training or inference.

This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
Expand All @@ -281,12 +292,17 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
module: A :class:`torch.nn.Module` to set up
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.

Returns:
The wrapped model.

"""
self._validate_setup_module(module)
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module

module = self._precision.convert_module(module)
Expand All @@ -296,6 +312,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri

# Let strategy wrap and connect the module alone
module = self._strategy.setup_module(module)

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
Expand All @@ -305,7 +324,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
)

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
original_module._fabric = self
if original_module not in self._callbacks:
self._callbacks.append(original_module)

Expand Down Expand Up @@ -410,6 +429,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =

"""
module = model._forward_module if model is not None else model
module, _ = _unwrap_compiled(module)
if isinstance(self._strategy, DeepSpeedStrategy):
if model is None:
if self._models_setup == 0:
Expand Down Expand Up @@ -641,7 +661,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
skip.

"""
module = _unwrap_compiled(module)
module, _ = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
raise TypeError(
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
Expand All @@ -656,7 +676,9 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
category=PossibleUserWarning,
)
return nullcontext()
return self._strategy._backward_sync_control.no_backward_sync(module._forward_module)

forward_module, _ = _unwrap_compiled(module._forward_module)
return self._strategy._backward_sync_control.no_backward_sync(forward_module)

def sharded_model(self) -> ContextManager:
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
Expand Down Expand Up @@ -772,7 +794,7 @@ def load(
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
obj = _unwrap_compiled(state[k])
obj, _ = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
continue
state[k] = unwrapped_state[k]
Expand Down
66 changes: 58 additions & 8 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

import torch
from lightning_utilities.core.apply_func import apply_to_collection
Expand All @@ -31,6 +46,9 @@
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable

if TYPE_CHECKING:
from torch._dynamo import OptimizedModule

T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")

Expand Down Expand Up @@ -279,8 +297,8 @@ def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
) -> Union[nn.Module, Optimizer, DataLoader]:
if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule):
return unwrapped._forward_module
if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
return _unwrap_compiled(unwrapped._forward_module)[0]
if isinstance(obj, _FabricOptimizer):
return obj.optimizer
if isinstance(obj, _FabricDataLoader):
Expand All @@ -296,19 +314,28 @@ def _unwrap(
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)


def _unwrap_compiled(obj: Any) -> Any:
def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.

Use this function before instance checks against e.g. :class:`_FabricModule`.

"""
if not _TORCH_GREATER_EQUAL_2_0:
return obj
# obj can't be an `OptimizedModule` anyway
return obj, None

from torch._dynamo import OptimizedModule

if isinstance(obj, OptimizedModule):
return obj._orig_mod
return obj
return obj._orig_mod, getattr(obj, "_compile_kwargs", None)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return obj, None


def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "OptimizedModule":
if not _TORCH_GREATER_EQUAL_2_0:
raise RuntimeError("Converting to a compiled module is only supported in PyTorch >= 2.0.0")

return torch.compile(module, **compile_kwargs) # type: ignore[return-value]


def is_wrapped(obj: object) -> bool:
Expand All @@ -322,5 +349,28 @@ def is_wrapped(obj: object) -> bool:
obj: The object to test.

"""
obj = _unwrap_compiled(obj)
obj, _ = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))


def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Wraps the ``torch.compile`` function and captures the compile arguments.

We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
same arguments as the user passed to the original call. The arguments get stored in a dictionary
``_compile_kwargs`` on the returned compiled module.

"""

@wraps(compile_fn)
def _capture(model: Any, **kwargs: Any) -> Any:
compiled_model = compile_fn(model, **kwargs)
if isinstance(model, nn.Module):
compiled_model._compile_kwargs = deepcopy(kwargs)
return compiled_model

return _capture


if _TORCH_GREATER_EQUAL_2_0:
torch.compile = _capture_compile_kwargs(torch.compile)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 40 additions & 0 deletions tests/tests_fabric/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from torch.nn.parallel.distributed import DistributedDataParallel

from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
from tests_fabric.test_fabric import BoringModel


@pytest.mark.parametrize(
Expand Down Expand Up @@ -64,6 +70,40 @@ def assert_params_equal(params0, params1):
assert_params_equal(params_before, wrapped_model.parameters())


@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
@mock.patch(
"lightning.fabric.wrappers.torch.compile",
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
)
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
from torch._dynamo import OptimizedModule

fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp")
fabric.launch()

model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

fabric_model = fabric.setup(compiled_model, _reapply_compile=True)

assert isinstance(fabric_model._forward_module, OptimizedModule)
assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)

assert fabric_model._original_module == model
assert fabric_model._forward_module._orig_mod.module == model
assert fabric_model.device == fabric.device

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()


@pytest.mark.parametrize(
("clip_type", "accelerator", "precision"),
[
Expand Down
44 changes: 26 additions & 18 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from copy import deepcopy
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -344,33 +345,40 @@ def test_setup_with_orig_params_and_multiple_param_groups():
assert not isinstance(layer.weight, FlatParameter)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, dynamo=True)
@mock.patch.dict(os.environ, {})
@pytest.mark.parametrize(
"compile_after_setup",
[
False,
# https://github.com/pytorch/pytorch/issues/97811
pytest.param(True, marks=RunIf(min_python="3.9")),
],
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
@mock.patch(
"lightning.fabric.wrappers.torch.compile",
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
)
def test_compile(compile_after_setup):
"""Test that the model can be compiled before and after the model is wrapped in FSDP."""
model = BoringModel()
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
from torch._dynamo import OptimizedModule

strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

if not compile_after_setup:
model = torch.compile(model)
model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

fabric_model = fabric.setup(compiled_model, _reapply_compile=True)

assert isinstance(fabric_model._forward_module, OptimizedModule)
assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)

model = fabric.setup(model)
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)

if compile_after_setup:
model = torch.compile(model)
assert fabric_model._original_module == model
assert fabric_model._forward_module._orig_mod.module == model
assert fabric_model.device == fabric.device

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
model(torch.rand(2, 32, device=fabric.device)).sum().backward()
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
Expand Down
Loading