Path: blob/master/site/zh-cn/guide/migrate/checkpoint_saver.ipynb
39135 views
Copyright 2021 The TensorFlow Authors.
持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。
在 TensorFlow 1 中,要使用 tf.estimator.Estimator API 在训练/验证期间配置检查点保存,可以在 tf.estimator.RunConfig 中指定计划或使用 tf.estimator.CheckpointSaverHook。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。
在 TensorFlow 2 中,可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint:
根据使用
save_best_only=True参数监视的指标保存“最佳”版本,其中monitor可以是'loss'、'val_loss'、'accuracy'或 'val_accuracy'`。以特定频率持续保存(使用
save_freq参数)。通过将
save_weights_only设置为True,仅保存权重/参数而不是整个模型。
有关详情,请参阅 tf.keras.callbacks.ModelCheckpoint API 文档和保存和加载模型教程中的训练期间保存检查点部分。在保存和加载 Keras 模型指南中的 TF 检查点格式部分中详细了解检查点格式。另外,要添加容错,可以使用 tf.keras.callbacks.BackupAndRestore 或 tf.train.Checkpoint 手动设置检查点。在容错迁移指南中了解详情。
Keras 回调是在内置 Keras Model.fit/Model.evaluate/Model.predict API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的后续步骤部分中了解详情。
安装
从导入和用于演示目的的简单数据集开始:
TensorFlow 1:使用 tf.estimator API 保存检查点
此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig 以在使用 tf.estimator.Estimator API 进行训练/评估期间的每一步保存检查点:
TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点
在 TensorFlow 2 中,使用内置 Keras Model.fit(或 Model.evaluate)进行训练/评估时,可以配置 tf.keras.callbacks.ModelCheckpoint,然后将其传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。(请在 API 文档和使用内置方法进行训练和评估指南中的使用回调部分中了解详情。)
在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint 回调将检查点存储在临时目录中:
后续步骤
在以下资源中详细了解检查点:
API 文档:
tf.keras.callbacks.ModelCheckpoint教程:保存和加载模型(训练期间保存检查点部分)
指南:保存和加载 Keras 模型(TF 检查点格式部分)
以下资源中详细了解回调:
API 文档:
tf.keras.callbacks.Callback指南:编写自己的回调
指南:使用内置方法进行训练和评估(使用回调部分)
此外,您可能还会发现下列与迁移相关的资源十分有用:
容错迁移指南:用于
Model.fit的tf.keras.callbacks.BackupAndRestore,或用于自定义训练循环的tf.train.Checkpoint和tf.train.CheckpointManagerAPI提前停止迁移指南:
tf.keras.callbacks.EarlyStopping是一个内置的提前停止回调TensorBoard 迁移指南:TensorBoard 支持跟踪和显示指标
在 TensorFlow.org 上查看
在 Google Colab 运行
在 Github 上查看源代码
下载笔记本