Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/early_stopping.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.

In TensorFlow 2, there are three ways to implement early stopping:

  • Use a built-in Keras callback—tf.keras.callbacks.EarlyStopping—and pass it to Model.fit.

  • Define a custom callback and pass it to Keras Model.fit.

  • Write a custom early stopping rule in a custom training loop (with tf.GradientTape).

Setup

import time import numpy as np import tensorflow as tf import tensorflow.compat.v1 as tf1 import tensorflow_datasets as tfds

TensorFlow 1: Early stopping with an early stopping hook and tf.estimator

Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator:

def normalize_img(image, label): return tf.cast(image, tf.float32) / 255., label def _input_fn(): ds_train = tfds.load( name='mnist', split='train', shuffle_files=True, as_supervised=True) ds_train = ds_train.map( normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_train = ds_train.batch(128) ds_train = ds_train.repeat(100) return ds_train def _eval_input_fn(): ds_test = tfds.load( name='mnist', split='test', shuffle_files=True, as_supervised=True) ds_test = ds_test.map( normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_test = ds_test.batch(128) return ds_test def _model_fn(features, labels, mode): flatten = tf1.layers.Flatten()(features) features = tf1.layers.Dense(128, 'relu')(flatten) logits = tf1.layers.Dense(10)(features) loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) optimizer = tf1.train.AdagradOptimizer(0.005) train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step()) return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook. You pass the hook to the make_early_stopping_hook method as a parameter for should_stop_fn, which can accept a function without any arguments. The training stops once should_stop_fn returns True.

The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:

estimator = tf1.estimator.Estimator(model_fn=_model_fn) start_time = time.time() max_train_seconds = 20 def should_stop_fn(): return time.time() - start_time > max_train_seconds early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook( estimator=estimator, should_stop_fn=should_stop_fn, run_every_secs=1, run_every_steps=None) train_spec = tf1.estimator.TrainSpec( input_fn=_input_fn, hooks=[early_stopping_hook]) eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn) tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

TensorFlow 2: Early stopping with a built-in callback and Model.fit

Prepare the MNIST dataset and a simple Keras model:

(ds_train, ds_test), ds_info = tfds.load( 'mnist', split=['train', 'test'], shuffle_files=True, as_supervised=True, with_info=True, ) ds_train = ds_train.map( normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_train = ds_train.batch(128) ds_test = ds_test.map( normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_test = ds_test.batch(128) model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10) ]) model.compile( optimizer=tf.keras.optimizers.Adam(0.005), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], )

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping—to the callbacks parameter of Model.fit.

The EarlyStopping callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)

Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3 (patience):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) # Only around 25 epochs are run during training, instead of 100. history = model.fit( ds_train, epochs=100, validation_data=ds_test, callbacks=[callback] ) len(history.history['loss'])

TensorFlow 2: Early stopping with a custom callback and Model.fit

You can also implement a custom early stopping callback, which can also be passed to the callbacks parameter of Model.fit (or Model.evaluate).

In this example, the training process is stopped once self.model.stop_training is set to be True:

class LimitTrainingTime(tf.keras.callbacks.Callback): def __init__(self, max_time_s): super().__init__() self.max_time_s = max_time_s self.start_time = None def on_train_begin(self, logs): self.start_time = time.time() def on_train_batch_end(self, batch, logs): now = time.time() if now - self.start_time > self.max_time_s: self.model.stop_training = True
# Limit the training time to 30 seconds. callback = LimitTrainingTime(30) history = model.fit( ds_train, epochs=100, validation_data=ds_test, callbacks=[callback] ) len(history.history['loss'])

TensorFlow 2: Early stopping with a custom training loop

In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.

Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:

model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10) ]) optimizer = tf.keras.optimizers.Adam(0.005) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy() train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy() val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy() val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

Define the parameter update functions with tf.GradientTape and the @tf.function decorator for a speedup:

@tf.function def train_step(x, y): with tf.GradientTape() as tape: logits = model(x, training=True) loss_value = loss_fn(y, logits) grads = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) train_acc_metric.update_state(y, logits) train_loss_metric.update_state(y, logits) return loss_value @tf.function def test_step(x, y): logits = model(x, training=False) val_acc_metric.update_state(y, logits) val_loss_metric.update_state(y, logits)

Next, write a custom training loop, where you can implement your early stopping rule manually.

The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:

epochs = 100 patience = 5 wait = 0 best = float('inf') for epoch in range(epochs): print("\nStart of epoch %d" % (epoch,)) start_time = time.time() for step, (x_batch_train, y_batch_train) in enumerate(ds_train): loss_value = train_step(x_batch_train, y_batch_train) if step % 200 == 0: print("Training loss at step %d: %.4f" % (step, loss_value.numpy())) print("Seen so far: %s samples" % ((step + 1) * 128)) train_acc = train_acc_metric.result() train_loss = train_loss_metric.result() train_acc_metric.reset_states() train_loss_metric.reset_states() print("Training acc over epoch: %.4f" % (train_acc.numpy())) for x_batch_val, y_batch_val in ds_test: test_step(x_batch_val, y_batch_val) val_acc = val_acc_metric.result() val_loss = val_loss_metric.result() val_acc_metric.reset_states() val_loss_metric.reset_states() print("Validation acc: %.4f" % (float(val_acc),)) print("Time taken: %.2fs" % (time.time() - start_time)) # The early stopping strategy: stop the training if `val_loss` does not # decrease over a certain number of epochs. wait += 1 if val_loss < best: best = val_loss wait = 0 if wait >= patience: break

Next steps