Path: blob/master/site/zh-cn/guide/intro_to_modules.ipynb
25115 views
Copyright 2020 The TensorFlow Authors.
要进行 TensorFlow 机器学习,您可能需要定义、保存和恢复模型。
抽象地说,模型是:
一个在张量上进行某些计算的函数(前向传递)
一些可以更新以响应训练的变量
在本指南中,您将深入学习 Keras,了解如何定义 TensorFlow 模型。本文着眼于 TensorFlow 如何收集变量和模型,以及如何保存和恢复它们。
注:如果您想立即开始使用 Keras,请参阅 Keras 指南集合。
设置
TensorFlow 模块
大多数模型都由层组成。层是具有已知数学结构的函数,可以重复使用并具有可训练的变量。在 TensorFlow 中,层和模型的大多数高级实现(例如 Keras 或 Sonnet)都在以下同一个基础类上构建:tf.Module
。
构建模块
下面是一个在标量张量上运行的非常简单的 tf.Module
示例:
模块和引申而来的层是“对象”的深度学习术语:它们具有内部状态以及使用该状态的方法。
__call__
并无特殊之处,只是其行为与 Python 可调用对象类似;您可以使用任何函数来调用模型。
您可以出于任何原因开启和关闭变量的可训练性,包括在微调过程中冻结层和变量。
注:tf.Module
是 tf.keras.layers.Layer
和 tf.keras.Model
的基类,因此您在此处看到的一切内容也适用于 Keras。出于历史兼容性原因,Keras 层不会从模块收集变量,因此您的模型应仅使用模块或仅使用 Keras 层。不过,下面给出的用于检查变量的方法在这两种情况下相同。
通过将 tf.Module
子类化,将自动收集分配给该对象属性的任何 tf.Variable
或 tf.Module
实例。这样,您可以保存和加载变量,还可以创建 tf.Module
的集合。
下面是一个由模块组成的两层线性层模型的示例。
首先是一个密集(线性)层:
随后是完整的模型,此模型将创建并应用两个层实例:
tf.Module
实例将以递归方式自动收集分配给它的任何 tf.Variable
或 tf.Module
实例。这样,您可以使用单个模型实例管理 tf.Module
的集合,并保存和加载整个模型。
等待创建变量
您在这里可能已经注意到,必须定义层的输入和输出大小。这样,w
变量才会具有已知的形状并且可被分配。
通过将变量创建推迟到第一次使用特定输入形状调用模块时,您将无需预先指定输入大小。
这种灵活性是 TensorFlow 层通常仅需要指定其输出的形状(例如在 tf.keras.layers.Dense
中),而无需指定输入和输出大小的原因。
检查点由两种文件组成---数据本身以及元数据的索引文件。索引文件跟踪实际保存的内容和检查点的编号,而检查点数据包含变量值及其特性查找路径。
您可以查看检查点内部,以确保整个变量集合已由包含这些变量的 Python 对象保存并排序。
在分布式(多机)训练期间,可以将它们分片,这就是要对它们进行编号(例如 '00000-of-00001')的原因。不过,在本例中,只有一个分片。
重新加载模型时,将重写 Python 对象中的值。
注:由于检查点处于长时间训练工作流的核心位置,因此 tf.checkpoint.CheckpointManager
是一个可使检查点管理变得更简单的辅助类。有关更多详细信息,请参阅指南。
保存函数
TensorFlow 可以在不使用原始 Python 对象的情况下运行模型,如 TensorFlow Serving 和 TensorFlow Lite 所示,甚至当您从 TensorFlow Hub 下载经过训练的模型时也是如此。
TensorFlow 需要了解如何执行 Python 中描述的计算,但不需要原始代码。为此,您可以创建一个计算图,如计算图和函数简介指南中所述。
此计算图中包含实现函数的运算。
您可以通过添加 @tf.function
装饰器在上面的模型中定义计算图,以指示此代码应作为计算图运行。
您构建的模块的工作原理与之前完全相同。传递给函数的每个唯一签名都会创建一个单独的计算图。请参阅计算图和函数简介指南以了解详情。
您可以通过在 TensorBoard 摘要中跟踪计算图来将其可视化。
启动 Tensorboard 以查看生成的跟踪:
创建 SavedModel
共享经过完全训练的模型的推荐方式是使用 SavedModel
。SavedModel
包含函数集合与权重集合。
您可以按以下方式保存刚刚训练的模型:
saved_model.pb
文件是一个描述函数式 tf.Graph
的协议缓冲区。
可以从此表示加载模型和层,而无需实际构建创建该表示的类的实例。在您没有(或不需要)Python 解释器(例如大规模应用或在边缘设备上),或者在原始 Python 代码不可用或不实用的情况下,这样做十分理想。
您可以将模型作为新对象加载:
通过加载已保存模型创建的 new_model
是 TensorFlow 内部的用户对象,无需任何类知识。它不是 SequentialModule
类型的对象。
此新模型适用于已定义的输入签名。您不能向以这种方式恢复的模型添加更多签名。
因此,利用 SavedModel
,您可以使用 tf.Module
保存 TensorFlow 权重和计算图,随后再次加载它们。
Keras 模型和层
请注意,到目前为止,还没有提到 Keras。您可以在 tf.Module
上构建自己的高级 API,而我们已经拥有这些 API。
在本部分中,您将研究 Keras 如何使用 tf.Module
。可在 Keras 指南中找到有关 Keras 模型的完整用户指南。
Keras 层和模型具有许多额外功能,包括:
可选损失
对指标的支持
对可选
training
参数的内置支持,用于区分训练和推断用途保存和恢复 Python 对象而不仅仅是黑盒函数
get_config
和from_config
方法,允许您准确存储配置以在 Python 中克隆模型
这些功能通过子类化允许更复杂的模型,例如自定义 GAN 或变分自编码器 (VAE) 模型。在自定义层和模型的完整指南中阅读相关内容。
Keras 模型还附带额外的功能,使它们易于训练、评估、加载、保存,甚至在多台机器上进行训练。
Keras 层
tf.keras.layers.Layer
是所有 Keras 层的基类,它继承自 tf.Module
。
您只需换出父项,然后将 __call__
更改为 call
即可将模块转换为 Keras 层:
Keras 层有自己的 __call__
,它会进行下一部分中所述的某些簿记,然后调用 call()
。您应当不会看到功能上的任何变化。
build
步骤
如上所述,在您确定输入形状之前,等待创建变量在许多情况下十分方便。
Keras 层具有额外的生命周期步骤,可让您在定义层时获得更高的灵活性。这是在 build()
函数中定义的。
build
仅被调用一次,而且是使用输入的形状调用的。它通常用于创建变量(权重)。
您可以根据输入的大小灵活地重写上面的 MyDense
层:
此时,模型尚未构建,因此没有变量:
调用该函数会分配大小适当的变量。
由于仅调用一次 build
,因此如果输入形状与层的变量不兼容,输入将被拒绝。
Keras 模型
您可以将模型定义为嵌套的 Keras 层。
不过,Keras 还提供了称为 tf.keras.Model
的全功能模型类。它继承自 tf.keras.layers.Layer
,因此 Keras 模型支持以与 Keras 层相同的方式使用和嵌套。Keras 模型还具有额外的功能,这使它们可以轻松训练、评估、加载、保存,甚至在多台机器上进行训练。
您可以使用几乎相同的代码定义上面的 SequentialModule
,再次将 __call__
转换为 call()
并更改父项。
所有相同的功能都可用,包括跟踪变量和子模块。
注:嵌套在 Keras 层或模型中的原始 tf.Module
将不会收集其变量以用于训练或保存。相反,它会在 Keras 层内嵌套 Keras 层。
重写 tf.keras.Model
是一种构建 TensorFlow 模型的极 Python 化方式。如果要从其他框架迁移模型,这可能非常简单。
如果要构造的模型是现有层和输入的简单组合,则可以使用函数式 API 节省时间和空间,此 API 附带有关模型重构和架构的附加功能。
下面是使用函数式 API 构造的相同模型:
这里的主要区别在于,输入形状是作为函数构造过程的一部分预先指定的。在这种情况下,不必完全指定 input_shape
参数;您可以将某些维度保留为 None
。
注:您无需在子类化模型中指定 input_shape
或 InputLayer
;这些参数和层将被忽略。
保存 Keras 模型
Keras 模型拥有自己专门的 zip 归档保存格式,以 .keras
扩展名标记。调用 tf.keras.Model.save
时,在文件名中添加一个 .keras
扩展名。例如:
同样地,它们也可以轻松重新加载:
Keras zip 归档 .keras
文件还可以保存指标、损失和优化器状态。
可以使用此重构模型,并且在相同数据上调用时会产生相同的结果:
设置 Keras 模型检查点
也可以为 Keras 模型设置检查点,这看起来和 tf.Module
一样。
有关保存和序列化 Keras 模型,包括为自定义层提供配置方法来为功能提供支持的更多信息,请参阅保存和序列化指南。