Path: blob/master/guides/ipynb/writing_your_own_callbacks.ipynb
3283 views
Writing your own callbacks
Authors: Rick Chao, Francois Chollet
Date created: 2019/03/20
Last modified: 2023/06/25
Description: Complete guide to writing new Keras callbacks.
Introduction
A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include keras.callbacks.TensorBoard
to visualize training progress and results with TensorBoard, or keras.callbacks.ModelCheckpoint
to periodically save your model during training.
In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started.
Setup
Keras callbacks overview
All callbacks subclass the keras.callbacks.Callback
class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.
You can pass a list of callbacks (as the keyword argument callbacks
) to the following model methods:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
An overview of callback methods
Global methods
on_(train|test|predict)_begin(self, logs=None)
Called at the beginning of fit
/evaluate
/predict
.
on_(train|test|predict)_end(self, logs=None)
Called at the end of fit
/evaluate
/predict
.
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Called right before processing a batch during training/testing/predicting.
on_(train|test|predict)_batch_end(self, batch, logs=None)
Called at the end of training/testing/predicting a batch. Within this method, logs
is a dict containing the metrics results.
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.
on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.
A basic example
Let's take a look at a concrete example. To get started, let's import tensorflow and define a simple Sequential Keras model:
Then, load the MNIST data for training and testing from Keras datasets API:
Now, define a simple custom callback that logs:
When
fit
/evaluate
/predict
starts & endsWhen each epoch starts & ends
When each training batch starts & ends
When each evaluation (test) batch starts & ends
When each inference (prediction) batch starts & ends
Let's try it out:
Usage of logs
dict
The logs
dict contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.
Usage of self.model
attribute
In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self.model
.
Here are a few of the things you can do with self.model
in a callback:
Set
self.model.stop_training = True
to immediately interrupt training.Mutate hyperparameters of the optimizer (available as
self.model.optimizer
), such asself.model.optimizer.learning_rate
.Save the model at period intervals.
Record the output of
model.predict()
on a few test samples at the end of each epoch, to use as a sanity check during training.Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
etc.
Let's see this in action in a couple of examples.
Examples of Keras callback applications
Early stopping at minimum loss
This first example shows the creation of a Callback
that stops training when the minimum of loss has been reached, by setting the attribute self.model.stop_training
(boolean). Optionally, you can provide an argument patience
to specify how many epochs we should wait before stopping after having reached a local minimum.
keras.callbacks.EarlyStopping
provides a more complete and general implementation.
Learning rate scheduling
In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.
See callbacks.LearningRateScheduler
for a more general implementations.
Built-in Keras callbacks
Be sure to check out the existing Keras callbacks by reading the API docs. Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!