Learn how to load a pre-trained Keras model and resume training to further improve its performance on your specific task.
In the realm of machine learning, the ability to save and load trained models is paramount. This allows us to preserve the knowledge gained during training and reuse models without retraining from scratch. This article provides a concise guide on how to save, load, and resume training Keras models using the HDF5 format, a popular choice for storing large binary objects. We'll cover the essential code snippets and highlight crucial considerations for a smooth and effective workflow.
Save your model:
model.save('my_model.h5')
This saves the model architecture, weights, and optimizer state.
Load your model:
from keras.models import load_model
model = load_model('my_model.h5')
Compile the model (if needed): If you didn't save the optimizer state or are changing optimizers, you'll need to compile:
model.compile(loss='...', optimizer='...', metrics=['...'])
Continue training:
model.fit(new_data, new_labels, epochs=..., initial_epoch=previous_epochs)
initial_epoch
should be set to the epoch you left off on.Important Considerations:
This Python code demonstrates how to save a trained Keras model to a file and then load it back to continue training or perform evaluation. It defines a simple convolutional neural network for image classification using the MNIST dataset. The code trains the model for a few epochs, saves it, and then loads it again to resume training from where it left off. Finally, it evaluates the model's performance on a test set. This approach is useful for breaking down training into multiple sessions, experimenting with different hyperparameters, or sharing trained models.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# --- 1. Define and Train Initial Model ---
# Create a simple model (replace with your own)
model = keras.Sequential(
[
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dense(10, activation="softmax"),
]
)
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Load sample data (MNIST)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Train for a few epochs
model.fit(x_train, y_train, epochs=5)
# --- 2. Save the Model ---
model.save('my_model.h5')
# --- 3. Load the Model and Resume Training (Simulating a New Session) ---
# Load the saved model
loaded_model = keras.models.load_model('my_model.h5')
# (Optional) Compile if needed (e.g., changing optimizer)
# loaded_model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Continue training from where you left off
# - Assume you trained for 5 epochs previously
# - Adjust learning rate if necessary
# - Be mindful of overfitting on new data
loaded_model.fit(x_train, y_train, epochs=10, initial_epoch=5)
# --- 4. Evaluate the Model ---
loss, accuracy = loaded_model.evaluate(x_test, y_test)
print(f"Loss: {loss}, Accuracy: {accuracy}")
Explanation:
Model Definition and Initial Training:
Saving the Model:
model.save('my_model.h5')
line saves the entire model (architecture, weights, optimizer state) to the 'my_model.h5' file.Loading and Resuming Training:
keras.models.load_model()
.initial_epoch=5
argument in model.fit()
ensures training continues from the 6th epoch.model.compile()
.Evaluation:
Key Points:
model.save('my_model.h5', save_format='h5', compression='gzip')
) if storage is a concern.load_model
using the custom_objects
argument.ModelCheckpoint
).This table summarizes the key steps and considerations for saving, loading, and resuming training of Keras models:
Task | Code | Explanation | Considerations |
---|---|---|---|
Save Model | model.save('my_model.h5') |
Saves the model architecture, weights, and optimizer state to a file named 'my_model.h5'. | |
Load Model | python from keras.models import load_model model = load_model('my_model.h5') |
Loads the saved model from the 'my_model.h5' file. | |
Compile Model (if needed) | model.compile(loss='...', optimizer='...', metrics=['...']) |
Compiles the model with the specified loss function, optimizer, and metrics. This is necessary if the optimizer state wasn't saved or if you're changing optimizers. | |
Resume Training | model.fit(new_data, new_labels, epochs=..., initial_epoch=previous_epochs) |
Continues training the model from where it left off. | - Optimizer state: Ensure it's saved and loaded correctly. - Learning rate: Adjust if needed, especially with new data. - Overfitting: Be cautious when adding significantly different data. |
Key takeaway: This process allows you to save your trained model and resume training later without starting from scratch, saving time and resources.
Saving, loading, and resuming training of Keras models is crucial for efficient machine learning workflows. By leveraging the HDF5 format, we can store a model's complete state, including its architecture, weights, and optimizer configuration. This enables us to reuse trained models, avoiding redundant computations and facilitating collaborative development. When resuming training, it's essential to consider the optimizer state, adjust the learning rate appropriately, and be mindful of potential overfitting, especially when introducing new data. This approach streamlines the model development process, allowing us to build upon previous successes and achieve better performance.