Path: blob/master/site/zh-cn/guide/function.ipynb
25115 views
Copyright 2020 The TensorFlow Authors.
使用 tf.function 时提升性能
在 TensorFlow 2 中,Eager Execution 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。
您可以使用 tf.function
将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 SavedModel
,则必须使用此工具。
本指南介绍 tf.function
的底层工作原理,让您形成概念化理解,从而有效地加以利用。
要点和建议包括:
先在 Eager 模式下调试,然后使用
@tf.function
进行装饰。不依赖 Python 副作用,如对象变异或列表追加。
tf.function
最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。
安装
定义一个辅助函数来演示可能遇到的错误类型:
基础知识
用法
您定义的 Function
(例如,通过应用 @tf.function
装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。
Function
中可以嵌套其他 Function
。
Function
的执行速度比 Eager 代码快,尤其是对于包含很多简单运算的计算图。但是,对于包含一些复杂运算(如卷积)的计算图,速度提升不会太明显。
跟踪
本部分介绍了 Function
的幕后运作方式,包括未来可能会发生变化的实现细节。但是,当您了解跟踪的原因和时间后,就能够更轻松高效地使用 tf.function
!
什么是“跟踪”?
Function
在 TensorFlow 计算图中运行您的程序。但是,tf.Graph
不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 tf.Graph
要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 tf.Graph
中运行。
Function
通过将代码分为以下两个阶段填补了这一空缺:
第一阶段称为跟踪,在这一阶段中,
Function
会创建新的tf.Graph
。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被推迟:它们会被tf.Graph
捕获而不运行。在第二阶段中,将运行包含第一阶段中推迟的全部内容的
tf.Graph
。此阶段比跟踪阶段快得多。
根据输入,Function
在调用时并非总会运行第一阶段。请参阅下方的跟踪规则以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。
当 Function
决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 Function
会创建并运行 tf.Graph
。稍后,您将了解如何使用 get_concrete_function
来仅运行跟踪阶段。
当您将不同类型的参数传递给 Function
时,两个阶段都将运行:
请注意,如果重复使用同一参数类型调用 Function
,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。
您可以使用 pretty_printed_concrete_signatures()
查看所有可用跟踪记录:
目前,您已经了解 tf.function
通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:
tf.Graph
与语言无关,是 TensorFlow 计算的原始可移植表示。ConcreteFunction
封装tf.Graph
。Function
管理ConcreteFunction
的缓存,并为输入选择正确的缓存。tf.function
封装 Python 函数,并返回一个Function
对象。跟踪会创建
tf.Graph
并将其封装在ConcreteFunction
中,也称为跟踪。
跟踪规则
被调用时,Function
使用每个参数的 tf.types.experimental.TraceType
将调用参数与现有的 ConcreteFunction
匹配。如果找到匹配的 ConcreteFunction
,则将调用分派给它。如果未找到匹配项,则跟踪新的 ConcreteFunction
。
如果找到多个匹配项,则会选择最具体的签名。匹配是通过子类型化完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,TensorShape([1, 2])
是 TensorShape([None, None])
的子类型,因此可以将使用 TensorShape([1, 2])
对 tf.function 进行的调用分派到使用 TensorShape([None, None])
生成的 ConcreteFunction
。但是,如果具有 TensorShape([1, None])
的 ConcreteFunction
也存在,那么它将被优先考虑,因为它更具体。
TraceType
由输入参数确定,具体如下所示:
对于
Tensor
,类型由Tensor
的dtype
和shape
参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型对于
Variable
,类型类似于Tensor
,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的对于 Python 基元值,类型对应于值本身。例如,值为
3
的TraceType
是LiteralTraceType<3>
,而不是int
。对于
list
和tuple
等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,[1, 2]
的类型是ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
,[2, 1]
的类型是ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
,两者不同。对于
dict
等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,{1: 2, 3: 4}
的类型为MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
。但是,与有序容器不同的是,{1: 2, 3: 4}
和{3: 4, 1: 2}
具有等价的类型。对于实现
__tf_tracing_type__
方法的 Python 对象,类型为该方法返回的任何内容对于任何其他 Python 对象,类型是通用的
TraceType
,匹配过程如下:首先,它检查该对象与先前跟踪中使用的对象是否相同(使用
id()
或is
)。请注意,如果对象已更改,这仍然会匹配,因此如果您使用 Python 对象作为tf.function
参数,最好使用不可变对象。接下来,它检查该对象是否等于先前跟踪中使用的对象(使用 python
==
)。
请注意,此过程仅保留对象的弱引用,因此仅在对象处于范围内/未被删除时有效。)
控制回溯
回溯即 Function
创建多个跟踪记录的过程,可以确保 TensorFlow 为每组输入生成正确的计算图。但是,跟踪非常消耗资源!如果 Function
为每一次调用都回溯新的计算图,您会发现代码的执行速度远不如不使用 tf.function
时快。
要控制跟踪行为,可以采用以下技巧:
将固定的 input_signature
传递给 tf.function
使用未知维度以获得灵活性
由于 TensorFlow 根据其形状匹配张量,因此,对于可变大小输入,使用 None
维度作为通配符可以让 Function
重复使用跟踪记录。对于每个批次,如果有不同长度的序列或不同大小的图像,则会出现可变大小输入(请参阅 Transformer 和 Deep Dream 教程了解示例)。
传递张量而不是 Python 文字
通常,Python 参数用于控制超参数和计算图构造,例如 num_layers=10
、training=True
或 nonlinearity='relu'
。所以,如果 Python 参数改变,则有必要回溯计算图。
但是,Python 参数有可能并未用于控制计算图构造。在这些情况下,Python 值的改变可能触发非必要的回溯。例如,在此训练循环中,AutoGraph 会动态展开。尽管有多个跟踪,但生成的计算图实际上是相同的,所以没有必要进行回溯。
如果需要强制执行回溯,可以创建一个新的 Function
。单独的 Function
对象肯定不会共享跟踪记录。
使用跟踪协议
在可能的情况下,您应当首选将 Python 类型转换为 tf.experimental.ExtensionType
。此外,ExtensionType
的 TraceType
是与其关联的 tf.TypeSpec
。因此,如果需要,您只需重写默认的 tf.TypeSpec
即可控制 ExtensionType
的 Tracing Protocol
。请参阅扩展程序类型指南中的自定义 ExtensionType 的 TypeSpec部分以了解详情。
否则,要直接控制 Function
何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 Tracing Protocol
。
获取具体函数
每次跟踪函数时都会创建一个新的具体函数。您可以使用 get_concrete_function
直接获取具体函数。
打印 ConcreteFunction
会显示其输入参数(及类型)和输出类型的摘要。
您也可以直接检索具体函数的签名。
对不兼容的类型使用具体跟踪记录会引发错误
您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中移除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。
获取计算图
每个具体函数都是 tf.Graph
的可调用封装容器。虽然一般不需要检索实际 tf.Graph
对象,不过,您可以从任何具体函数轻松获得实际对象。
调试
通常,在 Eager 模式下调试代码比在 tf.function
中简单。在使用 tf.function
进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 tf.config.run_functions_eagerly(True)
来全局停用和重新启用 tf.function
。
追溯仅在 tf.function
中出现的问题时,可参考下面的几点提示:
普通旧 Python
print
调用仅在跟踪期间执行,可用于追溯(重新)跟踪函数的时间。tf.print
调用每次都会执行,可用于追溯执行过程中产生的中间值。利用
tf.debugging.enable_check_numerics
很容易追溯到 NaN 和 Inf 在何处创建。pdb
(Python 调试器)可以帮助您理解跟踪的详细过程。(提醒:使用pdb
调试时,AutoGraph 会自动转换 Python 源代码。)
AutoGraph 转换
AutoGraph 是一个库,在 tf.function
中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 if
、for
、while
等控制流。
tf.cond
和 tf.while_loop
等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。
如果您有兴趣,可以检查 Autograph 生成的代码。
条件语句
AutoGraph 会将某些 if <condition>
语句转换为等效的 tf.cond
调用。如果 <condition>
是张量,则会执行这种替换,否则会将 if
语句作为 Python 条件语句执行。
Python 条件语句在跟踪时执行,因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph,当存在依赖于数据的控制流时,此跟踪计算图将无法选择替代分支。
tf.cond
跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 AutoGraph 跟踪作用以了解详情。
有关 AutoGraph 转换的 if 语句的其他限制,请参阅参考文档。
循环
AutoGraph 会将某些 for
和 while
语句转换为等效的 TensorFlow 循环运算,例如 tf.while_loop
。如果不转换,则会将 for
或 while
循环作为 Python 循环执行。
以下情形会执行这种替换:
for x in y
:如果y
是一个张量,则转换为tf.while_loop
。在特殊情况下,如果y
是tf.data.Dataset
,则会生成tf.data.Dataset
运算的组合。while <condition>
:如果<condition>
是张量,则转换为tf.while_loop
。
Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 tf.Graph
。
TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 tf.Graph
中出现一次。
有关 AutoGraph 转换的 for
和 while
语句的其他限制,请参阅参考文档。
在 Python 数据上循环
一个常见陷阱是在 tf.function
中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 tf.Graph
。
如果要在 tf.function
中封装整个训练循环,最安全的方式是将数据封装为 tf.data.Dataset
,以便 AutoGraph 动态展开训练循环。
在数据集中封装 Python/Numpy 数据时,要注意 tf.data.Dataset.from_generator
与 tf.data.Dataset.from_tensors
。前者将数据保留在 Python 中,并通过 tf.py_function
获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 tf.constant()
节点,这可能会消耗较多内存。
通过 TFRecordDataset
、CsvDataset
等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 tf.data
:构建 TensorFlow 输入流水线指南。
累加循环值
一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方式无法达到预期效果。要从动态展开循环累加结果,可以使用 tf.TensorArray
来实现。
限制
TensorFlow Function
有意设计了一些限制,在将 Python 函数转换为 Function
时需加以注意。
执行 Python 副作用
副作用(如打印、附加到列表、改变全局变量)在 Function
内部可能会出现异常行为,有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 Function
时发生。之后,将重新执行跟踪的 tf.Graph
,而不执行 Python 代码。
一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 tf.data
、tf.print
、tf.summary
、tf.Variable.assign
和 tf.TensorArray
)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。
如果希望在每次调用 Function
时都执行 Python 代码,tf.py_function
可以作为退出点。tf.py_function
的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 tf.py_function
必须连接到计算图中,它会将所有输入/输出转换为张量。
更改 Python 全局变量和自由变量
更改 Python 全局变量和自由变量视为 Python 副作用,因此仅在跟踪期间发生。
有时很难注意到意外行为。在下面的示例中,counter
旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 tf.function
时,assign_add
将被无条件记录在底层计算图中。因此,每次调用 tf.function
时 v
都会增加 1。当使用 Python 副作用(示例中的 counter
)确定要运行的运算(示例中的 assign_add
)时,此问题在尝试使用 tf.function
装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。
实现预期行为的一种解决方法是使用 tf.init_scope
将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,init_scope
还有其他副作用,包括清除控制流和梯度带。有时 init_scope
的使用会变得过于复杂而无法实际管理。
总之,根据经验,您应避免改变整数或容器(如位于 Function
外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,在循环中累加值部分中提供了一个如何实现类列表运算的示例。
在某些情况下,如果为 tf.Variable
,则您可以捕获和处理状态。这是通过重复调用相同的 ConcreteFunction
来更新 Keras 模型权重的方式。
使用 Python 迭代器和生成器
很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。
就像 TensorFlow 具有用于列表构造的专用 tf.TensorArray
一样,它也具有用于迭代构造的专用 tf.data.Iterator
。有关概述,请参阅 AutoGraph 转换部分。此外,tf.data
API 也可帮助实现生成器模式:
tf.function 的所有输出都必须是返回值
除了 tf.Variable
外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。
例如,下面的函数通过 Python 全局变量 x
“泄漏”张量 a
:
即使同时返回泄漏的值时也是如此:
通常,当您使用 Python 语句或数据结构时,会发生此类泄漏。除了泄漏不可访问的张量之外,此类语句也可能是错误的,因为它们被视为 Python 副作用,而且不能保证在每次函数调用时都执行。
泄漏局部张量的常见方法还包括改变外部 Python 集合或对象:
不支持递归 tf.functions
不支持递归 Function
,它们可能导致无限循环。例如:
即使递归 Function
看似有效,Python 函数也会被多次跟踪,并且可能会对性能产生影响。例如:
已知问题
如果您的 Function
评估不正确,则这些计划于将来得到修复的已知问题可能可以解释该问题。
取决于 Python 全局变量和自由变量
当使用 Python 参数的新值进行调用时,Function
会创建新的 ConcreteFunction
。但是,对于该 Function
的 Python 闭包、全局变量或非局部变量,则不会创建。如果它们的值在调用 Function
之间发生变化,则 Function
仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。
因此,您应采用使用参数的函数式编程风格而非闭合外部名称。
更新全局值的另一种方法是使其成为 tf.Variable
并改用 Variable.assign
方法。
依赖于 Python 对象
使用相同的 Function
评估模型的修改实例并不合理,因为它仍然具有与原始模型相同的基于实例的 TraceType。
因此,建议您编写 Function
以避免依赖于可变对象特性,或者为对象实现跟踪协议以将此类特性通知给 Function
。
如果这不可行,则一种解决方法是,每次修改对象时都创建新的 Function
以强制回溯:
回溯可能十分耗费资源,您可以使用 tf.Variable
作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。
创建 tf.Variables
Function
仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 tf.Variable
。下面的代码段会在每个函数调用中创建一个新的 tf.Variable
,这会导致 ValueError
异常。
示例:
用于解决这种限制的常见模式是从 Python None 值开始,随后,在值为 None 时,有条件地创建 tf.Variable
:
与多个 Keras 优化器一起使用
将多个 Keras 优化器与 tf.function
一起使用时,您可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。发生此错误的原因是优化器在首次应用梯度时会在内部创建 tf.Variables
。
如果您需要在训练期间更改优化器,一种解决方法是为每个优化器创建一个新的 Function
,直接调用 ConcreteFunction
。
与多个 Keras 模型一起使用
将不同的模型实例传递给同一 Function
时,您也可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
。
发生此错误的原因是 Keras 模型(未定义其输入形状)和 Keras 层会在首次调用时创建 tf.Variables
。您可能正在尝试在已调用的 Function
中初始化这些变量。为避免此错误,请在训练模型之前尝试调用 model.build(input_shape)
以初始化所有权重。
延伸阅读
要了解如何导出和加载 Function
,请参阅 SavedModel 指南。要详细了解跟踪后执行的计算图优化,请参阅 Grappler 指南。要了解如何优化数据流水线和剖析模型性能,请参阅 Profiler 指南。