Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature Exponential Moving Average (EMA) #10914

Open
hankyul2 opened this issue Dec 3, 2021 · 59 comments · May be fixed by #20545
Open

Add feature Exponential Moving Average (EMA) #10914

hankyul2 opened this issue Dec 3, 2021 · 59 comments · May be fixed by #20545
Labels
feature Is an improvement or enhancement

Comments

@hankyul2
Copy link
Contributor

hankyul2 commented Dec 3, 2021

🚀 Feature

How about add EMA as callback?

Motivation

I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.

Pitch

If user add ema as callback, ema is applied for validation and test.

Alternatives

Of course, you can add ema as tutorial. like below snippets

class EMA(nn.Module):
    """ Model Exponential Moving Average V2 from timm"""
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)
    

class BasicModule(LightningModule):
    def __init__(self, lr=0.01, use_ema=False):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model_ema = EMA(self.model, decay=0.9) if use_ema else None
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        
        metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
        self.train_metric = metric.clone(prefix='train_')
        self.valid_metric = metric.clone(prefix='valid_')
    
    def training_step(self, batch, batch_idx, optimizer_idx=None):
        return self.shared_step(*batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(*batch, self.valid_metric)

    def shared_step(self, x, y, metric):
        y_hat = self.model(x) if self.training or self.model_ema is None else self.model_ema.module(x)
        loss = self.criterion(y_hat, y)
        self.log_dict(metric(y_hat, y), prog_bar=True)
        return loss

    def configure_optimizers(self):
        return SGD(self.model.parameters(), lr=self.lr)

    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:
            self.model_ema.update(self.model)

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda

@hankyul2 hankyul2 added the feature Is an improvement or enhancement label Dec 3, 2021
@justusschock
Copy link
Member

Hi,
as stated in #8100 (comment) this can be done by replacing just one part of our SWA.

@hankyul2
Copy link
Contributor Author

hankyul2 commented Dec 3, 2021

@justusschock thank you for your reply.

Is there a way to use ema using swa?

I checked the link. For me, swa seems to update lr scheduler and model weights together. Am I right?

@mathemusician
Copy link
Contributor

mathemusician commented Dec 4, 2021

@hankyul2, I believe this is how it would be implemented:

from pytorch_lightning.callbacks import StochasticWeightAveraging

class EMA_Callback(StochasticWeightAveraging):
    def __init__(self, decay=0.9999):
        super().__init__()
        self.decay = decay
    
    def avg_fn (
        averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
    ) -> torch.FloatTensor:
        e = averaged_model_parameter
        m = model_parameter
        return self.decay * e + (1. - self.decay) * m

Let me know if you have any problems. I just learned about SWA because of your issue!

@hankyul2
Copy link
Contributor Author

hankyul2 commented Dec 4, 2021

@mathemusician thank you.

@hankyul2 hankyul2 closed this as completed Dec 4, 2021
@hal-314
Copy link

hal-314 commented Dec 10, 2021

@hankyul2 I don't think that @mathemusician solution is equivalent. avg_fn is called once per epoch while EMA updates happens every training step. I don't think that EMA can be implemented with SWA callback,

@hankyul2
Copy link
Contributor Author

hankyul2 commented Dec 10, 2021

@hal-314 yeah. I think so.

@hankyul2 hankyul2 reopened this Dec 10, 2021
@hankyul2
Copy link
Contributor Author

hankyul2 commented Dec 11, 2021

@hal-314 @mathemusician

I have implemented EMA Callback with simple functionality.

  • update weight
  • change forward() in validation_loop

I think much more options should be added. For example, save_weight, ema_step_period, etc.

If you find it helpful, please let me know. Then I will implement it more. If you don't, close this issue or leave any comments.

from copy import deepcopy

import torch
from pytorch_lightning import Callback


class EMACallback(Callback):
    def __init__(self, decay=0.995):
        self.decay = decay
        self.module_pair_list = []

    def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        def forward_wrapper(module, org, ema):
            def forward(*args, **kwargs):
                return org(*args, **kwargs) if module.training else ema(*args, **kwargs)
            return forward

        modules = list(filter(lambda x: len(list(x[1].parameters())) > 0, pl_module.named_children()))

        for name, module in modules:
            ema_module = deepcopy(module)
            self.module_pair_list.append((ema_module, module))
            pl_module.add_module(f'EMA_{name}', ema_module)
            module.forward_bc = module.forward
            module.forward = forward_wrapper(module, module.forward_bc, ema_module.forward)

    def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        for ema_module, module in self.module_pair_list:
            self._update(ema_module, module, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def _update(self, ema_module, module, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(ema_module.state_dict().values(), module.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

@hal-314
Copy link

hal-314 commented Dec 13, 2021

@hankyul2 I like more the approach from timm as it doesn't involve changing forward method. Users could not use the forward method in xxx_step method. So, I would recommend your first code but updating ema weights in on_train_batch_end instead of on_before_backward to match with timm train.py code.

Finally, I manage to implemented EMA as a callback by using state_dicts instead of the whole module as ModelEMAV2. I can try to make a PR or, at least, post the code here if you or anyone is interested (and I get permission for it). It works for 1 GPU and, likely, for Multi GPU though it isn't tested.

@hankyul2
Copy link
Contributor Author

@hal-314 Cool 😎. I want to see it. Can you share your code?

@hal-314
Copy link

hal-314 commented Dec 14, 2021

@hankyul2 Here is the code. Be aware that you need overrides package installed (pip install overrides). If you don't want it, comment the import and the @overrides decorator. I only use it to be sure that I'm actually overriding the method correctly.

Bits to be aware:

  • I didn't test multi-gpu path though it should work. If not, there are a couple of asserts that will fail. If your setup is multigpu, could you check that works fine?
  • You can override get_state_dict method to filter some parameters that you don't want to include in EMA. For example, those in metrics. See its doc.
  • EMA weights are saved as "ema_state_dict" in the callback state. So, you can retrieve them manually outside lightning.
  • Tested with Pytorch 1.8 & 1.10 and Pythorch Lightning 1.5.4

I hope it's useful.

from copy import deepcopy
from typing import Optional, Union, Dict, Any

import pytorch_lightning as pl
import torch
from overrides import overrides
from pytorch_lightning.utilities import rank_zero_only


class EMA(pl.Callback):
    """Implements EMA (exponential moving average) to any kind of model.
    EMA weights will be used during validation and stored separately from original model weights.

    How to use EMA:
        - Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See
          https://github.com/rwightman/pytorch-image-models/issues/102
        - Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
          discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and
          https://github.com/rwightman/pytorch-image-models/issues/224
        - For object detection, SWA usually works better. See   https://github.com/timgaripov/swa/issues/16

    Implementation detail:
        - See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
        - When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
          This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
          resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
          performance.
    """
    def __init__(self, decay: float = 0.9999, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True):
        super().__init__()
        self.decay = decay
        self.ema_device: str = f"{ema_device}" if ema_device else None  # perform ema on different device from the model
        self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False  # Only works if CUDA is available
        self.ema_state_dict: Dict[str, torch.Tensor] = {}
        self.original_state_dict = {}
        self._ema_state_dict_ready = False

    @staticmethod
    def get_state_dict(pl_module: pl.LightningModule):
        """Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
        For example, in pl_module has metrics, you don't want to return their parameters.
        
        code:
            # Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
            # like losses, metrics, etc.
            patterns_to_ignore = ("metrics1", "metrics2")
            return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
        """
        return pl_module.state_dict()
        
    @overrides
    def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
        # Only keep track of EMA weights in rank zero.
        if not self._ema_state_dict_ready and pl_module.global_rank == 0:
            self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
            if self.ema_device:
                self.ema_state_dict = {k: tensor.to(device=self.ema_device) for k, tensor in self.ema_state_dict.items()}

            if self.ema_device == "cpu" and self.ema_pin_memory:
                self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()}

        self._ema_state_dict_ready = True

    @rank_zero_only
    def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
        # Update EMA weights
        with torch.no_grad():
            for key, value in self.get_state_dict(pl_module).items():
                ema_value = self.ema_state_dict[key]
                ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)

    @overrides
    def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
        pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
        assert self.ema_state_dict.keys() == self.original_state_dict.keys(), \
            f"There are some keys missing in the ema static dictionary broadcasted. " \
            f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
        pl_module.load_state_dict(self.ema_state_dict, strict=False)

        if pl_module.global_rank > 0:
            # Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
            self.ema_state_dict = {}

    @overrides
    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        # Replace EMA weights with training weights
        pl_module.load_state_dict(self.original_state_dict, strict=False)

    @overrides
    def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> dict:
        return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}

    @overrides
    def on_load_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
    ) -> None:
        self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"]
        self.ema_state_dict = callback_state["ema_state_dict"]

@justusschock @mathemusician Do you think that lightning should include an EMA callback? Maybe, it could go to bolts.

@hankyul2
Copy link
Contributor Author

@hal-314 wow... I think it is good.(but I am not a maintainer or something)

@justusschock
Copy link
Member

Personally I think we should include this. However, not sure where it belongs (Currently I'd say not to lightning core but either to flash or bolts)

cc @tchaton @ethanwharris for opinions on this

@flukeskywalker
Copy link

@hal-314 thanks for your implementation! I agree that EMA would be very useful to have in Lightning.

I tested your implementation with default parameters (so ema_device=None etc) and it seems to work well on a single GPU. In multi-gpu, the assertion in on_validation_start fails on GPUs other than 0. It appears that the ema_state_dict is not broadcast successfully to all devices (it is an empty dict).

@hal-314
Copy link

hal-314 commented Dec 21, 2021

@flukeskywalker Glad to de that it's useful for you :)

If you fix the multi gpu code, could you mind to share the fix? So, others can use it.

@hankyul2
Copy link
Contributor Author

hankyul2 commented Dec 28, 2021

@hal-314 Thank you for sharing your code. I test it with 2 gpus in ddp mode.
When pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0) is executed, OOM error occurs.
Can I ask you how validation steps works in ddp mode?? or any document that I can reference?

Whole error logs is in below.

Traceback (most recent call last):
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
    return self._run_train()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
    self.fit_loop.run()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 146, in run
    self.on_advance_end()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 242, in on_advance_end
    self._run_validation()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 337, in _run_validation
    self.val_loop.run()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 140, in run
    self.on_run_start(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 95, in on_run_start
    self._on_evaluation_start()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 179, in _on_evaluation_start
    self.trainer.call_hook("on_validation_start", *args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
    callback_fx(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 216, in on_validation_start
    callback.on_validation_start(self, self.lightning_module)
  File "/home/hankyul/private/SuperConvergence/src/ema.py", line 134, in on_validation_start
    pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 411, in broadcast
    broadcast_object_list(obj, src, group=_group.WORLD)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1840, in broadcast_object_list
    object_list[i] = _tensor_to_object(obj_view, obj_size)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1532, in _tensor_to_object
    return _unpickler(io.BytesIO(buf)).load()
  File "/opt/conda/lib/python3.7/site-packages/torch/storage.py", line 161, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 608, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 787, in _legacy_load
    result = unpickler.load()
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 743, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
    result = fn(storage, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 155, in _cuda_deserialize
    return storage_type(obj.size())
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/__init__.py", line 606, in _lazy_new
    return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

@hal-314
Copy link

hal-314 commented Dec 28, 2021

@hankyul2 Sorry but I don't have experience with multi-gpu in lightning. From the stack, it seems that the OOM occurs when broadcasting the state. Here is where I find broadcast docs in Lightning.

@sevenights
Copy link

@hankyul2 I tried the code @hal-314 in my 2-gpus 1080ti machine. And it worked well. Maybe the out of memory error is just literally...

My experiment configuration as below:

Model: BertModelForSequenceClassification
max_length: 50
padding_to_max_length: True
batch_size: 4(per device)

When I trained without ema, the gpu memory usage is 3311MB each.
And turned on the ema, gpu 0 was 3851MB, gpu 1 was 3311MB.

@sevenights
Copy link

@hal-314 @hankyul2 sorry, I made a mistake in my last code. And I found the broadcast in on_validation_start didn't work as espect. I change it as follow:

@overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    if not self._ema_state_dict_ready:
        return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

    self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
    ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
    self.ema_state_dict = ema_state_dict

since the broadcast does not change variables in-place(in dist.py, when the self.rank != 0 is True). The callstack I found as follows:

pl_module.training_type_plugin.broadcast(self.ema_state_dict, 0)
-> ddp_spawn.broadcast(obj, src)
-> LightningDistributed.broadcast(obj, group)

and source code as follows

# pytorch_lightning/plugins/training_type/ddp_spawn.py
def broadcast(self, obj: object, src: int = 0) -> object:
    if not distributed_available():
        return obj
    return self.dist.broadcast(obj)
# pytorch_lightning/distributed/dist.py
class LightningDistributed:
    def __init__(self, rank=None, device=None):
        self.rank = rank
        self.device = device

    def broadcast(self, obj: Any, group=_group.WORLD):
        # always wrap into a list so it can be broadcasted.
        obj = [obj]

        if self.rank != 0:
            obj = [None] * len(obj)

        broadcast_object_list(obj, 0, group=group or _group.WORLD)

        return obj[0]

the gpu usage showed below:

without ema:

|    0   N/A  N/A     15733      C   ...da3/envs/pl149/bin/python     4645MiB |
|    1   N/A  N/A     15750      C   ...da3/envs/pl149/bin/python     4645MiB |

with ema:

|    0   N/A  N/A     12561      C   ...da3/envs/pl149/bin/python     5509MiB |
|    0   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     1227MiB |
|    1   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     5063MiB |

environment:
pytorch==1.8.0
pytorch-lightning==1.4.9
gpus: 1080ti * 2

@flukeskywalker
Copy link

I can confirm that @sevenights's fix works for me too with pytorch-lightning==1.5.8

@yoyolicoris
Copy link

I would love to see an implementation of EMA. I just migrated from ignite to lightning and lacking an EMA callback really holds me back.

@stale
Copy link

stale bot commented Feb 18, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 18, 2022
@stale stale bot closed this as completed Apr 28, 2022
@turian
Copy link
Contributor

turian commented Mar 26, 2023

I too would be quite interested in this feature, based upon a model I am replicating that uses EMA. @justusschock

@turian
Copy link
Contributor

turian commented Apr 16, 2023

@Borda just curious if where this is on the roadmap?

@Borda
Copy link
Member

Borda commented Apr 17, 2023

@Borda just curious if where this is on the roadmap?

I think we can add it as experimental callback :)

@turian
Copy link
Contributor

turian commented Apr 19, 2023

@Borda I would be super happy to beta test this ASAP.

The EMA callback in Nemo shared above is apache-2.0

@Borda
Copy link
Member

Borda commented Apr 20, 2023

@lantiga, what are your thoughts? Maybe implement it in Bolts?

@carmocca
Copy link
Contributor

carmocca commented Apr 27, 2023

PyTorch added support for this with pytorch/pytorch#94820 (commit)

@turian
Copy link
Contributor

turian commented Apr 27, 2023

@carmocca Perhaps a doc update then, given that lightning suggests SWA where EMA is sometimes superior?

@carmocca
Copy link
Contributor

I haven't tried the newly added EMA in PyTorch. Just wanted to share the info. If anybody gives it a shot, we would accept a docs contribution showing how to use it with Lightning.

@shivammehta25
Copy link
Contributor

shivammehta25 commented Jul 26, 2023

After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...

@jlotthammer
Copy link

Just curious about the status here - Is there a recommended approach for this that the community has settled on?

@Hans-digit
Copy link

I am also curious either. What's going on with important EMA callback..?

@lukasschmit
Copy link

Also would love this! SWA is great and feel like EMA should be even easier to add?

@KasuganoLove
Copy link

Sorry for the late response here, within NeMo I have added an EMA callback which we have tested/used. This is based on the PyTorch Lightning Callback, can be seen here: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py

We're doing some performance improvements that may require more involvement from NeMo, so to use separately would require some stripping down. Progress can be seen here: NVIDIA/NeMo#5169

That's great! Thank you.

@KasuganoLove
Copy link

After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...

NVIDIA/NeMo#5169 (comment)

I hope this will be helpful.

@zhong-yy
Copy link

@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?

@KasuganoLove
Copy link

@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?

@zhong-yy There's a evaluate_ema_weights_instead: bool = False in the EMA callback to choose whether to use the original model or the EMA model to evaluate the metrics.

@je-santos
Copy link

Any update on this?

@lyndonlauder
Copy link

This should have been added long ago. @williamFalcon please try to get your team to move this along, It has been a stopping block with lightning for many users for years now.

@KasuganoLove
Copy link

Our new project uses EMA as a callback for pytorch lightning. I hope it can be used as a reference. For details, see: https://github.com/nkicsl/Resfusion/blob/main/callback/ema.py

@alberthli
Copy link

Any update?

@senarvi
Copy link
Contributor

senarvi commented Dec 23, 2024

In my opinion it would be good to write a callback that uses the AveragedModel from PyTorch instead of basically copying the same functionality to Lightning. Fewer places that have to be updated when something needs to be fixed or added.

@senarvi
Copy link
Contributor

senarvi commented Jan 8, 2025

I mean something like this:

from lightning.pytorch.callbacks.callback import Callback
from torch.optim.swa_utils import AveragedModel


class ModelAveragingCallback(Callback):
    def __init__(self, device, avg_fn):
        self._device = device
        self._avg_fn = avg_fn
        self._averaged_model = None
        self._latest_update_step = -1

    def setup(self, trainer, pl_module, stage) -> None:
        if stage == "fit":
            device = self._device or pl_module.device
            self._averaged_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.global_step > self._latest_update_step:
            self._averaged_model.update_parameters(pl_module)
            self._latest_update_step = trainer.global_step

    def on_fit_end(self, trainer, pl_module):
        average_params = itertools.chain(self._averaged_model.module.parameters(), self._averaged_model.module.buffers())
        current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
        for average_param, current_param in zip(average_params, current_params):
            current_param.data.copy_(average_param.data)

    def on_validation_epoch_start(self, trainer, pl_module):
        if self._averaged_model is not None:
            self._swap_models(pl_module)

    def on_validation_epoch_end(self, trainer, pl_module):
        if self._averaged_model is not None:
            self._swap_models(pl_module)

    def state_dict(self):
        return {"latest_update_step": self._latest_update_step}

    def load_state_dict(self, state_dict):
        self._latest_update_step = state_dict["latest_update_step"]

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        average_state = self._averaged_model.state_dict()
        checkpoint["current_model_state"] = checkpoint["state_dict"]
        checkpoint["state_dict"] = {
            name[7:]: value for name, value in average_state.items() if name.startswith("module.")
        }
        checkpoint["model_averaging_state"] = {
            name: value for name, value in average_state.items() if not name.startswith("module.")
        }

    def on_load_checkpoint(self, trainer, pl_module, checkpoint):
        if ("current_model_state" in checkpoint) and ("model_averaging_state" in checkpoint):
            average_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
            average_state |= checkpoint["model_averaging_state"]
            self._averaged_model.load_state_dict(average_state)
            checkpoint["state_dict"] = checkpoint["current_model_state"]
        else:
            self._averaged_model.module.load_state_dict(deepcopy(checkpoint['state_dict']), strict=False)

    def _swap_models(self, pl_module):
        average_params = itertools.chain(self._averaged_model.module.parameters(), self._averaged_model.module.buffers())
        current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
        for average_param, current_param in zip(average_params, current_params):
            tmp = average_param.data.clone()
            average_param.data.copy_(current_param.data)
            current_param.data.copy_(tmp)

In principle it could be something like this. Things to consider:

  • We could have a separate class for stochastic weight averaging, or one class that is able to perform averaging after each training step or after each epoch.
  • This code averages the buffers too. We could add an option to use torch.optim.swa_utils.update_bn() instead, but that would make things more complicated, since we would need to load more data after the last epoch.

I'm happy to create a pull request if this sounds good to you, but I would need some feedback, since I'm not sure how to verify that this works correctly.

@lantiga
Copy link
Collaborator

lantiga commented Jan 9, 2025

Hey @senarvi this looks great. Happy to take the PR, it will be important to have an integration test to make sure things are computed correctly, I'd just start from BoringModel or something more involved among the test models we have.

@turian
Copy link
Contributor

turian commented Jan 12, 2025

Reading this thread, I am surprised that this is not yet implemented in 2025

@senarvi senarvi linked a pull request Jan 14, 2025 that will close this issue
8 tasks
@senarvi
Copy link
Contributor

senarvi commented Jan 14, 2025

I created a draft pull request. I would be happy to hear your opinion, especially on the two questions in the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.