Path: blob/master/site/en-snapshot/guide/migrate/early_stopping.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
Migrate early stopping
This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator
and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.
In TensorFlow 2, there are three ways to implement early stopping:
Use a built-in Keras callback—
tf.keras.callbacks.EarlyStopping
—and pass it toModel.fit
.Define a custom callback and pass it to Keras
Model.fit
.Write a custom early stopping rule in a custom training loop (with
tf.GradientTape
).
Setup
TensorFlow 1: Early stopping with an early stopping hook and tf.estimator
Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator
:
In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook
. You pass the hook to the make_early_stopping_hook
method as a parameter for should_stop_fn
, which can accept a function without any arguments. The training stops once should_stop_fn
returns True
.
The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:
TensorFlow 2: Early stopping with a built-in callback and Model.fit
Prepare the MNIST dataset and a simple Keras model:
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping
—to the callbacks
parameter of Model.fit
.
The EarlyStopping
callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)
Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3
(patience
):
TensorFlow 2: Early stopping with a custom callback and Model.fit
You can also implement a custom early stopping callback, which can also be passed to the callbacks
parameter of Model.fit
(or Model.evaluate
).
In this example, the training process is stopped once self.model.stop_training
is set to be True
:
TensorFlow 2: Early stopping with a custom training loop
In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.
Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:
Define the parameter update functions with tf.GradientTape and the @tf.function
decorator for a speedup:
Next, write a custom training loop, where you can implement your early stopping rule manually.
The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:
Next steps
Learn more about the Keras built-in early stopping callback API in the API docs.
Learn to write custom Keras callbacks, including early stopping at a minimum loss.
Learn about Training and evaluation with the Keras built-in methods.
Explore common regularization techniques in the Overfit and underfit tutorial that uses the
EarlyStopping
callback.