Path: blob/master/site/zh-cn/guide/migrate/checkpoint_saver.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。
在 TensorFlow 1 中,要使用 tf.estimator.Estimator
API 在训练/验证期间配置检查点保存,可以在 tf.estimator.RunConfig
中指定计划或使用 tf.estimator.CheckpointSaverHook
。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。
在 TensorFlow 2 中,可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint
:
根据使用
save_best_only=True
参数监视的指标保存“最佳”版本,其中monitor
可以是'loss'
、'val_loss'
、'accuracy'
或 'val_accuracy'`。以特定频率持续保存(使用
save_freq
参数)。通过将
save_weights_only
设置为True
,仅保存权重/参数而不是整个模型。
有关详情,请参阅 tf.keras.callbacks.ModelCheckpoint
API 文档和保存和加载模型教程中的训练期间保存检查点部分。在保存和加载 Keras 模型指南中的 TF 检查点格式部分中详细了解检查点格式。另外,要添加容错,可以使用 tf.keras.callbacks.BackupAndRestore
或 tf.train.Checkpoint
手动设置检查点。在容错迁移指南中了解详情。
Keras 回调是在内置 Keras Model.fit
/Model.evaluate
/Model.predict
API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的后续步骤部分中了解详情。
安装
从导入和用于演示目的的简单数据集开始:
TensorFlow 1:使用 tf.estimator API 保存检查点
此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig
以在使用 tf.estimator.Estimator
API 进行训练/评估期间的每一步保存检查点:
TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点
在 TensorFlow 2 中,使用内置 Keras Model.fit
(或 Model.evaluate
)进行训练/评估时,可以配置 tf.keras.callbacks.ModelCheckpoint
,然后将其传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。(请在 API 文档和使用内置方法进行训练和评估指南中的使用回调部分中了解详情。)
在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint
回调将检查点存储在临时目录中:
后续步骤
在以下资源中详细了解检查点:
API 文档:
tf.keras.callbacks.ModelCheckpoint
教程:保存和加载模型(训练期间保存检查点部分)
指南:保存和加载 Keras 模型(TF 检查点格式部分)
以下资源中详细了解回调:
API 文档:
tf.keras.callbacks.Callback
指南:编写自己的回调
指南:使用内置方法进行训练和评估(使用回调部分)
此外,您可能还会发现下列与迁移相关的资源十分有用:
容错迁移指南:用于
Model.fit
的tf.keras.callbacks.BackupAndRestore
,或用于自定义训练循环的tf.train.Checkpoint
和tf.train.CheckpointManager
API提前停止迁移指南:
tf.keras.callbacks.EarlyStopping
是一个内置的提前停止回调TensorBoard 迁移指南:TensorBoard 支持跟踪和显示指标