Path: blob/master/site/zh-cn/guide/migrate/early_stopping.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
本笔记本演示了如何使用提前停止设置模型训练。首先,在 TensorFlow 1 中使用 tf.estimator.Estimator
和提前停止钩子,然后在 TensorFlow 2 中使用 Keras API 或自定义训练循环。 提前停止是一种正则化技术,可在验证损失达到特定阈值时停止训练。
在 TensorFlow 2 中,可以通过三种方式实现提前停止:
使用内置的 Keras 回调
tf.keras.callbacks.EarlyStopping
并将其传递给Model.fit
。定义自定义回调并将其传递给 Keras
Model.fit
。在自定义训练循环中编写自定义提前停止规则(使用
tf.GradientTape
)。
安装
TensorFlow 1:使用提前停止钩子和 tf.estimator 提前停止
首先,定义用于 MNIST 数据集加载和预处理的函数,以及与 tf.estimator.Estimator
一起使用的模型定义:
在 TensorFlow 1 中,提前停止的工作方式是使用 tf.estimator.experimental.make_early_stopping_hook
设置提前停止钩子。将钩子传递给 make_early_stopping_hook
方法作为 should_stop_fn
的参数,它可以接受不带任何参数的函数。一旦 should_stop_fn
返回 True
,训练就会停止。
下面的示例演示了如何实现将训练时间限制为最多 20 秒的提前停止技术:
TensorFlow 2:使用内置回调和 Model.fit 提前停止
准备 MNIST 数据集和一个简单的 Keras 模型:
在 TensorFlow 2 中,当您使用内置的 Keras Model.fit
(或 Model.evaluate
)时,可以通过将内置回调 tf.keras.callbacks.EarlyStopping
传递给 Model.fit
的 callbacks
参数来配置提前停止。
EarlyStopping
回调会监视用户指定的指标,并在停止改进时结束训练。(请查看使用内置方法进行训练和评估或 API 文档来了解详情。)
下面是一个提前停止回调的示例,它监视损失并在显示没有改进的周期数设置为 3
(patience
) 后停止训练:
TensorFlow 2:使用自定义回调和 Model.fit 提前停止
您也可以实现自定义的提前停止回调,此回调也可以传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。
在此示例中,一旦 self.model.stop_training
设置为 True
,训练过程就会停止:
TensorFlow 2:使用自定义训练循环提前停止
在 TensorFlow 2 中,如果您不使用内置 Keras 方法进行训练和评估,则可以在自定义训练循环中实现提前停止。
首先,使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标:
使用 tf.GradientTape 和 @tf.function
装饰器定义参数更新函数以加快速度:
接下来,编写一个自定义训练循环,可以在其中手动实现提前停止规则。
下面的示例显示了当验证损失在一定数量的周期内没有改进时如何停止训练:
后续步骤
在 API 文档中详细了解 Keras 内置提前停止回调 API。
了解如何编写自定义 Keras 回调,包括以最小损失提前停止。
在使用
EarlyStopping
回调的过拟合和欠拟合教程中探索常见的正则化技术。