Path: blob/master/site/zh-cn/guide/migrate/migrating_checkpoints.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
迁移模型检查点
注:使用 tf.compat.v1.Saver
保存的检查点通常称为 TF1 或基于名称的检查点。使用 tf.train.Checkpoint
保存的检查点称为 TF2 或基于对象的检查点。
概述
本指南假定您有一个使用 tf.compat.v1.Saver
保存和加载检查点的模型,并且想要使用 TF2 tf.train.Checkpoint
API 迁移代码,或者使用 TF2 模型中既有的检查点。
下面是可能会遇到的一些常见情形:
情形 1
以前的训练运行中存在现有的 TF1 检查点,需要加载或转换为 TF2。
要在 TF2 中加载 TF1 检查点,请参阅在 TF2 中加载 TF1 检查点代码段。
要将检查点转换为 TF2,请参阅检查点转换。
情形 2
您正在以一种存在更改变量名称和路径的风险的方式调整您的模型(例如,从 get_variable
增量迁移到显式 tf.Variable
创建时),并且希望在此过程中保持现有检查点的保存/加载。
请参阅如何在模型迁移期间保持检查点兼容性部分
情形 3
您正在将训练代码和检查点迁移到 TF2,但您的推断流水线目前仍需要 TF1 检查点(为了生产稳定性)。
选项 1
训练时同时保存 TF1 和 TF2 检查点。
选项 2
将 TF2 检查点转换为 TF1。
请参阅检查点转换
下面的示例显示了 TF1/TF2 中保存和加载检查点的所有组合,因此可以灵活地确定如何迁移模型。
安装
从 TF1 到 TF2 的变化
如果您对 TF1 和 TF2 之间发生了哪些变化以及我们所说的“基于名称”(TF1) 与“基于对象”(TF2) 的检查点的含义感到好奇,请阅读此部分。
这两种类型的检查点实际上以相同的格式保存,本质上是一个键值表。不同之处在于键的生成方式。
基于名称的检查点中的键是变量的名称。基于对象的检查点中的键指向从根对象到变量的路径(下面的示例将有助于更好地理解这段话的含义)。
首先,保存一些检查点:
如果您查看 tf2-ckpt
中的键,它们全部指向每个变量的对象路径。例如,变量 a
是 variables
列表中的第一个元素,因此它的键变为 variables/0/...
(请尽管忽略 .ATTRIBUTES/VARIABLE_VALUE 常量)。
仔细检查下面的 Checkpoint
对象:
尝试使用下面的代码段,看看检查点键如何随对象结构变化:
为什么 TF2 使用这种机制?
TF2 中没有更多的全局计算图,因此变量名称是不可靠的,并且程序之间可能存在不一致。TF2 鼓励使用面向对象的建模方法,其中变量归层所有,层归模型所有:
当变量名称匹配时
长标题:如何在变量名称匹配时重用检查点。
短答案:可以使用 tf1.train.Saver
或 tf.train.Checkpoint
直接加载既有的检查点。
如果您使用的是 tf.compat.v1.keras.utils.track_tf1_style_variables
,那么它将确保模型变量名称与以前相同。您还可以手动确保变量名称匹配。
当迁移模型中的变量名称匹配时,您可以直接使用 tf.train.Checkpoint
或 tf.compat.v1.train.Saver
加载检查点。这两个 API 都与 Eager 和计算图模式兼容,因此您可以在迁移的任何阶段使用它们。
注:您可以使用 tf.train.Checkpoint
加载 TF1 检查点,但如果没有复杂的名称匹配,则不能使用 tf.compat.v1.Saver
加载 TF2 检查点。
下面是对不同模型使用相同检查点的示例。首先,使用 tf1.train.Saver
保存一个 TF1 检查点:
下面的示例使用 tf.compat.v1.Saver
在 Eager 模式下加载检查点:
下一个代码段使用 TF2 API tf.train.Checkpoint
加载检查点:
TF2 中的变量名称
变量仍然具有您可以设置的
name
参数。Keras 模型还采用
name
参数,并将其设置为变量的前缀。v1.name_scope
函数可用于设置变量名前缀,这与tf.variable_scope
截然不同。它只影响名称,而不跟踪变量和重用。
tf.compat.v1.keras.utils.track_tf1_style_variables
装饰器是一个填充码,它通过保持 tf.variable_scope
和 tf.compat.v1.get_variable
的命名和重用语义不变,帮助您维护变量名称和 TF1 检查点兼容性。如需了解详情,请参阅模型映射指南。
注 1:如果您使用填充码,请使用 TF2 API 加载您的检查点(即使使用预训练的 TF1 检查点)。
请参阅检查点 Keras 部分。
注 2:从 get_variable
迁移到 tf.Variable
时:
如果您的填充码装饰层或模块包含一些使用 tf.Variable
而不是 tf.compat.v1.get_variable
的变量(或 Keras 层/模型),并作为属性附加/以面向对象的方式进行跟踪,则与 Eager Execution 期间相比,它们在 TF1.x 计算图/会话中可能具有不同的变量命名语义。
简而言之,在 TF2 中运行时,这些名称可能不是预期的名称。
警告:变量在 Eager Execution 中可能有重复的名称,如果基于名称的检查点中的多个变量需要映射到相同的名称,这可能会导致问题。可以使用 tf.name_scope
和层构造函数或 tf.Variable
name
参数显式调整层和变量名称,从而调整变量名称并确保没有重复。
维护分配映射
分配映射通常用于在 TF1 模型之间传递权重,如果变量名称发生变化,也可以在模型迁移期间使用。
您可以将这些映射与 tf.compat.v1.train.init_from_checkpoint
、tf.compat.v1.train.Saver
和 tf.train.load_checkpoint
一起使用,以将权重加载到变量或范围名称可能已更改的模型中。
本部分中的示例将使用之前保存的检查点:
使用 init_from_checkpoint
加载
tf1.train.init_from_checkpoint
必须在计算图/会话中调用,因为它将值置于变量初始值设定项中,而不是创建分配运算。
您可以使用 assignment_map
参数来配置变量的加载方式。从文档中:
分配映射支持以下语法:
'checkpoint_scope_name/': 'scope_name/'
- 将从checkpoint_scope_name
加载当前scope_name
中的所有变量,并具有匹配的张量名称。'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'
- 将从checkpoint_scope_name/some_other_variable
初始化scope_name/variable_name
变量。'scope_variable_name': variable
- 将使用来自检查点的张量 'scope_variable_name' 初始化给定的tf.Variable
对象。'scope_variable_name': list(variable)
- 将使用来自检查点的张量 'scope_variable_name' 初始化分区变量列表。'/': 'scope_name/'
- 将从检查点的根目录(例如,无范围)加载当前scope_name
中的所有变量。
使用 tf1.train.Saver
加载
与 init_from_checkpoint
不同,tf.compat.v1.train.Saver
同时支持在计算图模式和 Eager 模式下运行。var_list
参数可以接受字典,但它必须将变量名称映射到 tf.Variable
对象。
使用 tf.train.load_checkpoint
加载
如果您需要精确控制变量值,则此选项适合您。同样,这适用于计算图和 Eager 模式。
维护 TF2 检查点对象
如果在迁移过程中变量和范围名称可能会发生很大变化,请使用 tf.train.Checkpoint
和 TF2 检查点。TF2 使用对象结构而不是变量名(有关详情,请参阅从 TF1 到 TF2 的变化)。
简而言之,在创建 tf.train.Checkpoint
来保存或恢复检查点时,请确保它使用相同的顺序(对于列表)和键(对于 Checkpoint
初始值设定项的字典和关键字参数)。检查点兼容性的一些示例:
下面的代码示例显示了如何使用“相同”的 tf.train.Checkpoint
来加载具有不同名称的变量。首先,保存一个 TF2 检查点:
即使变量/范围名称发生变化,也可以继续使用 tf.train.Checkpoint
:
在 Eager 模式下:
Estimator 中的 TF2 检查点
上面的部分介绍了如何在迁移模型时保持检查点兼容性。这些概念也适用于 Estimator 模型,尽管保存/加载检查点的方式略有不同。当迁移 Estimator 模型以使用 TF2 API 时,您可能希望在模型仍使用 Estimator 时从 TF1 切换到 TF2 检查点。本部分介绍如何实现此目的。
tf.estimator.Estimator
和 MonitoredSession
有一种称为 scaffold
的保存机制,即 tf.compat.v1.train.Scaffold
对象。Scaffold
可以包含 tf1.train.Saver
或 tf.train.Checkpoint
,它使 Estimator
和 MonitoredSession
能够保存 TF1 或 TF2 样式的检查点。
在从 est-tf1
热启动之后,v
的最终值应当是 16
,随后再训练 5 步。训练步的值不会从 warm_start
检查点结转。
检查点 Keras
使用 Keras 构建的模型仍然使用 tf1.train.Saver
和 tf.train.Checkpoint
来加载既有的权重。当您的模型完全迁移后,请切换为使用 model.save_weights
和 model.load_weights
,尤其是当您在训练时使用 ModelCheckpoint
回调时。
关于检查点和 Keras,您需要了解以下信息:
初始化与构建
Keras 模型和层必须经过两个步骤才能完全创建。首先是 Python 对象的初始化:layer = tf.keras.layers.Dense(x)
。其次是构建步骤,此过程实际会创建大部分权重:layer.build(input_shape)
。此外,您还可以通过调用或运行单个 train
、eval
或 predict
步骤(仅限第一次)来构建模型。
如果您发现 model.load_weights(path).assert_consumed()
引发错误,则很可能是模型/层尚未构建。
Keras 使用 TF2 检查点
tf.train.Checkpoint(model).write
等效于 model.save_weights
。与 tf.train.Checkpoint(model).read
和 model.load_weights
相同。请注意,Checkpoint(model) != Checkpoint(model=model)
。
TF2 检查点可与 Keras 的 build()
步骤一起使用
tf.train.Checkpoint.restore
有一种称为延迟恢复的机制,它允许在尚未创建变量时使用 tf.Module
和 Keras 对象存储变量值。这允许已初始化的模型加载权重并在之后构建。
由于存在这种机制,我们强烈建议您将 TF2 检查点加载 API 与 Keras 模型一起使用(即使在将既有的 TF1 检查点恢复到模型映射填充码中时)。有关详情,请参阅检查点指南。
代码段
下面的代码段显示了检查点保存 API 中的 TF1/TF2 版本兼容性。
在 TF1 中保存 TF2 检查点
在 TF1 中加载 TF2 检查点
将 TF1 检查点转换为 TF2
转换代码段 Save a TF1 checkpoint in TF2
中保存的检查点:
将 TF2 检查点转换为 TF1
转换代码段 Save a TF2 checkpoint in TF1
中保存的检查点:
相关指南
模型映射指南和
tf.compat.v1.keras.utils.track_tf1_style_variables