Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Lightning Checkpoint saving and Earlystopping #979

Open
omarequalmars opened this issue Nov 15, 2024 · 4 comments
Open

Pytorch Lightning Checkpoint saving and Earlystopping #979

omarequalmars opened this issue Nov 15, 2024 · 4 comments

Comments

@omarequalmars
Copy link

Recently I have been trying to train an SMP model using pytorch lightning Checkpoint saving and Earlystopping features in my training loop to save progress. it uses a torch.save() method to save the models at each user-specified condition as a .ckpt/.pth file. I used the torch.load() method to load the file:

checkpoint_path = "DeepLabV3Plus_Training\checkpointsaves\last.pth"

model = SegmentationModel.load_from_checkpoint(checkpoint_path, strict = True)

and I got an error saying that there are many keys that do not match in the state dict. I tried it with strict=False and ofc the model acted like there were no weights, showing just imagenet weight results.

Is there a way to make saving and loading the model using torch.save and torch.load work? I have no problem using the provided methods myself, I just want to integrate that into the pytorch lightning workflow. I made a custom callback that saves the model every x epochs:

    class SaveModelEveryNepochs(pl.Callback):
    def __init__(self, save_dir: str, save_interval: int = 10):

    self.save_dir = save_dir
    self.save_interval = save_interval

def on_train_epoch_end(self, trainer, pl_module):
    """
    Called at the end of each training epoch.

    Args:
        trainer: The trainer instance.
        pl_module: The Lightning module instance.
    """
    current_epoch = trainer.current_epoch
    if current_epoch % self.save_interval == 0:
        # Save the model using SMP's method
        model = pl_module.model  # Assuming your model is stored in pl_module
        model.save_pretrained(f"{self.save_dir}/model_epoch_{current_epoch}")

However, I'd still like it to save the model in the same way in EarlyStopping. Additionally, the default lightning checkpoint callback gives the option to save: the top k models according to user criteria, saving last k models, saving every x epochs, etc.

Is there a way to integrate the existing save_pretrained() and the loading counterpart into the existing EarlyStopping and ModelCheckpoint objects? This would be immensely useful for long experiments, which mine is.

@qubvel
Copy link
Collaborator

qubvel commented Nov 29, 2024

Hey @omarequalmars! Indeed, it might be useful to be able to load a PyTorch-Lightning checkpoint. However, the way you create a Lightning module can be different, so it might be helpful to be able to provide a "prefix" to remove or add in the state dict to load from the Lightning checkpoint. For example, prefix_to_remove="model.". What do you think?

@omarequalmars
Copy link
Author

omarequalmars commented Dec 4, 2024

I'm not sure, I'm not entirely familiar of how the existing method in SMP loads and saves model when it comes to syntax of the state_dict. If possible, i'd like some direction on where exactly to find that. I see that it saves a config.json that contains all the architecture hyperparameters, which it loads separately then it uses the other file .safetensors to load the actual weights. On the other hand, the native pytorch method actually consists of the user declaring the model class a-priori, then using torch.load(weight_only = True) or torch.load_state_dict().

From your comment and by running an example trial code:

import torch
import segmentation_models_pytorch as smp

checkpoint = torch.load(r'DeepLabV3Plus_Training\runs\experiment_diceloss_bigconvs\mobileone_s0\version_0\checkpoints\epoch=88-step=1157.ckpt')

print(checkpoint['state_dict'].keys())

model = smp.from_pretrained(r'DeepLabV3Plus_Training\mysaves\mobileone_s0\model_epoch_0')

sdict = model.state_dict()
print("-----------------------")

print(sdict.keys())

Getting the following output (shortened):

checkpoint = torch.load(r'DeepLabV3Plus_Training\runs\experiment_diceloss_bigconvs\mobileone_s0\version_0\checkpoints\epoch=88-step=1157.ckpt')
odict_keys(['model.encoder.stage0.rbr_conv.0.conv.weight', 'model.encoder.stage0.rbr_conv.0.bn.weight'
....
Loading weights from local directory
-----------------------
odict_keys(['encoder.stage0.rbr_conv.0.conv.weight', 'encoder.stage0.rbr_conv.0.bn.weight

I understand it seems that all it takes is to just remove the 'model.' in case of loading the safetensors checkpoint or adding a model. in case of loading a .pth or .ckpt checkpoint no? Then the from_pretrained() should just proceed with the same state_dict, just with the keys modified accordingly. It seems this can be efficiently implemented with a str.split() method. This is as far as my understanding goes, am I correct? What do you think? @qubvel

@qubvel
Copy link
Collaborator

qubvel commented Dec 4, 2024

Yeah, it seems like you just have to remove the prefix from the state dict.

It's not necessary to use the from_pretrained method; you can proceed with model.load_state_dict in your particular case. Load the checkpoint with torch.load, remove the prefix, and then pass it to the load_state_dict method. It's not one line of code, but it's not a huge overhead actually.

checkpoint = torch.load(r'DeepLabV3Plus_Training\runs\experiment_diceloss_bigconvs\mobileone_s0\version_0\checkpoints\epoch=88-step=1157.ckpt')
state_dict = checkpoint["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)

It can be added to from_pretrained method as well with overriding load_state_dict or using load pre-hook.

Let me know what you think, might be a nice feature to support lightning checkpoint loading in from_pretrained loading IMO

@omarequalmars
Copy link
Author

Sounds good, would be very convenient for people running large number of experiments (me) using pytorch lightning and want to automate obtaining info from checkpoints. I think it would be a great addition to implement this in from_pretrained(), however in order to do this we will have to require the user to save all parameters (depth, stride, etc.) that the SMP model instance takes in the .yaml file. So that it can be converted to dictionary to be passed internally with from_pretrained() to create the model instance with the correct hyperparameters.

Alternatively, we can make use of the ['hyperparameters'] key that is loaded when torch.load() is called:

model_checkpoint_path = r"noKD\runs\tu-mobilevit_xxs_DLV3+\version_0\checkpoints\epoch=159-step=2080.ckpt"

checkpoint = torch.load(model_checkpoint_path)
print(checkpoint['hyper_parameters'])

output:
{'encoder': 'tu-mobilevit_xxs', 'depth': 5, 'encoder_output_stride': 16, 'channels': 512, 'decoder_atrous_rates': (16, 32, 128), 'lr_init': 0.001, 'weight_decay': 0, 'temperature': 2, 'Arch': 'DLV3+'}

The issue here is that users could be arbitrarily naming the hyper parameter dictionary keys, or handling additional hyperparameters not relevant to the model instance itself but relevant for other things (such as my case, running Knowledge Distillation experiments).

It would warrant a tutorial on readthedocs or at least an example code on the repo so that it's usage would be intuitive.

I tested out your code a bit and I ran into an 'edge' case which I did not expect:

import torch
import segmentation_models_pytorch as smp

model_checkpoint_path = r"noKD\runs\tu-mobilevit_xxs_DLV3+\version_0\checkpoints\epoch=159-step=2080.ckpt"

checkpoint = torch.load(model_checkpoint_path)
encoder = 'tu-mobilevit_xxs'
hparams = {
    'encoder': encoder,
    'depth': 5,
    'encoder_output_stride': 16,
    'channels': 512,
    'decoder_atrous_rates': (16, 32, 128),
    'lr_init': 1e-3,
    'weight_decay': 0,
    'temperature': 2,
    'Arch': "DLV3+"
}
model = smp.DeepLabV3Plus(
        encoder_name= hparams['encoder'],
        encoder_depth= hparams['depth'],
encoder_output_stride= hparams['encoder_output_stride'], 
    decoder_channels= hparams['channels'],
    decoder_atrous_rates= hparams['decoder_atrous_rates'],
    encoder_weights='imagenet',
    in_channels=2,
        classes = 3,
        activation = 'identity',
    aux_params = None)

sdict = checkpoint['state_dict']
state_dict = {k.replace("model.", ""): v for k, v in sdict.items()}
print(sdict.keys())
print(state_dict.keys())

Output (shortened):

odict_keys(['model.encoder.model.stem.conv.weight', 'model.encoder.model.stem.bn.weight', dict_keys(['encoder.stem.conv.weight', 'encoder.stem.bn.weight',
Apparently, some state_dicts have substring 'model.' multiple times in some of the state dict keys, which causes some weights to be missing. I ran into this while trying to load up one of my checkpoints in a test code:

    checkpoint = torch.load(model_checkpoint_path)
    sdict = checkpoint['state_dict']
    state_dict = {k.replace("model.", ""): v for k, v in sdict.items()}

    model.load_state_dict(state_dict, strict = True)
RuntimeError: Error(s) in loading state_dict for DeepLabV3Plus:
        Missing key(s) in state_dict: "encoder.model.stem.conv.weight", "encoder.model.stem.bn.weight", "encoder.model.stem.bn.bias", "encoder.model.stem.bn.running_mean", "encoder.model.stem.bn.running_var", "encoder.model.stages_0.0.conv1_1x1.conv.weight", "encoder.model.stages_0.0.conv1_1x1.bn.weight", "encoder.model.stages_0.0.conv1_1x1.bn.bias", "encoder.model.stages_0.0.conv1_1x1.bn.running_mean", "encoder.model.stages_0.0.conv1_1x1.bn.running_var", "encoder.model.stages_0.0.conv2_kxk.conv.weight", "encoder.model.stages_0.0.conv2_kxk.bn.weight", "encoder.model.stages_0.0.conv2_kxk.bn.bias", "encoder.model.stage

I think the solution to this is to only remove the first "model." in the string.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants