← Back to homepage

An introduction to TensorFlow.Keras callbacks

November 10, 2020 by Chris

Training a deep learning model is both simple and complex at the same time. It's simple because with libraries like TensorFlow 2.0 (tensorflow.keras, specifically) it's very easy to get started. But while creating a first model is easy, fine-tuning it while knowing what you are doing is a bit more complex.

For example, you will need some knowledge on the supervised learning process, gradient descent or other optimization, regularization, and a lot of other contributing factors.

Tweaking and tuning a deep learning models therefore benefits from two things: insight into what is happening and automated control to avoid the need for human intervention where possible. In Keras, this can be achieved with the tensorflow.keras.callbacks API. In this article, we will look into Callbacks in more detail. We will first illustrate what they are by displaying where they play a role in the supervised machine learning process. Then, we cover the Callbacks API - and for each callback, illustrate what it can be used for together with a small example. Finally, we will show how you can create your own Callback with the tensorflow.keras.callbacks.Base class.

Let's take a look :)

Update 11/Jan/2021: changed header image.

Callbacks and their role in the training process

In our article about the supervised machine learning process, we saw how a supervised machine learning model is trained:

  1. A machine learning model (today, often a neural network) is initialized.
  2. Samples from the training set are fed forward, through the model, resulting in a set of predictions.
  3. The predictions are compared with what is known as the ground truth (i.e. the labels corresponding to the training samples), resulting in one value - a loss value - telling us how bad the model performs.
  4. Based on the loss value and the subsequent backwards computation of the error, the weights are changed a little bit, to make the model a bit better. Then, we're either moving back to step 2, or we stop the training process.

As we can see, steps 2-4 are iterative, meaning that the model improves in a cyclical fashion. This is reflected in the figure below.

In Machine Learning terms, each iteration is also called an epoch. Hence, training a machine learning model involves the completion of at least one, but often multiple epochs. Note from the article about gradient descent based optimization that we often don't feed forward all data at once. Instead, we use what is called a minibatch approach - the entire batch of data is fed forward in smaller batches called minibatches. By consequence, each epoch consists of at least one but often multiple batches of data.

Now, it can be the case that you want to get insights from the training process while it is running. Or you want to provide automated steering in order to avoid wasting resources. In those cases, you might want to add a callback to your Keras model.

A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc).

Keras Team (n.d.)

As we shall see later in this article, among others, there are callbacks for monitoring and for stopping the training process when it no longer makes the model better. This is possible because with callbacks, we can 'capture' the training process while it is happening. They essentially 'hook' into the training process by allowing the training process to invoke certain callback definitions. In Keras, each callback implements at least one, but possibly multiple of the following definitions (Keras Team, n.d.).

As we can see, by using a callback, through the definitions outlined above, we can control the training process at a variety of levels.

The Keras Callbacks API

Now that we understand what callbacks are, how they can help us, and what definitions - and hence hooks - are available for 'breaking into' your training process in TensorFlow 2.x based Keras. Now, it's time to take a look at the Keras Callbacks API. Available as tensorflow.keras.callbacks, it's a set of generally valuable Callbacks that can be used in a variety of cases.

Most specifically, it contains the following callbacks, and we will cover each of them next:

  1. ModelCheckpoint callback: can be used to automatically save a model after each epoch, or just the best one.
  2. TensorBoard callback: allows us to monitor the training process in realtime with TensorBoard.
  3. EarlyStopping callback: ensures that the training process stops if the loss value does no longer improve.
  4. LearningRateScheduler callback: updates the learning rate before the start of an epoch, based on a scheduler function.
  5. ReduceLROnPlateau callback: reduces learning rate if the loss value does no longer improve.
  6. RemoteMonitor callback: sends TensorFlow training events to a remote monitor, such as a logging system.
  7. LambdaCallback: allows us to define simple functions that can be executed as a callback.
  8. TerminateOnNaN callback: if the loss value is Not a Number (NaN), the training process stops.
  9. CSVLogger callback: streams the outcome of an epoch to a CSV file.
  10. ProgbarLogger callback: used to determine what is printed to standard output in the Keras progress bar.

How do we add a callback to a Keras model?

Before we take a look at all the individual callbacks, we must take a look at how we can use the tensorflow.keras.callbacks API in the first place. Doing so is really simple and only changes your code in a minor way:

  1. You must add the specific callbacks to the model imports.
  2. You must initialize the callbacks you want to use, including their configuration; preferably do so in a list.
  3. You must add the callbacks to the model.fit call.

With those three simple steps, you ensure that the callbacks are hooked into the training process!

For example, if we want to use both ModelCheckpoint and EarlyStopping - as we do here - for step (1), we first add the imports:

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

Then, for step (2), we initialize the callbacks in a list:

keras_callbacks = [
      EarlyStopping(monitor='val_loss', patience=5, mode='min', min_delta=0.01),
      ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True, mode='min')
]

And then, for step (3), we simply add the callbacks to model.fit:

model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

ModelCheckpoint callback

If you want to periodically save your Keras model - or the model weights - to some file, the ModelCheckpoint callback is what you need.

Callback to save the Keras model or model weights at some frequency.

TensorFlow (n.d.)

It is available as follows:

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs
)

With the following arguments:

Using ModelCheckpoint is easy - and here is an example based on a generator:

checkpoint_path=f'{os.path.dirname(os.path.realpath(__file__))}/covid-convnet.h5'
keras_callbacks = [
      ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True, mode='min')
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

TensorBoard callback

Did you know that you can visualize the training process realtime with TensorBoard?

With the TensorBoard callback, you can link TensorBoard with your Keras model.

Enable visualizations for TensorBoard.

TensorFlow (n.d.)

The callback logs a range of items from the training process into your TensorBoard log location:

It is implemented as follows:

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
    update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

Here is an example of using the TensorBoard callback within your Keras model:

keras_callbacks = [
      TensorBoard(log_dir="./logs")
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

EarlyStopping callback

Optimizing your neural network involves applying gradient descent or another optimizer to a loss value generated by feeding forward batches of training samples, generating predictions that are compared with the corresponding training labels.

During this process, you want to find a model that performs well in terms of predictions (i.e., it is not underfit) but that is not too rigid with respect to the dataset it is trained on (i.e., it is neither overfit). That's why the EarlyStopping callback can be useful if you are dealing with a situation like this.

Stop training when a monitored metric has stopped improving.

TensorBoard (n.d.)

It is implemented as follows:

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

Here is an example of using EarlyStopping with Keras:

keras_callbacks = [
      EarlyStopping(monitor='val_loss', min_delta=0.001, restore_best_weights=True)
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

LearningRateScheduler callback

During the optimization process, a so called weight update is computed. However, if we compare the optimization process with rolling a ball down a mountain (reflecting the loss landscape), we want to smooth the ride, ensuring that our ball does not bounce out of control. That is why a learning rate is applied: it specifies a fraction of the weight update to be used by the optimizer.

Preferably being relatively large during the early iterations and lower in the later stages, we must adapt the learning rate during the training process. This is called learning rate decay and shows what a learning rate scheduler can be useful for. The LearningRateScheduler callback implements this functionality.

At the beginning of every epoch, this callback gets the updated learning rate value from schedule function provided at __init__, with the current epoch and current learning rate, and applies the updated learning rate on the optimizer.

TensorFlow (n.d.)

Its implementation is really simple:

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

Here is an example of using the LearningRateScheduler with Keras:

def scheduler(epoch, learning_rate):
  if epoch < 15:
    return learning_rate
  else:
    return learning_rate * 0.99

keras_callbacks = [
      LearningRateScheduler(scheduler)
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

ReduceLROnPlateau callback

During the optimization process - i.e., rolling the ball downhill - it can be the case that you encounter so-called loss plateaus. In those areas, the gradient of the loss function is close to zero, but not entirely - indicating that you are in the vicinity of a loss minimum. That is, close to where you want to be (unless you are dealing with a local minimum, of course).

Keeping your learning rate equal when close to a plateau means that your model will likely not improve any further. This happens because your model will optimize, oscillating around the loss minimum, simply because the steps the current learning rate it instructs to set are too big.

With the ReduceLROnPlateau callback, the optimization process can be instructed to reduce the learning rate (and hence the step) when a plateau is encountered.

Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This callback monitors a quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced.

TensorFlow (n.d.)

The callback is implemented as follows:

tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto',
    min_delta=0.0001, cooldown=0, min_lr=0, **kwargs
)

An example of using the ReduceLROnPlateau callback with Keras:

keras_callbacks = [
      ReduceLROnPlateau(monitor='val_loss', factor=0.25, patience=5, cooldown=5, min_lr=0.000000001)
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

RemoteMonitor callback

Above, we saw that training logs can be distributed to TensorBoard for visualization and logging purposes. However, it can be the case that you have your own logging and visualization system - whether that's a cloud-based system or a locally installed Grafana or Elastic Stack visualization tooling.

In those cases, you might wish to send the training logs there instead. The RemoteMonitor callback can help you do this.

Callback used to stream events to a server.

TensorFlow (n.d.)

It is implemented as follows:

tf.keras.callbacks.RemoteMonitor(
    root='http://localhost:9000', path='/publish/epoch/end/', field='data',
    headers=None, send_as_json=False
)

An example of using the RemoteMonitor callback with Keras:

keras_callbacks = [
      RemoteMonitor(root='https://some-domain.com', path='/statistics/keras')
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

LambdaCallback

Say that you want a certain function to fire after every batch or every epoch - a simple function, nothing special. However, it's not provided in the collection of callbacks presented with the tensorflow.keras.callbacks API. In this case, you might want to use the LambdaCallback.

Callback for creating simple, custom callbacks on-the-fly. This callback is constructed with anonymous functions that will be called at the appropriate time. Te

TensorFlow (n.d.)

It can thus be used to provide anonymous (i.e. lambda functions without a name) functions to the training process. The callback looks as follows:

tf.keras.callbacks.LambdaCallback(
    on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None,
    on_train_begin=None, on_train_end=None, **kwargs
)

Here, the on_epoch_begin, on_epoch_end, on_batch_begin, on_batch_end, on_train_begin and on_train_end event based arguments can be filled with Python definitions. They are executed at the right point in time.

An example of a LambdaCallback added to your Keras model:

keras_callbacks = [
      LambdaCallback(on_batch_end=lambda batch, log_data: print(batch))
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

TerminateOnNaN callback

In some cases (e.g. when you did not apply min-max normalization to your input data), the loss value can be very strange - outputting values close to Infinity or values that are Not a Number (NaN). In those cases, you don't want to pursue further training. The TerminateOnNaN callback can help here.

Callback that terminates training when a NaN loss is encountered.

TensorFlow (n.d.)

It is implemented as follows:

tf.keras.callbacks.TerminateOnNaN()

An example of using the TerminateOnNaN callback with your Keras model:

keras_callbacks = [
      TerminateOnNaN()
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

CSVLogger callback

CSV files can be very useful when you need to exchange data. If you want to flush your training logs into a CSV file, the CSVLogger callback can be useful to you.

Callback that streams epoch results to a CSV file.

TensorFlow (n.d.)

It is implemented as follows:

tf.keras.callbacks.CSVLogger(
    filename, separator=',', append=False
)

This is an example of using the CSVLogger callback with Keras:

keras_callbacks = [
      CSVLogger('./logs.csv', separator=';', append=True)
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

ProgbarLogger callback

When you are training a Keras model with verbosity set to True, you will see a progress bar in your terminal. With the ProgbarLogger callback, you can change what is displayed there.

Callback that prints metrics to stdout.

TensorFlow (n.d.)

It is implemented as follows:

tf.keras.callbacks.ProgbarLogger(
    count_mode='samples', stateful_metrics=None
)

Here is an example of using the ProgbarLogger callback with Keras.

keras_callbacks = [
      ProgbarLogger(count_mode='samples')
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

Experimental: BackupAndRestore callback

When you are training a neural network, especially in a distributed setting, it would be problematic if your training process suddenly stops - e.g. due to machine failure. Every iteration passed so far will be gone. With the experimental BackupAndRestore callback, you can instruct Keras to create temporary checkpoint files after each epoch, to which you can restore later.

BackupAndRestore callback is intended to recover from interruptions that happened in the middle of a model.fit execution by backing up the training states in a temporary checkpoint file (based on TF CheckpointManager) at the end of each epoch.

TensorFlow (n.d.)

It is implemented as follows:

tf.keras.callbacks.experimental.BackupAndRestore(
    backup_dir
)

Here, the backup_dir attribute indicates the folder where checkpoints should be created.

Here is an example of using the BackupAndRestore callback with Keras.

keras_callbacks = [
       BackupAndRestore('./checkpoints')
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

Applied by default: History and BaseLogger callbacks

There are two callbacks that are part of the tensorflow.keras.callbacks API but which can be covered less extensively - because of the simple reason that they are already applied to each Keras model under the hood.

They are the History and BaseLogger callbacks.

Creating your own callback with the Base Callback

Sometimes, neither the default or the lambda callbacks can provide the functionality you need. In those cases, you can create your own callback, by using the Base callback class tensorflow.keras.callbacks.Callback. Creating one is very simple: you define a class, create the relevant definitions (you can choose from on_epoch_begin, on_epoch_end, on_batch_begin, on_batch_end, on_train_begin and on_train_end etc.), and then add the callback to your callbacks list. There you go!

class OwnCallback(tensorflow.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        print('Training is now beginning!')

keras_callbacks = [
       OwnCallback()
]
model.fit(train_generator,
          epochs=50,
          verbose=1,
          callbacks=keras_callbacks,
          validation_data=val_generator)

Summary

In this article, we looked at the concept of a callback for hooking into the supervised machine learning training process. Sometimes, you want to receive additional information while you are training a model. In other cases, you want to actively steer the process into a desired direction. Both cases are possible by means of a callback.

Beyond the conceptual introduction to callbacks, we also looked at how Keras implements them - by means of the tensorflow.keras.callbacks API. We briefly looked at each individual callback provided by Keras, ranging from automated changes to hyperparameters to logging in TensorBoard, file or into a remote monitor. We also looked at creating your own callback, whether that's with a LambdaCallback for simple custom callbacks or with the Base callback class for more complex ones.

I hope that you have learned something from today's article! If you did, please feel free to leave a comment in the comments section below 💬 Please do the same if you have any questions, remarks or suggestions for improvement. Thank you for reading MachineCurve today and happy engineering! 😎

References

Keras Team. (n.d.). Keras documentation: Callbacks API. Keras: the Python deep learning API. https://keras.io/api/callbacks/

Keras Team. (2020, April 15). Keras documentation: Writing your own callbacks. Keras: the Python deep learning API. https://keras.io/guides/writing_your_own_callbacks/#a-basic-example

TensorFlow. (n.d.). Tf.keras.callbacks.ModelCheckpointhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint

TensorFlow. (n.d.). Tf.keras.callbacks.TensorBoardhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard

TensorFlow. (n.d.). Tf.keras.callbacks.EarlyStoppinghttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping

TensorFlow. (n.d.). Tf.keras.callbacks.LearningRateSchedulerhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/LearningRateScheduler

TensorFlow. (n.d.). Tf.keras.callbacks.ReduceLROnPlateauhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau

TensorFlow. (n.d.). Tf.keras.callbacks.RemoteMonitorhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/RemoteMonitor

TensorFlow. (n.d.). Tf.keras.callbacks.LambdaCallbackhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/LambdaCallback

TensorFlow. (n.d.). Tf.keras.callbacks.TerminateOnNaNhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TerminateOnNaN

TensorFlow. (n.d.). Tf.keras.callbacks.BaseLoggerhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/BaseLogger

TensorFlow. (n.d.). Tf.keras.callbacks.CSVLoggerhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/CSVLogger

TensorFlow. (n.d.). Tf.keras.callbacks.Historyhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/History

TensorFlow. (n.d.). Tf.keras.callbacks.ProgbarLoggerhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ProgbarLogger

TensorFlow. (n.d.). Tf.keras.callbacks.experimental.BackupAndRestorehttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore

TensorFlow. (n.d.). Tf.keras.callbacks.Callbackhttps://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

Hi, I'm Chris!

I know a thing or two about AI and machine learning. Welcome to MachineCurve.com, where machine learning is explained in gentle terms.