šŸ¶
Tensorflow

Saving and Restoring ML Models After Training

By Ondřej DolanskĆ½ on 12/04/2024

Learn how to save and restore your trained machine learning models for later use, ensuring reproducibility and efficient development.

Saving and Restoring ML Models After Training

Table of Contents

Introduction

In the realm of machine learning, the ability to save and restore models is paramount. This process facilitates the resumption of training, deployment of models into production environments, and seamless sharing of research and development efforts. This article provides a concise guide on how to save and restore models using two popular deep learning libraries: PyTorch and TensorFlow.

Step-by-Step Guide

Saving and restoring a model in machine learning is crucial for resuming training, deploying models, and sharing work. Here's a breakdown:

1. Saving a Model

  • PyTorch

    torch.save(model.state_dict(), 'model_weights.pth')

    This saves the model's learned parameters (weights and biases). To save the optimizer state as well:

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'checkpoint.pth')
  • TensorFlow

    saver = tf.train.Saver()
    saver.save(sess, 'my_model') 

    This saves the entire model graph and parameters.

2. Restoring a Model

  • PyTorch

    model = MyModel() 
    model.load_state_dict(torch.load('model_weights.pth'))
    model.eval()  # Set to evaluation mode if not training further

    To restore from a checkpoint with optimizer state:

    checkpoint = torch.load('checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  • TensorFlow

    saver = tf.train.Saver()
    saver.restore(sess, 'my_model')

Key Points

  • Saving Frequency: Save models periodically during training to avoid losing progress.
  • Optimizer State: Save the optimizer state if you want to resume training from where you left off.
  • Evaluation Mode: Set the model to evaluation mode (model.eval() in PyTorch) after loading if you're not continuing training.
  • File Formats: PyTorch commonly uses .pth or .pt, while TensorFlow uses various formats depending on the saving method.
  • Environment Consistency: Ensure the environment where you restore the model has the same dependencies as the one where it was saved.

Code Example

The code provides examples of saving and restoring models in PyTorch and TensorFlow. It shows how to save model weights, optimizer states, and entire models. It also demonstrates how to load the saved data to resume training or perform inference. The PyTorch example includes saving both model weights separately and as a checkpoint with optimizer state. The TensorFlow example utilizes Keras for model definition and provides a straightforward way to save and load the entire model.

Here are complete code examples demonstrating saving and restoring models in PyTorch and TensorFlow, incorporating the key points mentioned:

PyTorch Example

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# Initialize model, optimizer, and sample data
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters())
data = torch.randn(10)

# ... Training loop ...

# Saving the model
torch.save(model.state_dict(), 'model_weights.pth')

# Saving a checkpoint with optimizer state
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')

# ... Later ...

# Restoring just the model weights
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('model_weights.pth'))
loaded_model.eval()  # Set to evaluation mode

# Restoring from a checkpoint
loaded_model = SimpleModel()
loaded_optimizer = torch.optim.Adam(loaded_model.parameters())
checkpoint = torch.load('checkpoint.pth')
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# ... Continue training or use the loaded model for inference ...

TensorFlow Example

import tensorflow as tf

# Define a simple model
model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(10, activation='relu', input_shape=(10,)),
  tf.keras.layers.Dense(1)
])

# Compile the model
model.compile(optimizer='adam', loss='mse')

# ... Training loop ...

# Saving the model
model.save('my_model')

# ... Later ...

# Restoring the model
loaded_model = tf.keras.models.load_model('my_model')

# ... Use the loaded model for inference or further training ...

Explanation:

  • PyTorch: We define a simple model, optimizer, and demonstrate saving both model weights and a checkpoint containing optimizer state. We then show how to load both types of saved data.
  • TensorFlow: We define and compile a simple model using Keras and demonstrate saving and loading the entire model using model.save and tf.keras.models.load_model.

Remember to adapt these examples to your specific models, optimizers, and file paths.

Additional Notes

  • Model Complexity and Storage: Saving the entire model (like in TensorFlow's model.save) is convenient but can result in larger file sizes, especially for complex architectures. Saving only the state dictionary (PyTorch) offers more control over what's stored and can be more efficient.

  • Version Control: Treat your saved models like code! Use version control systems (like Git) to track changes to your models, especially if you're experimenting with different architectures or hyperparameters. This ensures reproducibility and allows you to revert to previous versions if needed.

  • Model Checkpointing: Instead of saving only at the end of training, implement model checkpointing to save the model at regular intervals (e.g., after every epoch or a certain number of steps). This is crucial for long training runs, as it prevents data loss due to unexpected interruptions.

  • Early Stopping and Best Model Saving: Integrate early stopping techniques to prevent overfitting. When early stopping is triggered, save the model with the best performance on your validation set. This ensures you're saving the model with the highest generalization capability.

  • Cloud Storage: For collaborative projects or if you need easy access to your models from different machines, consider storing your models on cloud storage services like AWS S3, Google Cloud Storage, or Azure Blob Storage.

  • Model Deployment: When deploying models, consider the format required by your deployment environment. You might need to convert your saved model to a specific format (e.g., ONNX, TensorFlow Lite) for optimized inference on different platforms.

  • Security Considerations: Be mindful of security risks when sharing or downloading models. Ensure the source is trustworthy and scan downloaded models for potential vulnerabilities.

  • Documentation: Always document your saving and loading procedures clearly. Specify the framework used, file formats, dependencies, and any specific steps required to load and use the model correctly.

Summary

Feature PyTorch TensorFlow
Saving Model Parameters torch.save(model.state_dict(), 'model_weights.pth') saver = tf.train.Saver()
saver.save(sess, 'my_model')
Saving Optimizer State ```python
torch.save({
   'model_state_dict': model.state_dict(),
   'optimizer_state_dict': optimizer.state_dict(),

}, 'checkpoint.pth')

| **Restoring Model Parameters** | ```python
model = MyModel() 
model.load_state_dict(torch.load('model_weights.pth')) 
``` | ```python
saver = tf.train.Saver()
saver.restore(sess, 'my_model')
``` |
| **Restoring Optimizer State** | ```python
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
``` | Loaded as part of the model |
| **File Formats** | `.pth`, `.pt` | Varies |
| **Notes** |  - Set to evaluation mode after loading if not training: `model.eval()` | - Saves the entire model graph and parameters | 

**General Best Practices:**

* **Save frequently:** Regularly save your model during training to avoid losing progress.
* **Save optimizer state:** This allows you to resume training from where you left off.
* **Environment consistency:** Ensure the environment where you restore the model has the same dependencies as the one where it was saved. 

## Conclusion

By mastering the techniques of saving and restoring models, machine learning practitioners can streamline their workflows, enhance collaboration, and ensure the longevity and reproducibility of their work. Whether you're using PyTorch, TensorFlow, or another framework, understanding these principles is essential for anyone serious about building, deploying, and sharing machine learning models effectively.  The provided code examples offer a practical starting point for implementing these techniques in your own projects. Remember to adapt them to your specific models and tasks, and always prioritize best practices for saving frequency, environment consistency, and documentation. By doing so, you can ensure that your hard work is preserved and that your models are always ready for action. 

## References

* ![Restoring optimizer and model from saved state not fully ...](https://discuss.pytorch.org/uploads/default/original/2X/b/b89041672ce48619c3d754f7b4e32ef9402a1a1b.png) [Restoring optimizer and model from saved state not fully ...](https://discuss.pytorch.org/t/restoring-optimizer-and-model-from-saved-state-not-fully-reproducing-training-results/9463) | Iā€™m trying to continue training after saving my models and optimizers. However, it seems some part of the optimizer (Adam) is not being saved, because when I restart training from a checkpoint, the values move rapidly from the old training path, but then stabilize again. For example, the following three plots show this, with each line being a single trial, where the second line is the loaded version of the trial. Note, this is a GAN, so these values are not all expected to nicely descend, and yo...
* ![How to save a trained tensorflow model for later use for application ...](https://cdn.sstatic.net/Sites/stackoverflow/Img/apple-touch-icon@2.png?v=73d79a89bded) [How to save a trained tensorflow model for later use for application ...](https://stackoverflow.com/questions/38641887/how-to-save-a-trained-tensorflow-model-for-later-use-for-application) | Jul 28, 2016 ... Add ops to save and restore all the variables. saver = tf.train.Saver() Later, launch the model, initialize the variables, do some work, save the variables toĀ ...
* ![How to save/restore model using artifact on servers that do not have ...](https://global.discourse-cdn.com/flex020/uploads/wandb/original/1X/366b649231631dbab896843020da0056074ac79d.png) [How to save/restore model using artifact on servers that do not have ...](https://community.wandb.ai/t/how-to-save-restore-model-using-artifact-on-servers-that-do-not-have-internet-access/1313) | Hi,  I just started using wandb tools. According to the instruction here, it suggests using Artifact for new code to save models. And I am able to save the model in the offline mode. However, I wonder how to restore the model from an artifact with a particular version (e.g., v3, not necessarily the latest version of the artifact) if I want to resume the training after itā€™s interrupted?  I am running code on compute nodes that do not have access to internet, so I have to use the offline mode. And...
* ![python - Saving tensorflow model after training is finished - Stack ...](https://cdn.sstatic.net/Sites/stackoverflow/Img/apple-touch-icon@2.png?v=73d79a89bded) [python - Saving tensorflow model after training is finished - Stack ...](https://stackoverflow.com/questions/36227354/saving-tensorflow-model-after-training-is-finished) | Mar 25, 2016 ... It seems like from the tensorflow documentation, the "session" is the thing that holds the information from the trained model.
* ![Save only best model in Trainer - Transformers - Hugging Face ...](https://global.discourse-cdn.com/hellohellohello/original/2X/d/de4155eb4aa4108ecb32a1389d7cc37ae69f88b7.png) [Save only best model in Trainer - Transformers - Hugging Face ...](https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442) | I have read previous posts on the similar topic but could not conclude if there is a workaround to get only the best model saved and not the checkpoint at every step, my disk space goes full even after I add savetotallimit as 5 as the trainer saves every checkpoint to disk at the start.  Please suggest.  Thanks
* ![python - How do I save a trained model in PyTorch? - Stack Overflow](https://cdn.sstatic.net/Sites/stackoverflow/Img/apple-touch-icon@2.png?v=73d79a89bded) [python - How do I save a trained model in PyTorch? - Stack Overflow](https://stackoverflow.com/questions/42703500/how-do-i-save-a-trained-model-in-pytorch) | Mar 9, 2017 ... Since you are resuming training, DO NOT call model.eval() once you restore the states when loading. Case # 3: Model to be used by someone elseĀ ...
* ![Create_learner from load_learner - Part 1 (2019) - fast.ai Course ...](https://forums.fast.ai/uploads/default/original/3X/b/3/b395de7a2ba00b82865031c97a8cadf3d80e71e5.png) [Create_learner from load_learner - Part 1 (2019) - fast.ai Course ...](https://forums.fast.ai/t/create-learner-from-load-learner/58294) | Iā€™m using an audio dataset from kaggle with two sets: curated and noisy (i could simply combine these, but wanna do it properly). I managed to use cnn to learn the curated set on the audio spectrograms, and I exported the results. Day later I start my jupiter and want to load this export - hereā€™s how far I get:  data = ImageDataBunch.from_csv(     path=path_project.resolve(),     folder=path_data_noisy.relative_to(path_project),     csv_labels=path_csv_noisy_png,     ds_tfms=None,     size=224, ...
* ![python - Save a tensorflow model after a fixed training time - Stack ...](https://cdn.sstatic.net/Sites/stackoverflow/Img/apple-touch-icon@2.png?v=73d79a89bded) [python - Save a tensorflow model after a fixed training time - Stack ...](https://stackoverflow.com/questions/58096219/save-a-tensorflow-model-after-a-fixed-training-time) | Sep 25, 2019 ... Of course, you can define a callback function delegated to stop the training phase. You can have a look here for further information:
* ![Data sampling strategy and model save/restore Ā· Issue #57 Ā· lululxvi ...](https://opengraph.githubassets.com/8901438e5766450c181d5770d3423f9671e38cc154c8918072e304361020fe24/lululxvi/deepxde/issues/57) [Data sampling strategy and model save/restore Ā· Issue #57 Ā· lululxvi ...](https://github.com/lululxvi/deepxde/issues/57) | Dear Lu, I have another puzzle when running the code. Let's say we sample 50 points along the boundary, where bcs = [bc_rectX,bc_rectY, bc_circleX, bc_circleY] data = dde.data.PDE(geom, pde, bcs, n...

Were You Able to Follow the Instructions?

šŸ˜Love it!
šŸ˜ŠYes
šŸ˜Meh-gical
šŸ˜žNo
šŸ¤®Clickbait