Path: blob/master/site/zh-cn/guide/keras/custom_callback.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
编写自己的回调函数
简介
回调是一种可以在训练、评估或推断过程中自定义 Keras 模型行为的强大工具。示例包括使用 TensorBoard 来呈现训练进度和结果的 tf.keras.callbacks.TensorBoard
,以及用来在训练期间定期保存模型的 tf.keras.callbacks.ModelCheckpoint
。
在本指南中,您将了解什么是 Keras 回调函数,它可以做什么,以及如何构建自己的回调函数。我们提供了一些简单回调函数应用的演示,以帮助您入门。
设置
Keras 回调函数概述
所有回调函数都将 keras.callbacks.Callback
类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。
您可以将回调函数的列表(作为关键字参数 callbacks
)传递给以下模型方法:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
回调函数方法概述
全局方法
on_(train|test|predict)_begin(self, logs=None)
在 fit
/evaluate
/predict
开始时调用。
on_(train|test|predict)_end(self, logs=None)
在 fit
/evaluate
/predict
结束时调用。
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on_(train|test|predict)_batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs
是包含指标结果的字典。
周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。
on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。
基本示例
让我们来看一个具体的例子。首先,导入 Tensorflow 并定义一个简单的序列式 Keras 模型:
然后,从 Keras 数据集 API 加载 MNIST 数据进行训练和测试:
接下来,定义一个简单的自定义回调函数来记录以下内容:
fit
/evaluate
/predict
开始和结束的时间每个周期开始和结束的时间
每个训练批次开始和结束的时间
每个评估(测试)批次开始和结束的时间
每次推断(预测)批次开始和结束的时间
我们来试一下:
logs
字典的用法
logs
字典包含损失值,以及批次或周期结束时的所有指标。示例包括损失和平均绝对误差。
self.model
属性的用法
除了在调用其中一种方法时接收日志信息外,回调还可以访问与当前一轮训练/评估/推断有关的模型:self.model
。
以下是您可以在回调函数中使用 self.model
进行的一些操作:
设置
self.model.stop_training = True
以立即中断训练。转变优化器(可作为
self.model.optimizer
)的超参数,例如self.model.optimizer.learning_rate
。定期保存模型。
在每个周期结束时,在少量测试样本上记录
model.predict()
的输出,以用作训练期间的健全性检查。在每个周期结束时提取中间特征的可视化,随时间推移监视模型当前的学习内容。
其他
下面我们通过几个示例来看看它是如何工作的。
Keras 回调函数应用示例
在达到最小损失时尽早停止
第一个示例展示了如何通过设置 self.model.stop_training
(布尔)属性来创建能够在达到最小损失时停止训练的 Callback
。您还可以提供参数 patience
来指定在达到局部最小值后应该等待多少个周期然后停止。
tf.keras.callbacks.EarlyStopping
提供了一种更完整、更通用的实现。
学习率规划
在此示例中,我们展示了如何在学习过程中使用自定义回调来动态更改优化器的学习率。
有关更通用的实现,请查看 callbacks.LearningRateScheduler
。
内置 Keras 回调函数
请务必阅读 API 文档查看现有的 Keras 回调函数。应用包括记录到 CSV、保存模型、在 TensorBoard 中可视化指标等等!