[Learn about machine learning from the Keras] — 7. Model Predict operation process

Czxdas
3 min readSep 21, 2023

--

This will observe the general operation process of the model when making predictions or decisions, and how it will proceed.
After the model is trained, you can execute the predict function of the model entity to judge the signal.

from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

from tensorflow.keras import layers
from tensorflow.keras.models import Model

model = Sequential([
layers.Dense(512, activation="relu"),
layers.Dense(10, activation="softmax")
])

model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])

model.fit(train_images, train_labels, epochs=5, batch_size=128)

test_digits = test_images[0:10]
predictions = model.predict(test_digits)
print(predictions[0])

(1)
According to this example, you will first go to the keras.engine.training.Model.predict function body, and pass the passed parameters and signals into the keras.engine.data_adapter.DataHandler constructor for settings. This will also be done when doing model training. The parameters of the training function are packaged into data_adapter.DataHandler entities, which means that model execution predictions can also be executed in batch settings (for example, the data_adapter.DataHandler entity created during fit can be passed in epoch and batch_size).

(2)
The parameters can also set the callback object. If the callbacks parameter list passed in is not an entity list that inherits the keras.callbacks.CallbackList class, the callbacks parameter will be integrated into the keras.callbacks.CallbackList class. This class is also the Container of the callback. . During the integration process, just like during training, the final keras.callbacks.CallbackList entity will definitely contain the keras.callbacks.ProgbarLogger class entity and the keras.callbacks.History class entity.

(3)
Then call the keras.engine.training.Model.make_predict_function function to determine the predict_function to be used for each predict step. Here, a method similar to closure is used to save the function in memory in advance, and this block of memory will be used when running batches.

(4)
The execution architecture is as follows:

callbacks.on_predict_begin()
for _, iterator in data_handler.enumerate_epochs(): #epoch固定為1
for step in data_handler.steps():
callbacks.on_predict_batch_begin(step)

tmp_batch_outputs = self.predict_function(iterator)

callbacks.on_predict_batch_end(
end_step, {"outputs": batch_outputs}
)
callbacks.on_predict_end()

(5)
The predict function mainly executed in the loop will find keras.engine.training.Model.predict_step:

It will go directly to keras.engine.data_adapter.unpack_x_y_sample_weight, and convert the incoming data into a tuple according to the list or other type collection and return it.

Then execute keras.engine.training.Model.call and keras.engine.base_layer.call. After some parameter conversion and inspection, and after the model has not been built, the build action will be run again, and then keras.engine.sequential.call will be executed. Model call function.

The model call function initially executes keras.engine.sequential._build_graph_network_for_inferred_shape, sets the tensor dimension of the input, and continues to keras.engine.functional._run_internal_graph to find out the Layer nodes built by each model and execute the Layer in sequence. .call. In this example, the second-layer keras.layers.core.dense.call is executed sequentially. It mainly performs the inner product of the input input and its own Layer.kernel and returns it for use by the application main body. (After each layer is calculated, it will be saved to keras.engine.sequential.outputs)

The above is the general behavior of the model doing predict, which is recorded here.

--

--