Path: blob/master/site/en-snapshot/guide/migrate/sessionrunhook_callback.ipynb
38617 views
Copyright 2021 The TensorFlow Authors.
Migrate SessionRunHook to Keras callbacks
In TensorFlow 1, to customize the behavior of training, you use tf.estimator.SessionRunHook with tf.estimator.Estimator. This guide demonstrates how to migrate from SessionRunHook to TensorFlow 2's custom callbacks with the tf.keras.callbacks.Callback API, which works with Keras Model.fit for training (as well as Model.evaluate and Model.predict). You will learn how to do this by implementing a SessionRunHook and a Callback task that measures examples per second during training.
Examples of callbacks are checkpoint saving (tf.keras.callbacks.ModelCheckpoint) and TensorBoard summary writing. Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides.
Setup
Start with imports and a simple dataset for demonstration purposes:
TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs
The following TensorFlow 1 examples show how to set up a custom SessionRunHook that measures examples per second during training. After creating the hook (LoggerHook), pass it to the hooks parameter of tf.estimator.Estimator.train.
TensorFlow 2: Create a custom Keras callback for Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure a custom tf.keras.callbacks.Callback, which you then pass to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)
In the example below, you will write a custom tf.keras.callbacks.Callback that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous SessionRunHook example.
Next steps
Learn more about callbacks in:
API docs:
tf.keras.callbacks.CallbackGuide: Writing your own callbacks
Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
The Early stopping migration guide:
tf.keras.callbacks.EarlyStoppingis a built-in early stopping callbackThe TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
The LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook