Path: blob/master/site/zh-cn/guide/migrate/fault_tolerance.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
容错是指定期保存参数和模型等可跟踪对象的状态的机制。这样,您便能够在训练期间出现程序/机器故障时恢复它们。
本指南首先演示了如何通过使用 tf.estimator.RunConfig
指定指标保存以在 TensorFlow 1 中使用 tf.estimator.Estimator
向训练添加容错。随后,您将学习如何通过以下两种方式在 Tensorflow 2 中实现容错训练:
如果您使用 Keras
Model.fit
API,则可以将tf.keras.callbacks.BackupAndRestore
回调传递给它。如果您使用自定义训练循环(使用
tf.GradientTape
),则可以使用tf.train.Checkpoint
和tf.train.CheckpointManager
API 任意保存检查点。
这两种方式都会备份和恢复检查点文件中的训练状态。
安装
安装 tf-nightly
,因为使用 tf.keras.callbacks.BackupAndRestore
中的 save_freq
参数设置特定步骤保存检查点的频率是从 TensorFlow 2.10 引入的:
TensorFlow 1:使用 tf.estimator.RunConfig 保存检查点
在 TensorFlow 1 中,可以配置 tf.estimator
,随后通过配置 tf.estimator.RunConfig
在每一步保存检查点。
在此示例中,首先编写一个在第五个检查点期间人为抛出错误的钩子:
接下来,配置 tf.estimator.Estimator
以保存每个检查点并使用 MNIST 数据集:
开始训练模型。您之前定义的钩子将引发人为异常。
使用最后保存的检查点重新构建 tf.estimator.Estimator
并继续训练:
TensorFlow 2:使用回调和 Model.fit 备份和恢复
在 TensorFlow 2 中,如果使用 Keras Model.fit
API 进行训练,则可以提供 tf.keras.callbacks.BackupAndRestore
回调来添加容错功能。
为了帮助演示这一点,首先定义一个 Keras Callback
类,该类会在第四个周期检查点期间人为抛出错误:
然后,定义并实例化一个简单的 Keras 模型,定义损失函数,调用 Model.compile
并设置一个 tf.keras.callbacks.BackupAndRestore
回调,它会将检查点保存在周期边界的临时目录中:
开始使用 Model.fit
训练模型。在训练期间,由于上面实例化的 tf.keras.callbacks.BackupAndRestore
将保存检查点,而 InterruptAtEpoch
类将引发人为异常来模拟第四个周期后的失败。
接下来,实例化 Keras 模型,调用 Model.compile
,并从之前保存的检查点继续使用 Model.fit
训练模型:
定义另一个 Callback
类,该类会在第 140 步期间人为抛出错误:
注:本部分使用了仅在 Tensorflow 2.10 发布后才能在 tf-nightly
中可用的功能。
要确保检查点每 30 个步骤保存一次,请将 BackupAndRestore
回调中的 save_freq
设置为 30
。InterruptAtStep
将引发一个人为的异常来模拟周期 1 和步骤 40 的失败(总步数为 140)。最后会在周期 1 和步骤 20 保存检查点。
接下来,实例化 Keras 模型,调用 Model.compile
,并从之前保存的检查点继续使用 Model.fit
训练模型。请注意,训练从周期 2 和步骤 21 开始。
TensorFlow 2:使用自定义训练循环编写手动检查点
如果您在 TensorFlow 2 中使用自定义训练循环,则可以使用 tf.train.Checkpoint
和 tf.train.CheckpointManager
API 实现容错机制。
此示例演示了如何执行以下操作:
使用
tf.train.Checkpoint
对象手动创建一个检查点,其中要保存的可跟踪对象设置为特性。使用
tf.train.CheckpointManager
管理多个检查点。
首先,定义和实例化 Keras 模型、优化器和损失函数。然后,创建一个 Checkpoint
来管理两个具有可跟踪状态的对象(模型和优化器),以及一个 CheckpointManager
来记录多个检查点并将它们保存在临时目录中。
现在,实现一个自定义训练循环,在第一个周期之后,每次新周期开始时都会加载最后一个检查点:
后续步骤
要详细了解 TensorFlow 2 中的容错和检查点,请查看以下文档:
tf.keras.callbacks.BackupAndRestore
回调 API 文档。tf.train.Checkpoint
和tf.train.CheckpointManager
API 文档。训练检查点指南,包括编写检查点部分。
此外,您可能还会发现下列与分布式训练相关的材料十分有用:
使用 Keras 进行多工作进程训练教程中的容错部分。
参数服务器训练教程中的处理任务失败部分。