Path: blob/master/site/en-snapshot/guide/migrate/logging_stop_hook.ipynb
39053 views
Copyright 2021 The TensorFlow Authors.
Migrate LoggingTensorHook and StopAtStepHook to Keras callbacks
In TensorFlow 1, you use tf.estimator.LoggingTensorHook to monitor and log tensors, while tf.estimator.StopAtStepHook helps stop training at a specified step when training with tf.estimator.Estimator. This notebook demonstrates how to migrate from these APIs to their equivalents in TensorFlow 2 using custom Keras callbacks (tf.keras.callbacks.Callback) with Model.fit.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides. For migrating from SessionRunHook in TensorFlow 1 to Keras callbacks in TensorFlow 2, check out the Migrate training with assisted logic guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
TensorFlow 1: Log tensors and stop training with tf.estimator APIs
In TensorFlow 1, you define various hooks to control the training behavior. Then, you pass these hooks to tf.estimator.EstimatorSpec.
In the example below:
To monitor/log tensors—for example, model weights or losses—you use
tf.estimator.LoggingTensorHook(tf.train.LoggingTensorHookis its alias).To stop training at a specific step, you use
tf.estimator.StopAtStepHook(tf.train.StopAtStepHookis its alias).
TensorFlow 2: Log tensors and stop training with custom callbacks and Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure tensor monitoring and training stopping by defining custom Keras tf.keras.callbacks.Callbacks. Then, you pass them to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)
In the example below:
To recreate the functionalities of
StopAtStepHook, define a custom callback (namedStopAtStepCallbackbelow) where you override theon_batch_endmethod to stop training after a certain number of steps.To recreate the
LoggingTensorHookbehavior, define a custom callback (LoggingTensorCallback) where you record and output the logged tensors manually, since accessing to tensors by names is not supported. You can also implement the logging frequency inside the custom callback. The example below will print the weights every two steps. Other strategies like logging every N seconds are also possible.
When finished, pass the new callbacks—StopAtStepCallback and LoggingTensorCallback—to the callbacks parameter of Model.fit:
Next steps
Learn more about callbacks in:
API docs:
tf.keras.callbacks.CallbackGuide: Writing your own callbacks
Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
The Early stopping migration guide:
tf.keras.callbacks.EarlyStoppingis a built-in early stopping callbackThe TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
The Training with assisted logic migration guide: From
SessionRunHookin TensorFlow 1 to Keras callbacks in TensorFlow 2
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook