pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Hook to transform the model before loading the weights

Open awaelchli opened this issue 3 years ago • 5 comments

🚀 Feature

Motivation

This is a request from a user on Slack. In their use case, they need to transform the model early in the trainer execution, after the checkpoint was loaded, but early enough before the weights get copied into the model (in case of executing trainer.fit(model, ckpt_path=...)). This is currently only done in a special case of the QuantizationAwareTraining callback.

Pitch

Provide a hook that runs before the model gets reloaded, but after the weights have been loaded from the file, i.e., in the sequence below the hook should run roughly where self._checkpoint_connector._restore_quantization_callbacks() runs.

https://github.com/Lightning-AI/lightning/blob/291267c3bff8054ec438960857c9f2fec1d54899/src/pytorch_lightning/trainer/trainer.py#L1071-L1079

The hook should take as input the checkpoint dict, so that the user can load their metadata.


def on_resume_start(self, lightning_module, trainer, checkpoint):
    """Do something with the model before restoration of the trainer/model state"""


Alternatives

One could make the QuantizationAwareTraining._load_before_model hook public, but this is limited to the use of the quantization callback only.

Additional context

Slack conversation


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

cc @borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @jjenniferdai @rohitgr7 @akihironitta

awaelchli avatar Aug 30 '22 11:08 awaelchli

might be useful with SWA too.

rohitgr7 avatar Aug 30 '22 11:08 rohitgr7

@awaelchli is there any timeline for something like this? I have a use case that does custom transforms on the model before loading the checkpoint state and exposing something like _load_before_model would be really nice to have

d4l3k avatar Mar 08 '23 00:03 d4l3k

This isn't planned to be added at the moment unless there is a demand for it. Could you explain your use case, and maybe show some pseudo code, also explaining why applying your transformations can't be done in the existing hooks like LightningModule.load_state_dict or LightningModule.on_load_checkpoint()? That would be helpful :)

awaelchli avatar Mar 08 '23 23:03 awaelchli

@awaelchli Similar to the old QuantizationAwareTraining callback we have some code that conditionally modifies the model using torch.fx graph transforms. Thus we need to be able to inspect and modify the model and state_dict before it gets loaded.

This transformation callback needs to be able to handle both transformed models (for failure recovery) as well as untransformed models (for fine tuning with transformations)

load_state_dict/on_load_checkpoint is too late since if the checkpoint was based on the transformed model the state_dict can't be loaded to the pre-transformed model.

In the removal PR for the QAT callback it says "You can copy the callback code and maintain it yourself." but that's impossible since the hook required doesn't exist. cc @lightningforever https://github.com/Lightning-AI/lightning/pull/16750

PyTorch native has some nn.Module hooks that could be used for this purpose as well but they're currently not public https://github.com/pytorch/pytorch/issues/75287

d4l3k avatar Mar 08 '23 23:03 d4l3k

@awaelchli +1. It would be convenient for my use case. But for my use case, the hook would need to be called before configure_model saving.py#L182.

My model is constructed using parameters estimated from the data. Initially, I was passing them through the model constructor, but it stopped being a viable option when I started using the DDP training strategy. It became unviable to prepare data and estimate parameters in the constructor, so the data preparation moved to prepare_data and the model configuration moved to configure_model. Right now, I am saving the estimated data parameters in on_save_checkpoint, but it does not work for loading since the configure_model is called first in the saving.py#L182, before on_load_checkpoint. Right now, I am looking for a better setup, and a new hook would be an ideal solution; otherwise, I would need to override the loading logic.

rustamzh avatar Nov 24 '25 11:11 rustamzh