[Learn about machine learning from the Keras] — 16.Customized Callback and how to apply it

Czxdas
2 min readSep 22, 2023

--

In the chapter about model fitting, we can see that the Callback entity is put into a Container before training, and then used during the actual training iteration loop. Here we will customize the Callback class to see how it works.

from tensorflow.keras import callbacks as callbacks_module

class MyCallbackWithBatchHooks(callbacks_module.Callback):
def __init__(self):
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0

def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
print('MyCallbackWithBatchHooks',' on_train_batch_end')

def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
print('MyCallbackWithBatchHooks',' on_test_batch_end')

def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
print('MyCallbackWithBatchHooks',' on_predict_batch_end')

class MyCallbackWithBatchHooks2(callbacks_module.Callback):
def __init__(self):
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0

def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
print('MyCallbackWithBatchHooks2',' on_train_batch_end')

def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
print('MyCallbackWithBatchHooks2',' on_test_batch_end')

def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
print('MyCallbackWithBatchHooks2',' on_predict_batch_end')

my_cb = MyCallbackWithBatchHooks()
my_cb2 = MyCallbackWithBatchHooks2()
cb_list = callbacks_module.CallbackList([my_cb2,my_cb], verbose=0)
cb_list.on_train_batch_end(0)
cb_list.on_test_batch_end(0)

print(my_cb2.train_batches, my_cb.train_batches)
print(my_cb2.test_batches, my_cb.test_batches)
print(my_cb2.predict_batches , my_cb.predict_batches)

According to the example, first declare a MyCallbackWithBatchHooks entity that inherits keras.callbacks.Callback and send it to the callback’s Container class, which is keras.callbacks.CallbackList. The creation process of the Container entity first goes to keras.callbacks.CallbackList.__init__ to initialize the relevant attributes according to the passed parameters.

After Container collects the specified callback objects, it will detect and execute the functions corresponding to all callback objects one by one according to the function specifications it provides. If the callback object is not defined, it will look for the parent class keras.callbacks.Callback. , and the parent class usually just creates specifications without any program content that needs to be executed.

In other words, Callback Container provides several interfaces that allow callback objects to execute their own internal corresponding functions. These interfaces can be called at different timing points.

These interfaces are listed below:

on_batch_begin
on_batch_end
on_epoch_begin
on_epoch_end
on_train_batch_begin
on_train_batch_end
on_test_batch_begin
on_test_batch_end
on_predict_batch_begin
on_predict_batch_end
on_train_begin
on_train_end
on_test_begin
on_test_end
on_predict_begin
on_predict_end

Based on these provided interfaces, the container “batch” executes all registered callback objects.
Looking at this example again, when the container executes cb_list.on_train_batch_end, it will look at the registered callbacks, my_cb2 and my_cb, and the on_train_batch_end contents of these two are executed.

Special attention should be paid to the order in which callbacks are registered. In this example, the my_cb2 object is placed in front of my_cb, and the execution order is the same as the registration order.

The above is to observe the general operation of callback Container.

--

--