-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add feature Exponential Moving Average (EMA) #10914
Comments
Hi, |
@justusschock thank you for your reply. Is there a way to use ema using swa? I checked the link. For me, |
@hankyul2, I believe this is how it would be implemented: from pytorch_lightning.callbacks import StochasticWeightAveraging
class EMA_Callback(StochasticWeightAveraging):
def __init__(self, decay=0.9999):
super().__init__()
self.decay = decay
def avg_fn (
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
) -> torch.FloatTensor:
e = averaged_model_parameter
m = model_parameter
return self.decay * e + (1. - self.decay) * m Let me know if you have any problems. I just learned about SWA because of your issue! |
@mathemusician thank you. |
@hankyul2 I don't think that @mathemusician solution is equivalent. |
@hal-314 yeah. I think so. |
I have implemented EMA Callback with simple functionality.
I think much more options should be added. For example, save_weight, ema_step_period, etc. If you find it helpful, please let me know. Then I will implement it more. If you don't, close this issue or leave any comments. from copy import deepcopy
import torch
from pytorch_lightning import Callback
class EMACallback(Callback):
def __init__(self, decay=0.995):
self.decay = decay
self.module_pair_list = []
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def forward_wrapper(module, org, ema):
def forward(*args, **kwargs):
return org(*args, **kwargs) if module.training else ema(*args, **kwargs)
return forward
modules = list(filter(lambda x: len(list(x[1].parameters())) > 0, pl_module.named_children()))
for name, module in modules:
ema_module = deepcopy(module)
self.module_pair_list.append((ema_module, module))
pl_module.add_module(f'EMA_{name}', ema_module)
module.forward_bc = module.forward
module.forward = forward_wrapper(module, module.forward_bc, ema_module.forward)
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
for ema_module, module in self.module_pair_list:
self._update(ema_module, module, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def _update(self, ema_module, module, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(ema_module.state_dict().values(), module.state_dict().values()):
ema_v.copy_(update_fn(ema_v, model_v)) |
@hankyul2 I like more the approach from timm as it doesn't involve changing forward method. Users could not use the forward method in Finally, I manage to implemented EMA as a callback by using state_dicts instead of the whole module as ModelEMAV2. I can try to make a PR or, at least, post the code here if you or anyone is interested (and I get permission for it). It works for 1 GPU and, likely, for Multi GPU though it isn't tested. |
@hal-314 Cool 😎. I want to see it. Can you share your code? |
@hankyul2 Here is the code. Be aware that you need Bits to be aware:
I hope it's useful. from copy import deepcopy
from typing import Optional, Union, Dict, Any
import pytorch_lightning as pl
import torch
from overrides import overrides
from pytorch_lightning.utilities import rank_zero_only
class EMA(pl.Callback):
"""Implements EMA (exponential moving average) to any kind of model.
EMA weights will be used during validation and stored separately from original model weights.
How to use EMA:
- Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See
https://github.com/rwightman/pytorch-image-models/issues/102
- Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and
https://github.com/rwightman/pytorch-image-models/issues/224
- For object detection, SWA usually works better. See https://github.com/timgaripov/swa/issues/16
Implementation detail:
- See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
- When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
performance.
"""
def __init__(self, decay: float = 0.9999, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True):
super().__init__()
self.decay = decay
self.ema_device: str = f"{ema_device}" if ema_device else None # perform ema on different device from the model
self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False # Only works if CUDA is available
self.ema_state_dict: Dict[str, torch.Tensor] = {}
self.original_state_dict = {}
self._ema_state_dict_ready = False
@staticmethod
def get_state_dict(pl_module: pl.LightningModule):
"""Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
For example, in pl_module has metrics, you don't want to return their parameters.
code:
# Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
# like losses, metrics, etc.
patterns_to_ignore = ("metrics1", "metrics2")
return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
"""
return pl_module.state_dict()
@overrides
def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
# Only keep track of EMA weights in rank zero.
if not self._ema_state_dict_ready and pl_module.global_rank == 0:
self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
if self.ema_device:
self.ema_state_dict = {k: tensor.to(device=self.ema_device) for k, tensor in self.ema_state_dict.items()}
if self.ema_device == "cpu" and self.ema_pin_memory:
self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()}
self._ema_state_dict_ready = True
@rank_zero_only
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
# Update EMA weights
with torch.no_grad():
for key, value in self.get_state_dict(pl_module).items():
ema_value = self.ema_state_dict[key]
ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)
@overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
assert self.ema_state_dict.keys() == self.original_state_dict.keys(), \
f"There are some keys missing in the ema static dictionary broadcasted. " \
f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
pl_module.load_state_dict(self.ema_state_dict, strict=False)
if pl_module.global_rank > 0:
# Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
self.ema_state_dict = {}
@overrides
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=False)
@overrides
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}
@overrides
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"]
self.ema_state_dict = callback_state["ema_state_dict"] @justusschock @mathemusician Do you think that lightning should include an EMA callback? Maybe, it could go to bolts. |
@hal-314 wow... I think it is good.(but I am not a maintainer or something) |
Personally I think we should include this. However, not sure where it belongs (Currently I'd say not to lightning core but either to flash or bolts) cc @tchaton @ethanwharris for opinions on this |
@hal-314 thanks for your implementation! I agree that EMA would be very useful to have in Lightning. I tested your implementation with default parameters (so |
@flukeskywalker Glad to de that it's useful for you :) If you fix the multi gpu code, could you mind to share the fix? So, others can use it. |
@hal-314 Thank you for sharing your code. I test it with 2 gpus in ddp mode. Whole error logs is in below. Traceback (most recent call last):
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 146, in run
self.on_advance_end()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 242, in on_advance_end
self._run_validation()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 337, in _run_validation
self.val_loop.run()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 140, in run
self.on_run_start(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 95, in on_run_start
self._on_evaluation_start()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 179, in _on_evaluation_start
self.trainer.call_hook("on_validation_start", *args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
callback_fx(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 216, in on_validation_start
callback.on_validation_start(self, self.lightning_module)
File "/home/hankyul/private/SuperConvergence/src/ema.py", line 134, in on_validation_start
pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 411, in broadcast
broadcast_object_list(obj, src, group=_group.WORLD)
File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1840, in broadcast_object_list
object_list[i] = _tensor_to_object(obj_view, obj_size)
File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1532, in _tensor_to_object
return _unpickler(io.BytesIO(buf)).load()
File "/opt/conda/lib/python3.7/site-packages/torch/storage.py", line 161, in _load_from_bytes
return torch.load(io.BytesIO(b))
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 608, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 787, in _legacy_load
result = unpickler.load()
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 743, in persistent_load
deserialized_objects[root_key] = restore_location(obj, location)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
result = fn(storage, location)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 155, in _cuda_deserialize
return storage_type(obj.size())
File "/opt/conda/lib/python3.7/site-packages/torch/cuda/__init__.py", line 606, in _lazy_new
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1. |
@hankyul2 I tried the code @hal-314 in my 2-gpus 1080ti machine. And it worked well. Maybe the out of memory error is just literally... My experiment configuration as below:
When I trained without ema, the gpu memory usage is 3311MB each. |
@hal-314 @hankyul2 sorry, I made a mistake in my last code. And I found the @overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
self.ema_state_dict = ema_state_dict since the
and source code as follows # pytorch_lightning/plugins/training_type/ddp_spawn.py
def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
return self.dist.broadcast(obj) # pytorch_lightning/distributed/dist.py
class LightningDistributed:
def __init__(self, rank=None, device=None):
self.rank = rank
self.device = device
def broadcast(self, obj: Any, group=_group.WORLD):
# always wrap into a list so it can be broadcasted.
obj = [obj]
if self.rank != 0:
obj = [None] * len(obj)
broadcast_object_list(obj, 0, group=group or _group.WORLD)
return obj[0] the gpu usage showed below: without ema:
with ema:
environment: |
I can confirm that @sevenights's fix works for me too with pytorch-lightning==1.5.8 |
I would love to see an implementation of EMA. I just migrated from ignite to lightning and lacking an EMA callback really holds me back. |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
I too would be quite interested in this feature, based upon a model I am replicating that uses EMA. @justusschock |
@Borda just curious if where this is on the roadmap? |
I think we can add it as experimental callback :) |
@Borda I would be super happy to beta test this ASAP. The EMA callback in Nemo shared above is apache-2.0 |
@lantiga, what are your thoughts? Maybe implement it in Bolts? |
PyTorch added support for this with pytorch/pytorch#94820 (commit) |
@carmocca Perhaps a doc update then, given that lightning suggests SWA where EMA is sometimes superior? |
I haven't tried the newly added EMA in PyTorch. Just wanted to share the info. If anybody gives it a shot, we would accept a docs contribution showing how to use it with Lightning. |
After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others... |
Just curious about the status here - Is there a recommended approach for this that the community has settled on? |
I am also curious either. What's going on with important EMA callback..? |
Also would love this! SWA is great and feel like EMA should be even easier to add? |
That's great! Thank you. |
I hope this will be helpful. |
@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics? |
@zhong-yy There's a |
Any update on this? |
This should have been added long ago. @williamFalcon please try to get your team to move this along, It has been a stopping block with lightning for many users for years now. |
Our new project uses EMA as a callback for pytorch lightning. I hope it can be used as a reference. For details, see: https://github.com/nkicsl/Resfusion/blob/main/callback/ema.py |
Any update? |
In my opinion it would be good to write a callback that uses the AveragedModel from PyTorch instead of basically copying the same functionality to Lightning. Fewer places that have to be updated when something needs to be fixed or added. |
I mean something like this: from lightning.pytorch.callbacks.callback import Callback
from torch.optim.swa_utils import AveragedModel
class ModelAveragingCallback(Callback):
def __init__(self, device, avg_fn):
self._device = device
self._avg_fn = avg_fn
self._averaged_model = None
self._latest_update_step = -1
def setup(self, trainer, pl_module, stage) -> None:
if stage == "fit":
device = self._device or pl_module.device
self._averaged_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if trainer.global_step > self._latest_update_step:
self._averaged_model.update_parameters(pl_module)
self._latest_update_step = trainer.global_step
def on_fit_end(self, trainer, pl_module):
average_params = itertools.chain(self._averaged_model.module.parameters(), self._averaged_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)
def on_validation_epoch_start(self, trainer, pl_module):
if self._averaged_model is not None:
self._swap_models(pl_module)
def on_validation_epoch_end(self, trainer, pl_module):
if self._averaged_model is not None:
self._swap_models(pl_module)
def state_dict(self):
return {"latest_update_step": self._latest_update_step}
def load_state_dict(self, state_dict):
self._latest_update_step = state_dict["latest_update_step"]
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
average_state = self._averaged_model.state_dict()
checkpoint["current_model_state"] = checkpoint["state_dict"]
checkpoint["state_dict"] = {
name[7:]: value for name, value in average_state.items() if name.startswith("module.")
}
checkpoint["model_averaging_state"] = {
name: value for name, value in average_state.items() if not name.startswith("module.")
}
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
if ("current_model_state" in checkpoint) and ("model_averaging_state" in checkpoint):
average_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
average_state |= checkpoint["model_averaging_state"]
self._averaged_model.load_state_dict(average_state)
checkpoint["state_dict"] = checkpoint["current_model_state"]
else:
self._averaged_model.module.load_state_dict(deepcopy(checkpoint['state_dict']), strict=False)
def _swap_models(self, pl_module):
average_params = itertools.chain(self._averaged_model.module.parameters(), self._averaged_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
tmp = average_param.data.clone()
average_param.data.copy_(current_param.data)
current_param.data.copy_(tmp) In principle it could be something like this. Things to consider:
I'm happy to create a pull request if this sounds good to you, but I would need some feedback, since I'm not sure how to verify that this works correctly. |
Hey @senarvi this looks great. Happy to take the PR, it will be important to have an integration test to make sure things are computed correctly, I'd just start from BoringModel or something more involved among the test models we have. |
Reading this thread, I am surprised that this is not yet implemented in 2025 |
I created a draft pull request. I would be happy to hear your opinion, especially on the two questions in the PR. |
🚀 Feature
How about add EMA as callback?
Motivation
I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.
Pitch
If user add ema as callback, ema is applied for validation and test.
Alternatives
Of course, you can add ema as tutorial. like below snippets
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda
The text was updated successfully, but these errors were encountered: