Path: blob/master/site/zh-cn/guide/checkpoint.ipynb
25115 views
Copyright 2018 The TensorFlow Authors.
训练检查点
“保存 TensorFlow 模型”这一短语通常表示保存以下两种元素之一:
检查点,或
SavedModel。
检查点可以捕获模型使用的所有参数(tf.Variable
对象)的确切值。检查点不包含对模型所定义计算的任何描述,因此通常仅在将使用保存参数值的源代码可用时才有用。
另一方面,除了参数值(检查点)之外,SavedModel 格式还包括对模型所定义计算的序列化描述。这种格式的模型独立于创建模型的源代码。因此,它们适合通过 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或者使用其他编程语言(C、C++、Java、Go、Rust、C# 等 TensorFlow API)编写的程序进行部署。
本文介绍用于编写和读取检查点的 API。
设置
编写检查点
TensorFlow 模型的持久状态存储在 tf.Variable
对象中。这些对象可以直接构造,但通常会通过像 tf.keras.layers
或 tf.keras.Model
这样的高级 API 创建。
管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。
tf.train.Checkpoint
、tf.keras.layers.Layer
和 tf.keras.Model
的子类会自动跟踪分配给其特性的变量。下面的示例构造了一个简单的线性模型,然后编写检查点,其中包含该模型所有变量的值。
您可以使用 Model.save_weights
轻松保存模型检查点。
手动创建检查点
设置
为了帮助演示 tf.train.Checkpoint
的所有功能, 下面定义了一个玩具 (toy) 数据集和优化步骤:
创建检查点对象
使用 tf.train.Checkpoint
对象手动创建一个检查点,其中要检查的对象设置为对象的特性。
tf.train.CheckpointManager
也有助于管理多个检查点。
训练模型并为模型设置检查点
以下训练循环可创建模型和优化器的实例,然后将它们收集到 tf.train.Checkpoint
对象中。它在每批数据上循环调用训练步骤,并定期将检查点写入磁盘。
恢复和继续训练
在第一个训练周期结束后,您可以传递一个新的模型和管理器,但在您中断的地方继续训练:
tf.train.CheckpointManager
对象会删除旧的检查点。上面配置为仅保留最近的三个检查点。
这些路径(如 './tf_ckpts/ckpt-10'
)不是磁盘上的文件,而是一个 index
文件和一个或多个包含变量值的数据文件的前缀。这些前缀被分组到一个单独的 checkpoint
文件 ('./tf_ckpts/checkpoint'
) 中,其中 CheckpointManager
保存其状态。
加载机制
TensorFlow 通过从加载的对象开始遍历带命名边的有向计算图来将变量与检查点值匹配。边名称通常来自对象中的特性名称,例如 self.l1 = tf.keras.layers.Dense(5)
中的 "l1"
。tf.train.Checkpoint
使用其关键字参数名称,如 tf.train.Checkpoint(step=...)
中的 "step"
。
上面示例中的依赖图如下所示:
优化器为红色,常规变量为蓝色,优化器插槽变量为橙色。其他节点(例如,代表 tf.train.Checkpoint
的节点)为黑色。
插槽变量是优化器状态的一部分,不过是为特定变量而创建。例如,上面的 'm'
边缘对应于动量,Adam 优化器会针对每个变量跟踪该动量。只有在同时保存变量和优化器时,才会将插槽变量保存到检查点中,并因此保存虚线边缘。
在 tf.train.Checkpoint
对象上调用 restore
会排队处理请求的恢复,一旦有来自 Checkpoint
对象的匹配路径,就会恢复变量值。例如,您可以通过重建一个穿过网络和层到达它的路径来仅从上面定义的模型加载偏差。
这些新对象的依赖关系计算图是您上面所编写较大检查点的一个小得多的子计算图。它仅包括偏差和 tf.train.Checkpoint
用于对检查点进行编号的保存计数器。
restore
返回一个具有可选断言的状态对象。在新的 Checkpoint
中创建的所有对象都已恢复,因此 status.assert_existing_objects_matched
通过。
检查点中有许多不匹配的对象,包括层的内核和优化器的变量。status.assert_consumed()
仅在检查点和程序完全匹配时通过,并在此处抛出异常。
延迟恢复
当输入形状可用时,TensorFlow 中的 Layer
对象可能会将变量创建延迟到变量的首次调用。例如,Dense
层内核的形状取决于该层的输入和输出形状,因此,作为构造函数参数所需的输出形状没有足够的信息来单独创建变量。由于调用 Layer
还会读取变量的值,必须在变量的创建与其首次使用之间进行恢复。
为支持这种习惯用法,tf.train.Checkpoint
会推迟尚不具有匹配变量的恢复。
手动检查检查点
tf.train.load_checkpoint
返回一个提供对检查点内容进行较低级别访问权限的 CheckpointReader
。它包含从每个变量的键到检查点中每个变量的形状和 dtype 的映射。如上面显示的计算图中所示,变量的键是它的对象路径。
注:检查点没有更高级别的结构。它只知道变量的路径和值,而没有 models
、layers
或它们如何连接的概念。
因此,如果您对 net.l1.kernel
的值感兴趣,可以使用以下代码获取该值:
此外,它还提供了一个 get_tensor
方法,允许您检查变量的值:
对象跟踪
检查点通过“跟踪”一个特性中的任何变量或可跟踪对象集来保存和回复 tf.Variable
对象的值。执行保存时,将从所有可访问的跟踪对象递归收集变量。
对于像 self.l1 = tf.keras.layers.Dense(5)
一样的直接特性赋值,将列表和字典分配给特性会跟踪其内容。
您可能会注意到列表和字典的包装器对象。这些包装器是可设置检查点版本的基础数据结构。就像基于特性的加载一样,这些包装器会在将变量添加到容器后立即恢复它的值。
可跟踪对象包括 tf.train.Checkpoint
、tf.Module
及其子类(例如 keras.layers.Layer
和 keras.Model
),并识别 Python 容器:
dict
(和collections.OrderedDict
)list
tuple
(和collections.namedtuple
、typing.NamedTuple
)
其他容器类型不受支持,包括:
collections.defaultdict
set
所有其他 Python 对象都会被忽略,包括:
int
string
float
总结
TensorFlow 对象提供了一种简单的自动机制来保存和恢复它们所使用变量的值。