Path: blob/master/guides/writing_your_own_callbacks.py
3273 views
"""1Title: Writing your own callbacks2Authors: Rick Chao, Francois Chollet3Date created: 2019/03/204Last modified: 2023/06/255Description: Complete guide to writing new Keras callbacks.6Accelerator: GPU7"""89"""10## Introduction1112A callback is a powerful tool to customize the behavior of a Keras model during13training, evaluation, or inference. Examples include `keras.callbacks.TensorBoard`14to visualize training progress and results with TensorBoard, or15`keras.callbacks.ModelCheckpoint` to periodically save your model during training.1617In this guide, you will learn what a Keras callback is, what it can do, and how you can18build your own. We provide a few demos of simple callback applications to get you19started.20"""2122"""23## Setup24"""2526import numpy as np27import keras2829"""30## Keras callbacks overview3132All callbacks subclass the `keras.callbacks.Callback` class, and33override a set of methods called at various stages of training, testing, and34predicting. Callbacks are useful to get a view on internal states and statistics of35the model during training.3637You can pass a list of callbacks (as the keyword argument `callbacks`) to the following38model methods:3940- `keras.Model.fit()`41- `keras.Model.evaluate()`42- `keras.Model.predict()`43"""4445"""46## An overview of callback methods4748### Global methods4950#### `on_(train|test|predict)_begin(self, logs=None)`5152Called at the beginning of `fit`/`evaluate`/`predict`.5354#### `on_(train|test|predict)_end(self, logs=None)`5556Called at the end of `fit`/`evaluate`/`predict`.5758### Batch-level methods for training/testing/predicting5960#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`6162Called right before processing a batch during training/testing/predicting.6364#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`6566Called at the end of training/testing/predicting a batch. Within this method, `logs` is67a dict containing the metrics results.6869### Epoch-level methods (training only)7071#### `on_epoch_begin(self, epoch, logs=None)`7273Called at the beginning of an epoch during training.7475#### `on_epoch_end(self, epoch, logs=None)`7677Called at the end of an epoch during training.78"""7980"""81## A basic example8283Let's take a look at a concrete example. To get started, let's import tensorflow and84define a simple Sequential Keras model:85"""868788# Define the Keras model to add callbacks to89def get_model():90model = keras.Sequential()91model.add(keras.layers.Dense(1))92model.compile(93optimizer=keras.optimizers.RMSprop(learning_rate=0.1),94loss="mean_squared_error",95metrics=["mean_absolute_error"],96)97return model9899100"""101Then, load the MNIST data for training and testing from Keras datasets API:102"""103104# Load example MNIST data and pre-process it105(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()106x_train = x_train.reshape(-1, 784).astype("float32") / 255.0107x_test = x_test.reshape(-1, 784).astype("float32") / 255.0108109# Limit the data to 1000 samples110x_train = x_train[:1000]111y_train = y_train[:1000]112x_test = x_test[:1000]113y_test = y_test[:1000]114115"""116Now, define a simple custom callback that logs:117118- When `fit`/`evaluate`/`predict` starts & ends119- When each epoch starts & ends120- When each training batch starts & ends121- When each evaluation (test) batch starts & ends122- When each inference (prediction) batch starts & ends123"""124125126class CustomCallback(keras.callbacks.Callback):127def on_train_begin(self, logs=None):128keys = list(logs.keys())129print("Starting training; got log keys: {}".format(keys))130131def on_train_end(self, logs=None):132keys = list(logs.keys())133print("Stop training; got log keys: {}".format(keys))134135def on_epoch_begin(self, epoch, logs=None):136keys = list(logs.keys())137print("Start epoch {} of training; got log keys: {}".format(epoch, keys))138139def on_epoch_end(self, epoch, logs=None):140keys = list(logs.keys())141print("End epoch {} of training; got log keys: {}".format(epoch, keys))142143def on_test_begin(self, logs=None):144keys = list(logs.keys())145print("Start testing; got log keys: {}".format(keys))146147def on_test_end(self, logs=None):148keys = list(logs.keys())149print("Stop testing; got log keys: {}".format(keys))150151def on_predict_begin(self, logs=None):152keys = list(logs.keys())153print("Start predicting; got log keys: {}".format(keys))154155def on_predict_end(self, logs=None):156keys = list(logs.keys())157print("Stop predicting; got log keys: {}".format(keys))158159def on_train_batch_begin(self, batch, logs=None):160keys = list(logs.keys())161print("...Training: start of batch {}; got log keys: {}".format(batch, keys))162163def on_train_batch_end(self, batch, logs=None):164keys = list(logs.keys())165print("...Training: end of batch {}; got log keys: {}".format(batch, keys))166167def on_test_batch_begin(self, batch, logs=None):168keys = list(logs.keys())169print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))170171def on_test_batch_end(self, batch, logs=None):172keys = list(logs.keys())173print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))174175def on_predict_batch_begin(self, batch, logs=None):176keys = list(logs.keys())177print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))178179def on_predict_batch_end(self, batch, logs=None):180keys = list(logs.keys())181print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))182183184"""185Let's try it out:186"""187188model = get_model()189model.fit(190x_train,191y_train,192batch_size=128,193epochs=1,194verbose=0,195validation_split=0.5,196callbacks=[CustomCallback()],197)198199res = model.evaluate(200x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]201)202203res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])204205"""206### Usage of `logs` dict207208The `logs` dict contains the loss value, and all the metrics at the end of a batch or209epoch. Example includes the loss and mean absolute error.210"""211212213class LossAndErrorPrintingCallback(keras.callbacks.Callback):214def on_train_batch_end(self, batch, logs=None):215print(216"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])217)218219def on_test_batch_end(self, batch, logs=None):220print(221"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])222)223224def on_epoch_end(self, epoch, logs=None):225print(226"The average loss for epoch {} is {:7.2f} "227"and mean absolute error is {:7.2f}.".format(228epoch, logs["loss"], logs["mean_absolute_error"]229)230)231232233model = get_model()234model.fit(235x_train,236y_train,237batch_size=128,238epochs=2,239verbose=0,240callbacks=[LossAndErrorPrintingCallback()],241)242243res = model.evaluate(244x_test,245y_test,246batch_size=128,247verbose=0,248callbacks=[LossAndErrorPrintingCallback()],249)250251"""252## Usage of `self.model` attribute253254In addition to receiving log information when one of their methods is called,255callbacks have access to the model associated with the current round of256training/evaluation/inference: `self.model`.257258Here are a few of the things you can do with `self.model` in a callback:259260- Set `self.model.stop_training = True` to immediately interrupt training.261- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`),262such as `self.model.optimizer.learning_rate`.263- Save the model at period intervals.264- Record the output of `model.predict()` on a few test samples at the end of each265epoch, to use as a sanity check during training.266- Extract visualizations of intermediate features at the end of each epoch, to monitor267what the model is learning over time.268- etc.269270Let's see this in action in a couple of examples.271"""272273"""274## Examples of Keras callback applications275276### Early stopping at minimum loss277278This first example shows the creation of a `Callback` that stops training when the279minimum of loss has been reached, by setting the attribute `self.model.stop_training`280(boolean). Optionally, you can provide an argument `patience` to specify how many281epochs we should wait before stopping after having reached a local minimum.282283`keras.callbacks.EarlyStopping` provides a more complete and general implementation.284"""285286287class EarlyStoppingAtMinLoss(keras.callbacks.Callback):288"""Stop training when the loss is at its min, i.e. the loss stops decreasing.289290Arguments:291patience: Number of epochs to wait after min has been hit. After this292number of no improvement, training stops.293"""294295def __init__(self, patience=0):296super().__init__()297self.patience = patience298# best_weights to store the weights at which the minimum loss occurs.299self.best_weights = None300301def on_train_begin(self, logs=None):302# The number of epoch it has waited when loss is no longer minimum.303self.wait = 0304# The epoch the training stops at.305self.stopped_epoch = 0306# Initialize the best as infinity.307self.best = np.inf308309def on_epoch_end(self, epoch, logs=None):310current = logs.get("loss")311if np.less(current, self.best):312self.best = current313self.wait = 0314# Record the best weights if current results is better (less).315self.best_weights = self.model.get_weights()316else:317self.wait += 1318if self.wait >= self.patience:319self.stopped_epoch = epoch320self.model.stop_training = True321print("Restoring model weights from the end of the best epoch.")322self.model.set_weights(self.best_weights)323324def on_train_end(self, logs=None):325if self.stopped_epoch > 0:326print(f"Epoch {self.stopped_epoch + 1}: early stopping")327328329model = get_model()330model.fit(331x_train,332y_train,333batch_size=64,334epochs=30,335verbose=0,336callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],337)338339"""340### Learning rate scheduling341342In this example, we show how a custom Callback can be used to dynamically change the343learning rate of the optimizer during the course of training.344345See `callbacks.LearningRateScheduler` for a more general implementations.346"""347348349class CustomLearningRateScheduler(keras.callbacks.Callback):350"""Learning rate scheduler which sets the learning rate according to schedule.351352Arguments:353schedule: a function that takes an epoch index354(integer, indexed from 0) and current learning rate355as inputs and returns a new learning rate as output (float).356"""357358def __init__(self, schedule):359super().__init__()360self.schedule = schedule361362def on_epoch_begin(self, epoch, logs=None):363if not hasattr(self.model.optimizer, "learning_rate"):364raise ValueError('Optimizer must have a "learning_rate" attribute.')365# Get the current learning rate from model's optimizer.366lr = self.model.optimizer.learning_rate367# Call schedule function to get the scheduled learning rate.368scheduled_lr = self.schedule(epoch, lr)369# Set the value back to the optimizer before this epoch starts370self.model.optimizer.learning_rate = scheduled_lr371print(f"\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}.")372373374LR_SCHEDULE = [375# (epoch to start, learning rate) tuples376(3, 0.05),377(6, 0.01),378(9, 0.005),379(12, 0.001),380]381382383def lr_schedule(epoch, lr):384"""Helper function to retrieve the scheduled learning rate based on epoch."""385if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:386return lr387for i in range(len(LR_SCHEDULE)):388if epoch == LR_SCHEDULE[i][0]:389return LR_SCHEDULE[i][1]390return lr391392393model = get_model()394model.fit(395x_train,396y_train,397batch_size=64,398epochs=15,399verbose=0,400callbacks=[401LossAndErrorPrintingCallback(),402CustomLearningRateScheduler(lr_schedule),403],404)405406"""407### Built-in Keras callbacks408409Be sure to check out the existing Keras callbacks by410reading the [API docs](https://keras.io/api/callbacks/).411Applications include logging to CSV, saving412the model, visualizing metrics in TensorBoard, and a lot more!413"""414415416