Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/fault_tolerance.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Fault tolerance refers to a mechanism of periodically saving the states of trackable objects, such as parameters and models. This enables you to recover them in the event of a program/machine failure during training.

This guide first demonstrates how to add fault tolerance to training with tf.estimator.Estimator in TensorFlow 1 by specifying metric saving with tf.estimator.RunConfig. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:

  • If you use the Keras Model.fit API, you can pass the tf.keras.callbacks.BackupAndRestore callback to it.

  • If you use a custom training loop (with tf.GradientTape), you can arbitrarily save checkpoints using the tf.train.Checkpoint and tf.train.CheckpointManager APIs.

Both of these methods will back up and restore the training states in checkpoint files.

Setup

Install tf-nightly, as the frequency of checkpoint saving at a particular step with the save_freq argument in tf.keras.callbacks.BackupAndRestore is introduced from TensorFlow 2.10:

!pip install tf-nightly
import tensorflow.compat.v1 as tf1 import tensorflow as tf import numpy as np import tempfile import time
mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1: Save checkpoints with tf.estimator.RunConfig

In TensorFlow 1, you can configure a tf.estimator to save checkpoints every step by configuring tf.estimator.RunConfig.

In this example, start by writing a hook that artificially throws an error during the fifth checkpoint:

class InterruptHook(tf1.train.SessionRunHook): # A hook for artificially interrupting training. def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 def after_run(self, run_context, run_values): if self._step == 5: raise RuntimeError('Interruption')

Next, configure tf.estimator.Estimator to save every checkpoint and use the MNIST dataset:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])] config = tf1.estimator.RunConfig(save_summary_steps=1, save_checkpoints_steps=1) path = tempfile.mkdtemp() classifier = tf1.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[256, 32], optimizer=tf1.train.AdamOptimizer(0.001), n_classes=10, dropout=0.2, model_dir=path, config = config ) train_input_fn = tf1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train.astype(np.int32), num_epochs=10, batch_size=50, shuffle=True, )

Begin training the model. An artificial exception will be raised by the hook you defined earlier.

try: classifier.train(input_fn=train_input_fn, hooks=[InterruptHook()], max_steps=10) except Exception as e: print(f'{type(e).__name__}:{e}')

Rebuild the tf.estimator.Estimator using the last saved checkpoint and continue training:

classifier = tf1.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[256, 32], optimizer=tf1.train.AdamOptimizer(0.001), n_classes=10, dropout=0.2, model_dir=path, config = config ) classifier.train(input_fn=train_input_fn, max_steps = 10)

TensorFlow 2: Back up and restore with a callback and Model.fit

In TensorFlow 2, if you use the Keras Model.fit API for training, you can provide the tf.keras.callbacks.BackupAndRestore callback to add the fault tolerance functionality.

To help demonstrate this, first start by defining a Keras Callback class that artificially throws an error during the fourth epoch checkpoint:

class InterruptAtEpoch(tf.keras.callbacks.Callback): # A callback for artificially interrupting training. def __init__(self, interrupting_epoch=3): self.interrupting_epoch = interrupting_epoch def on_epoch_end(self, epoch, log=None): if epoch == self.interrupting_epoch: raise RuntimeError('Interruption')

Then, define and instantiate a simple Keras model, define the loss function, call Model.compile, and set up a tf.keras.callbacks.BackupAndRestore callback that will save the checkpoints in a temporary directory at epoch boundaries:

def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10) ]) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model = create_model() model.compile(optimizer='adam', loss=loss, metrics=['accuracy']) log_dir = tempfile.mkdtemp() backup_restore_callback = tf.keras.callbacks.BackupAndRestore( backup_dir = log_dir)

Start training the model with Model.fit. During training, checkpoints will be saved thanks to tf.keras.callbacks.BackupAndRestore instantiated above, while the InterruptAtEpoch class will raise an artificial exception to simulate a failure after the fourth epoch.

try: model.fit(x=x_train, y=y_train, epochs=10, steps_per_epoch=100, validation_data=(x_test, y_test), callbacks=[backup_restore_callback, InterruptAtEpoch()]) except Exception as e: print(f'{type(e).__name__}:{e}')

Next, instantiate the Keras model, call Model.compile, and continue training the model with Model.fit from a previously saved checkpoint:

model = create_model() model.compile(optimizer='adam', loss=loss, metrics=['accuracy'], steps_per_execution=10) model.fit(x=x_train, y=y_train, epochs=10, steps_per_epoch=100, validation_data=(x_test, y_test), callbacks=[backup_restore_callback])

Define another Callback class that artificially throws an error during the 140th step:

class InterruptAtStep(tf.keras.callbacks.Callback): # A callback for artificially interrupting training. def __init__(self, interrupting_step=140): self.total_step_count = 0 self.interrupting_step = interrupting_step def on_batch_begin(self, batch, logs=None): self.total_step_count += 1 def on_batch_end(self, batch, logs=None): if self.total_step_count == self.interrupting_step: print("\nInterrupting at step count", self.total_step_count) raise RuntimeError('Interruption')

Note: This section uses features that are only available in tf-nightly until Tensorflow 2.10 is released.

To make sure the checkpoints are saved every 30 steps, set the save_freq in the BackupAndRestore callback to 30. The InterruptAtStep will raise an artificial exception to simulate a failure at epoch 1 and step 40 (total step count 140). The checkpoint would be last saved at epoch 1 and step 20.

log_dir_2 = tempfile.mkdtemp() backup_restore_callback = tf.keras.callbacks.BackupAndRestore( backup_dir = log_dir_2, save_freq=30 ) model = create_model() model.compile(optimizer='adam', loss=loss, metrics=['accuracy']) try: model.fit(x=x_train, y=y_train, epochs=10, steps_per_epoch=100, validation_data=(x_test, y_test), callbacks=[backup_restore_callback, InterruptAtStep()]) except Exception as e: print(f'{type(e).__name__}:{e}')

Next, instantiate the Keras model, call Model.compile, and continue training the model with Model.fit from a previously saved checkpoint. Notice that the training starts from epoch 2 and step 21.

model = create_model() model.compile(optimizer='adam', loss=loss, metrics=['accuracy'], steps_per_execution=10) model.fit(x=x_train, y=y_train, epochs=10, steps_per_epoch=100, validation_data=(x_test, y_test), callbacks=[backup_restore_callback])

TensorFlow 2: Write manual checkpoints with a custom training loop

If you use a custom training loop in TensorFlow 2, you can implement a fault tolerance mechanism with the tf.train.Checkpoint and tf.train.CheckpointManager APIs.

This example demonstrates how to:

  • Use a tf.train.Checkpoint object to manually create a checkpoint, where the trackable objects you want to save are set as attributes.

  • Use a tf.train.CheckpointManager to manage multiple checkpoints.

Start by defining and instantiating the Keras model, the optimizer, and the loss function. Then, create a Checkpoint that manages two objects with trackable states (the model and the optimizer), as well as a CheckpointManager for logging and keeping several checkpoints in a temporary directory.

model = create_model() optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) log_dir = tempfile.mkdtemp() epochs = 5 steps_per_epoch = 5 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, log_dir, max_to_keep=2)

Now, implement a custom training loop where after the first epoch every time a new epoch starts the last checkpoint is loaded:

for epoch in range(epochs): if epoch > 0: tf.train.load_checkpoint(save_path) print(f"\nStart of epoch {epoch}") for step in range(steps_per_epoch): with tf.GradientTape() as tape: logits = model(x_train, training=True) loss_value = loss_fn(y_train, logits) grads = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) save_path = checkpoint_manager.save() print(f"Checkpoint saved to {save_path}") print(f"Training loss at step {step}: {loss_value}")

Next steps

To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:

  • The tf.keras.callbacks.BackupAndRestore callback API docs.

  • The tf.train.Checkpoint and tf.train.CheckpointManager API docs.

  • The Training checkpoints guide, including the Writing checkpoints section.

You may also find the following material related to distributed training useful: