Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/sessionrunhook_callback.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

In TensorFlow 1, to customize the behavior of training, you use tf.estimator.SessionRunHook with tf.estimator.Estimator. This guide demonstrates how to migrate from SessionRunHook to TensorFlow 2's custom callbacks with the tf.keras.callbacks.Callback API, which works with Keras Model.fit for training (as well as Model.evaluate and Model.predict). You will learn how to do this by implementing a SessionRunHook and a Callback task that measures examples per second during training.

Examples of callbacks are checkpoint saving (tf.keras.callbacks.ModelCheckpoint) and TensorBoard summary writing. Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow as tf import tensorflow.compat.v1 as tf1 import time from datetime import datetime from absl import flags
features = [[1., 1.5], [2., 2.5], [3., 3.5]] labels = [[0.3], [0.5], [0.7]] eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]] eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs

The following TensorFlow 1 examples show how to set up a custom SessionRunHook that measures examples per second during training. After creating the hook (LoggerHook), pass it to the hooks parameter of tf.estimator.Estimator.train.

def _input_fn(): return tf1.data.Dataset.from_tensor_slices( (features, labels)).batch(1).repeat(100) def _model_fn(features, labels, mode): logits = tf1.layers.Dense(1)(features) loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits) optimizer = tf1.train.AdagradOptimizer(0.05) train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step()) return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() self.log_frequency = 10 def before_run(self, run_context): self._step += 1 def after_run(self, run_context, run_values): if self._step % self.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time examples_per_sec = self.log_frequency / duration print('Time:', datetime.now(), ', Step #:', self._step, ', Examples per second:', examples_per_sec) estimator = tf1.estimator.Estimator(model_fn=_model_fn) # Begin training. estimator.train(_input_fn, hooks=[LoggerHook()])

TensorFlow 2: Create a custom Keras callback for Model.fit

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure a custom tf.keras.callbacks.Callback, which you then pass to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)

In the example below, you will write a custom tf.keras.callbacks.Callback that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous SessionRunHook example.

class CustomCallback(tf.keras.callbacks.Callback): def on_train_begin(self, logs = None): self._step = -1 self._start_time = time.time() self.log_frequency = 10 def on_train_batch_begin(self, batch, logs = None): self._step += 1 def on_train_batch_end(self, batch, logs = None): if self._step % self.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time examples_per_sec = self.log_frequency / duration print('Time:', datetime.now(), ', Step #:', self._step, ', Examples per second:', examples_per_sec) callback = CustomCallback() dataset = tf.data.Dataset.from_tensor_slices( (features, labels)).batch(1).repeat(100) model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)]) optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05) model.compile(optimizer, "mse") # Begin training. result = model.fit(dataset, callbacks=[callback], verbose = 0) # Provide the results of training metrics. result.history

Next steps

Learn more about callbacks in:

You may also find the following migration-related resources useful: