Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/guide/migrate/migrating_checkpoints.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

迁移模型检查点

注:使用 tf.compat.v1.Saver 保存的检查点通常称为 TF1 或基于名称的检查点。使用 tf.train.Checkpoint 保存的检查点称为 TF2 或基于对象的检查点。

概述

本指南假定您有一个使用 tf.compat.v1.Saver 保存和加载检查点的模型,并且想要使用 TF2 tf.train.Checkpoint API 迁移代码,或者使用 TF2 模型中既有的检查点。

下面是可能会遇到的一些常见情形:

情形 1

以前的训练运行中存在现有的 TF1 检查点,需要加载或转换为 TF2。

情形 2

您正在以一种存在更改变量名称和路径的风险的方式调整您的模型(例如,从 get_variable 增量迁移到显式 tf.Variable 创建时),并且希望在此过程中保持现有检查点的保存/加载。

请参阅如何在模型迁移期间保持检查点兼容性部分

情形 3

您正在将训练代码和检查点迁移到 TF2,但您的推断流水线目前仍需要 TF1 检查点(为了生产稳定性)。

选项 1

训练时同时保存 TF1 和 TF2 检查点。

选项 2

将 TF2 检查点转换为 TF1。


下面的示例显示了 TF1/TF2 中保存和加载检查点的所有组合,因此可以灵活地确定如何迁移模型。

安装

import tensorflow as tf import tensorflow.compat.v1 as tf1 def print_checkpoint(save_path): reader = tf.train.load_checkpoint(save_path) shapes = reader.get_variable_to_shape_map() dtypes = reader.get_variable_to_dtype_map() print(f"Checkpoint at '{save_path}':") for key in shapes: print(f" (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, " f"value={reader.get_tensor(key)})")

从 TF1 到 TF2 的变化

如果您对 TF1 和 TF2 之间发生了哪些变化以及我们所说的“基于名称”(TF1) 与“基于对象”(TF2) 的检查点的含义感到好奇,请阅读此部分。

这两种类型的检查点实际上以相同的格式保存,本质上是一个键值表。不同之处在于键的生成方式。

基于名称的检查点中的键是变量的名称。基于对象的检查点中的键指向从根对象到变量的路径(下面的示例将有助于更好地理解这段话的含义)。

首先,保存一些检查点:

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.Session() as sess: saver = tf1.train.Saver() sess.run(a.assign(1)) sess.run(b.assign(2)) sess.run(c.assign(3)) saver.save(sess, 'tf1-ckpt') print_checkpoint('tf1-ckpt')
a = tf.Variable(5.0, name='a') b = tf.Variable(6.0, name='b') with tf.name_scope('scoped'): c = tf.Variable(7.0, name='c') ckpt = tf.train.Checkpoint(variables=[a, b, c]) save_path_v2 = ckpt.save('tf2-ckpt') print_checkpoint(save_path_v2)

如果您查看 tf2-ckpt 中的键,它们全部指向每个变量的对象路径。例如,变量 avariables 列表中的第一个元素,因此它的键变为 variables/0/...(请尽管忽略 .ATTRIBUTES/VARIABLE_VALUE 常量)。

仔细检查下面的 Checkpoint 对象:

a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) root = ckpt = tf.train.Checkpoint(variables=[a, b, c]) print("root type =", type(root).__name__) print("root.variables =", root.variables) print("root.variables[0] =", root.variables[0])

尝试使用下面的代码段,看看检查点键如何随对象结构变化:

module = tf.Module() module.d = tf.Variable(0.) test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, c=c, module=module) test_ckpt_path = test_ckpt.save('root-tf2-ckpt') print_checkpoint(test_ckpt_path)

为什么 TF2 使用这种机制?

TF2 中没有更多的全局计算图,因此变量名称是不可靠的,并且程序之间可能存在不一致。TF2 鼓励使用面向对象的建模方法,其中变量归层所有,层归模型所有:

variable = tf.Variable(...) layer.variable_name = variable model.layer_name = layer

如何在模型迁移期间保持检查点兼容性

迁移过程中的一个重要步骤是确保所有变量都被初始化为正确的值,这反过来又允许您验证运算/函数是否正在执行正确的计算。为此,您必须考虑迁移各个阶段中模型之间的检查点兼容性。本质上,本部分回答了一个问题,即如何在更改模型时继续使用相同的检查点

为了提高灵活性,下面是维护检查点兼容性的三种方法:

  1. 模型的变量名称和之前相同

  2. 模型有不同的变量名称,并维护一个将检查点中的变量名称映射到新名称的分配映射

  3. 模型有不同的变量名称,并维护了一个存储所有变量的 TF2 检查点对象

当变量名称匹配时

长标题:如何在变量名称匹配时重用检查点。

短答案:可以使用 tf1.train.Savertf.train.Checkpoint 直接加载既有的检查点。


如果您使用的是 tf.compat.v1.keras.utils.track_tf1_style_variables,那么它将确保模型变量名称与以前相同。您还可以手动确保变量名称匹配。

当迁移模型中的变量名称匹配时,您可以直接使用 tf.train.Checkpointtf.compat.v1.train.Saver 加载检查点。这两个 API 都与 Eager 和计算图模式兼容,因此您可以在迁移的任何阶段使用它们。

注:您可以使用 tf.train.Checkpoint 加载 TF1 检查点,但如果没有复杂的名称匹配,则不能使用 tf.compat.v1.Saver 加载 TF2 检查点。

下面是对不同模型使用相同检查点的示例。首先,使用 tf1.train.Saver 保存一个 TF1 检查点:

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.Session() as sess: saver = tf1.train.Saver() sess.run(a.assign(1)) sess.run(b.assign(2)) sess.run(c.assign(3)) save_path = saver.save(sess, 'tf1-ckpt') print_checkpoint(save_path)

下面的示例使用 tf.compat.v1.Saver 在 Eager 模式下加载检查点:

a = tf.Variable(0.0, name='a') b = tf.Variable(0.0, name='b') with tf.name_scope('scoped'): c = tf.Variable(0.0, name='c') # With the removal of collections in TF2, you must pass in the list of variables # to the Saver object: saver = tf1.train.Saver(var_list=[a, b, c]) saver.restore(sess=None, save_path=save_path) print(f"loaded values of [a, b, c]: [{a.numpy()}, {b.numpy()}, {c.numpy()}]") # Saving also works in eager (sess must be None). path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager') print_checkpoint(path)

下一个代码段使用 TF2 API tf.train.Checkpoint 加载检查点:

a = tf.Variable(0.0, name='a') b = tf.Variable(0.0, name='b') with tf.name_scope('scoped'): c = tf.Variable(0.0, name='c') # Without the name_scope, name="scoped/c" works too: c_2 = tf.Variable(0.0, name='scoped/c') print("Variable names: ") print(f" a.name = {a.name}") print(f" b.name = {b.name}") print(f" c.name = {c.name}") print(f" c_2.name = {c_2.name}") # Restore the values with tf.train.Checkpoint ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2]) ckpt.restore(save_path) print(f"loaded values of [a, b, c, c_2]: [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")

TF2 中的变量名称

  • 变量仍然具有您可以设置的 name 参数。

  • Keras 模型还采用 name 参数,并将其设置为变量的前缀。

  • v1.name_scope 函数可用于设置变量名前缀,这与 tf.variable_scope 截然不同。它只影响名称,而不跟踪变量和重用。

tf.compat.v1.keras.utils.track_tf1_style_variables 装饰器是一个填充码,它通过保持 tf.variable_scopetf.compat.v1.get_variable 的命名和重用语义不变,帮助您维护变量名称和 TF1 检查点兼容性。如需了解详情,请参阅模型映射指南

注 1:如果您使用填充码,请使用 TF2 API 加载您的检查点(即使使用预训练的 TF1 检查点)。

请参阅检查点 Keras 部分。

注 2:从 get_variable 迁移到 tf.Variable 时:

如果您的填充码装饰层或模块包含一些使用 tf.Variable 而不是 tf.compat.v1.get_variable 的变量(或 Keras 层/模型),并作为属性附加/以面向对象的方式进行跟踪,则与 Eager Execution 期间相比,它们在 TF1.x 计算图/会话中可能具有不同的变量命名语义。

简而言之,在 TF2 中运行时,这些名称可能不是预期的名称

警告:变量在 Eager Execution 中可能有重复的名称,如果基于名称的检查点中的多个变量需要映射到相同的名称,这可能会导致问题。可以使用 tf.name_scope 和层构造函数或 tf.Variable name 参数显式调整层和变量名称,从而调整变量名称并确保没有重复。

维护分配映射

分配映射通常用于在 TF1 模型之间传递权重,如果变量名称发生变化,也可以在模型迁移期间使用。

您可以将这些映射与 tf.compat.v1.train.init_from_checkpointtf.compat.v1.train.Savertf.train.load_checkpoint 一起使用,以将权重加载到变量或范围名称可能已更改的模型中。

本部分中的示例将使用之前保存的检查点:

print_checkpoint('tf1-ckpt')

使用 init_from_checkpoint 加载

tf1.train.init_from_checkpoint 必须在计算图/会话中调用,因为它将值置于变量初始值设定项中,而不是创建分配运算。

您可以使用 assignment_map 参数来配置变量的加载方式。从文档中:

分配映射支持以下语法:

  • 'checkpoint_scope_name/': 'scope_name/' - 将从 checkpoint_scope_name 加载当前 scope_name 中的所有变量,并具有匹配的张量名称。

  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - 将从 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 变量。

  • 'scope_variable_name': variable - 将使用来自检查点的张量 'scope_variable_name' 初始化给定的 tf.Variable 对象。

  • 'scope_variable_name': list(variable) - 将使用来自检查点的张量 'scope_variable_name' 初始化分区变量列表。

  • '/': 'scope_name/' - 将从检查点的根目录(例如,无范围)加载当前 scope_name 中的所有变量。

# Restoring with tf1.train.init_from_checkpoint: # A new model with a different scope for the variables. with tf.Graph().as_default() as g: with tf1.variable_scope('new_scope'): a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.Session() as sess: # The assignment map will remap all variables in the checkpoint to the # new scope: tf1.train.init_from_checkpoint( 'tf1-ckpt', assignment_map={'/': 'new_scope/'}) # `init_from_checkpoint` adds the initializers to these variables. # Use `sess.run` to run these initializers. sess.run(tf1.global_variables_initializer()) print("Restored [a, b, c]: ", sess.run([a, b, c]))

使用 tf1.train.Saver 加载

init_from_checkpoint 不同,tf.compat.v1.train.Saver 同时支持在计算图模式和 Eager 模式下运行。var_list 参数可以接受字典,但它必须将变量名称映射到 tf.Variable 对象。

# Restoring with tf1.train.Saver (works in both graph and eager): # A new model with a different scope for the variables. with tf1.variable_scope('new_scope'): a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) # Initialize the saver with a dictionary with the original variable names: saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c}) saver.restore(sess=None, save_path='tf1-ckpt') print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

使用 tf.train.load_checkpoint 加载

如果您需要精确控制变量值,则此选项适合您。同样,这适用于计算图和 Eager 模式。

# Restoring with tf.train.load_checkpoint (works in both graph and eager): # A new model with a different scope for the variables. with tf.Graph().as_default() as g: with tf1.variable_scope('new_scope'): a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.Session() as sess: # It may be easier writing a loop if your model has a lot of variables. reader = tf.train.load_checkpoint('tf1-ckpt') sess.run(a.assign(reader.get_tensor('a'))) sess.run(b.assign(reader.get_tensor('b'))) sess.run(c.assign(reader.get_tensor('scoped/c'))) print("Restored [a, b, c]: ", sess.run([a, b, c]))

维护 TF2 检查点对象

如果在迁移过程中变量和范围名称可能会发生很大变化,请使用 tf.train.Checkpoint 和 TF2 检查点。TF2 使用对象结构而不是变量名(有关详情,请参阅从 TF1 到 TF2 的变化)。

简而言之,在创建 tf.train.Checkpoint 来保存或恢复检查点时,请确保它使用相同的顺序(对于列表)和(对于 Checkpoint 初始值设定项的字典和关键字参数)。检查点兼容性的一些示例:

ckpt = tf.train.Checkpoint(foo=[var_a, var_b]) # compatible with ckpt tf.train.Checkpoint(foo=[var_a, var_b]) # not compatible with ckpt tf.train.Checkpoint(foo=[var_b, var_a]) tf.train.Checkpoint(bar=[var_a, var_b])

下面的代码示例显示了如何使用“相同”的 tf.train.Checkpoint 来加载具有不同名称的变量。首先,保存一个 TF2 检查点:

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(1)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(2)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(3)) with tf1.Session() as sess: sess.run(tf1.global_variables_initializer()) print("[a, b, c]: ", sess.run([a, b, c])) # Save a TF2 checkpoint ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c]) tf2_ckpt_path = ckpt.save('tf2-ckpt') print_checkpoint(tf2_ckpt_path)

即使变量/范围名称发生变化,也可以继续使用 tf.train.Checkpoint

with tf.Graph().as_default() as g: a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.variable_scope('different_scope'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.zeros_initializer()) with tf1.Session() as sess: sess.run(tf1.global_variables_initializer()) print("Initialized [a, b, c]: ", sess.run([a, b, c])) ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c]) # `assert_consumed` validates that all checkpoint objects are restored from # the checkpoint. `run_restore_ops` is required when running in a TF1 # session. ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops() # Removing `assert_consumed` is fine if you want to skip the validation. # ckpt.restore(tf2_ckpt_path).run_restore_ops() print("Restored [a, b, c]: ", sess.run([a, b, c]))

在 Eager 模式下:

a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) # The keys "scoped" and "unscoped" are no longer relevant, but are used to # maintain compatibility with the saved checkpoints. ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c]) ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops() print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

Estimator 中的 TF2 检查点

上面的部分介绍了如何在迁移模型时保持检查点兼容性。这些概念也适用于 Estimator 模型,尽管保存/加载检查点的方式略有不同。当迁移 Estimator 模型以使用 TF2 API 时,您可能希望在模型仍使用 Estimator 时从 TF1 切换到 TF2 检查点。本部分介绍如何实现此目的。

tf.estimator.EstimatorMonitoredSession 有一种称为 scaffold 的保存机制,即 tf.compat.v1.train.Scaffold 对象。Scaffold 可以包含 tf1.train.Savertf.train.Checkpoint,它使 EstimatorMonitoredSession 能够保存 TF1 或 TF2 样式的检查点。

# A model_fn that saves a TF1 checkpoint def model_fn_tf1_ckpt(features, labels, mode): # This model adds 2 to the variable `v` in every train step. train_step = tf1.train.get_or_create_global_step() v = tf1.get_variable('var', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) return tf.estimator.EstimatorSpec( mode, predictions=v, train_op=tf.group(v.assign_add(2), train_step.assign_add(1)), loss=tf.constant(1.), scaffold=None ) !rm -rf est-tf1 est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1') def train_fn(): return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6])) est.train(train_fn, steps=1) latest_checkpoint = tf.train.latest_checkpoint('est-tf1') print_checkpoint(latest_checkpoint)
# A model_fn that saves a TF2 checkpoint def model_fn_tf2_ckpt(features, labels, mode): # This model adds 2 to the variable `v` in every train step. train_step = tf1.train.get_or_create_global_step() v = tf1.get_variable('var', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step) return tf.estimator.EstimatorSpec( mode, predictions=v, train_op=tf.group(v.assign_add(2), train_step.assign_add(1)), loss=tf.constant(1.), scaffold=tf1.train.Scaffold(saver=ckpt) ) !rm -rf est-tf2 est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2', warm_start_from='est-tf1') def train_fn(): return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6])) est.train(train_fn, steps=1) latest_checkpoint = tf.train.latest_checkpoint('est-tf2') print_checkpoint(latest_checkpoint) assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4

在从 est-tf1 热启动之后,v 的最终值应当是 16,随后再训练 5 步。训练步的值不会从 warm_start 检查点结转。

检查点 Keras

使用 Keras 构建的模型仍然使用 tf1.train.Savertf.train.Checkpoint 来加载既有的权重。当您的模型完全迁移后,请切换为使用 model.save_weightsmodel.load_weights,尤其是当您在训练时使用 ModelCheckpoint 回调时。

关于检查点和 Keras,您需要了解以下信息:

初始化与构建

Keras 模型和层必须经过两个步骤才能完全创建。首先是 Python 对象的初始化layer = tf.keras.layers.Dense(x)。其次是构建步骤,此过程实际会创建大部分权重:layer.build(input_shape)。此外,您还可以通过调用或运行单个 trainevalpredict 步骤(仅限第一次)来构建模型。

如果您发现 model.load_weights(path).assert_consumed() 引发错误,则很可能是模型/层尚未构建。

Keras 使用 TF2 检查点

tf.train.Checkpoint(model).write 等效于 model.save_weights。与 tf.train.Checkpoint(model).readmodel.load_weights 相同。请注意,Checkpoint(model) != Checkpoint(model=model)

TF2 检查点可与 Keras 的 build() 步骤一起使用

tf.train.Checkpoint.restore 有一种称为延迟恢复的机制,它允许在尚未创建变量时使用 tf.Module 和 Keras 对象存储变量值。这允许已初始化的模型加载权重并在之后构建

m = YourKerasModel() status = m.load_weights(path) # This call builds the model. The variables are created with the restored # values. m.predict(inputs) status.assert_consumed()

由于存在这种机制,我们强烈建议您将 TF2 检查点加载 API 与 Keras 模型一起使用(即使在将既有的 TF1 检查点恢复到模型映射填充码中时)。有关详情,请参阅检查点指南

代码段

下面的代码段显示了检查点保存 API 中的 TF1/TF2 版本兼容性。

在 TF2 中保存 TF1 检查点

a = tf.Variable(1.0, name='a') b = tf.Variable(2.0, name='b') with tf.name_scope('scoped'): c = tf.Variable(3.0, name='c') saver = tf1.train.Saver(var_list=[a, b, c]) path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager') print_checkpoint(path)

在 TF2 中加载 TF1 检查点

a = tf.Variable(0., name='a') b = tf.Variable(0., name='b') with tf.name_scope('scoped'): c = tf.Variable(0., name='c') print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) saver = tf1.train.Saver(var_list=[a, b, c]) saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager') print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

在 TF1 中保存 TF2 检查点

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(1)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(2)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(3)) with tf1.Session() as sess: sess.run(tf1.global_variables_initializer()) ckpt = tf.train.Checkpoint( var_list={v.name.split(':')[0]: v for v in tf1.global_variables()}) tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session') print_checkpoint(tf2_in_tf1_path)

在 TF1 中加载 TF2 检查点

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: sess.run(tf1.global_variables_initializer()) print("Initialized [a, b, c]: ", sess.run([a, b, c])) ckpt = tf.train.Checkpoint( var_list={v.name.split(':')[0]: v for v in tf1.global_variables()}) ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops() print("Restored [a, b, c]: ", sess.run([a, b, c]))

检查点转换

可以通过加载和重新保存检查点在 TF1 和 TF2 之间转换检查点。另一种方式是 tf.train.load_checkpoint,如下面的代码所示。

将 TF1 检查点转换为 TF2

def convert_tf1_to_tf2(checkpoint_path, output_prefix): """Converts a TF1 checkpoint to TF2. To load the converted checkpoint, you must build a dictionary that maps variable names to variable objects. ``` ckpt = tf.train.Checkpoint(vars={name: variable}) ckpt.restore(converted_ckpt_path) ``` Args: checkpoint_path: Path to the TF1 checkpoint. output_prefix: Path prefix to the converted checkpoint. Returns: Path to the converted checkpoint. """ vars = {} reader = tf.train.load_checkpoint(checkpoint_path) dtypes = reader.get_variable_to_dtype_map() for key in dtypes.keys(): vars[key] = tf.Variable(reader.get_tensor(key)) return tf.train.Checkpoint(vars=vars).save(output_prefix)

转换代码段 Save a TF1 checkpoint in TF2 中保存的检查点:

# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`. print_checkpoint('tf1-ckpt-saved-in-eager') converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 'converted-tf1-to-tf2') print("\n[Converted]") print_checkpoint(converted_path) # Try loading the converted checkpoint. a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c}) ckpt.restore(converted_path).assert_consumed() print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

将 TF2 检查点转换为 TF1

def convert_tf2_to_tf1(checkpoint_path, output_prefix): """Converts a TF2 checkpoint to TF1. The checkpoint must be saved using a `tf.train.Checkpoint(var_list={name: variable})` To load the converted checkpoint with `tf.compat.v1.Saver`: ``` saver = tf.compat.v1.train.Saver(var_list={name: variable}) # An alternative, if the variable names match the keys: saver = tf.compat.v1.train.Saver(var_list=[variables]) saver.restore(sess, output_path) ``` """ vars = {} reader = tf.train.load_checkpoint(checkpoint_path) dtypes = reader.get_variable_to_dtype_map() for key in dtypes.keys(): # Get the "name" from the if key.startswith('var_list/'): var_name = key.split('/')[1] # TF2 checkpoint keys use '/', so if they appear in the user-defined name, # they are escaped to '.S'. var_name = var_name.replace('.S', '/') vars[var_name] = tf.Variable(reader.get_tensor(key)) return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)

转换代码段 Save a TF2 checkpoint in TF1 中保存的检查点:

# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`. print_checkpoint('tf2-ckpt-saved-in-session-1') converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1', 'converted-tf2-to-tf1') print("\n[Converted]") print_checkpoint(converted_path) # Try loading the converted checkpoint. with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: saver = tf1.train.Saver([a, b, c]) saver.restore(sess, converted_path) print("\nRestored [a, b, c]: ", sess.run([a, b, c]))

相关指南