[Learn about machine learning from the Keras] — 19.Use callbacks.ModelCheckpoint to find the best model

Czxdas
3 min readSep 22, 2023

First, let’s take a look at the quantitative changes in compiler metrics after model fit.

1.Example 1, default metrics content

from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

from tensorflow.keras import layers
from tensorflow.keras.models import Model
model = Sequential([
layers.Dense(512, activation="relu"),
layers.Dense(10, activation="softmax")
])

model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
)

historyTrain = model.fit(train_images, train_labels, epochs=1, batch_size=128 )

After completing model.fit, check historyTrain.history.keys(). There is only one “loss” by default.

2.Example 2, continuing before Example , model.compiler specifies adding metrics

model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)

After completing model.fit, check historyTrain.history.keys(). With the setting of metrics=[“accuracy”], there is an additional observation indicator “accuracy”

3.Example code: Continuation of Example 2, add parameter validation_split=0.2 to model.fit

historyTrain = model.fit(train_images, train_labels, epochs=1, batch_size=128, validation_split=0.2 )

From the above three examples, we can see how to generate observation indicators. What this section will discuss is using callbacks.ModelCheckpoint to find the best model.

After finally generating four observation indicators from the above three examples, we can use the callbacks.ModelCheckpoint class to declare the best results of which indicators to use to save the trained model.

So add callbacks.ModelCheckpoint as follows:

from keras import callbacks

Best_ValAcc_Model = callbacks.ModelCheckpoint(filepath="ModelValAacc",monitor="val_accuracy",mode="max")

Best_ValLoss_Model = callbacks.ModelCheckpoint(filepath="ModelValLoss",monitor="val_loss",mode="min")

historyTrain = model.fit(train_images, train_labels, epochs=50, batch_size=128, validation_split=0.2, callbacks = [Best_ValAcc_Model,Best_ValLoss_Model] )

import matplotlib.pyplot as plt
plt.xlabel('epoch', fontsize=12)
plt.ylabel('loss_value', fontsize=12)
plt.plot(historyTrain.history['val_loss'],color='red')
plt.plot(historyTrain.history['loss'],color='blue')
plt.show()
plt.xlabel('epoch', fontsize=12)
plt.ylabel('accuracy_value', fontsize=12)
plt.plot(historyTrain.history['val_accuracy'],color='red')
plt.plot(historyTrain.history['accuracy'],color='blue')
plt.show()

The first ModelCheckpoint is set to monitor=”val_accuracy” and assigned to Best_ValAcc_Model; the other ModelCheckpoint is set to monitor=”val_loss” and assigned to Best_ValLoss_Model. Then pass in these two ModelCheckpoint objects to the callbacks of the model fit function. After execution, two folders, ModelValAacc and ModelValLoss, will be generated in the program directory:

These two folders save the model weight corresponding to the maximum value val_accuracy and the model weight corresponding to the minimum value val_loss respectively.

Generally speaking, the best weights trained by the model will be retained when the model is trained to an epoch value between 10 and 20.

The model rereads and performs prediction judgments

How to retrieve the best model? The following sample code:

model.load_weights("./ModelValAacc")

Read the folder you just saved. After success, you can continue to execute the model’s perdict.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Czxdas
Czxdas

Written by Czxdas

Keep looking for Prajna wisdom

No responses yet

Write a response