Path: blob/master/site/zh-cn/guide/effective_tf2.ipynb
25115 views
Copyright 2020 The TensorFlow Authors.
高效的 TensorFlow 2
概述
本指南提供了使用 TensorFlow 2 (TF2) 编写代码的最佳做法列表,此列表专为最近从 TensorFlow 1 (TF1) 切换过来的用户编写。有关将 TF1 代码迁移到 TF2 的更多信息,请参阅指南的迁移部分。
设置
为本指南中的示例导入 TensorFlow 和其他依赖项。
惯用 TensorFlow 2 的建议
将代码重构为更小的模块
一种良好做法是将代码重构为根据需要调用的更小函数。为了获得最佳性能,您应当尝试在 tf.function
中装饰最大的计算块(请注意,由 tf.function
调用的嵌套 Python 函数不需要自己单独的装饰,除非您想为 tf.function
使用不同的 jit_compile
设置)。根据您的用例,这可能是多个训练步骤,甚至是整个训练循环。对于推断用例,它可能是单个模型前向传递。
使用 tf.Module
和 Keras 层管理变量
tf.Module
和 tf.keras.layers.Layer
提供了方便的 variables
和 trainable_variables
属性,它们以递归方式收集所有因变量。这样便可轻松在使用变量的地方对它们进行本地管理。
Keras 层/模型继承自 tf.train.Checkpointable
并与 @tf.function
集成,这样便有可能从 Keras 对象直接导出 SavedModel 或为其添加检查点。您不必使用 Keras的 Model.fit
API 来利用这些集成。
阅读 Keras 指南中有关迁移学习和微调的部分,了解如何使用 Keras 收集相关变量的子集。
结合 tf.data.Dataset
和 tf.function
TensorFlow Datasets 软件包 (tfds) 包含用于将预定义数据集作为 tf.data.Dataset
对象加载的的实用工具。对于此示例,您可以使用 tfds
加载 MNIST 数据集:
然后,准备用于训练的数据:
重新缩放每个图像;
重排样本顺序。
收集图像和标签批次。
为了使样本简短,将数据集修剪为仅返回 5 个批次:
使用常规 Python 迭代来迭代适合装入内存的训练数据。除此之外,tf.data.Dataset
是从磁盘流式传输训练数据的最佳方式。数据集是可迭代对象(但不是迭代器),就像其他 Eager Execution 中的 Python 可迭代对象一样。您可以通过将代码封装在 tf.function
中来充分利用数据集异步预提取/流式传输功能,此代码将 Python 迭代替换为使用 AutoGraph 的等效计算图运算。
如果您使用 Keras Model.fit
API,则不必担心数据集迭代。
使用 Keras 训练循环
如果您不需要对训练过程进行低级控制,建议使用 Keras 的内置 fit
、evaluate
和 predict
方法。无论实现方式(顺序、函数或子类化)如何,这些方法都能提供统一的接口来训练模型。
这些方法的优点包括:
接受 Numpy 数组、Python 生成器和
tf.data.Datasets
。自动应用正则化和激活损失。
支持
tf.distribute
,无论硬件配置如何,训练代码都保持不变。支持将任意可调用对象作为损失和指标。
支持
tf.keras.callbacks.TensorBoard
之类的回调以及自定义回调。性能出色,可以自动使用 TensorFlow 计算图。
下面是使用 Dataset
训练模型的示例。要详细了解工作原理,请参阅教程。
自定义训练并编写自己的循环
如果 Keras 模型适合您,但您需要更大的灵活性和对训练步骤或外层训练循环的控制,您可以实现自己的训练步骤甚至整个训练循环。如需了解详情,请参阅有关自定义 fit
的 Keras 指南。
此外 ,您还可以将许多内容作为 tf.keras.callbacks.Callback
实现。
这种方法具有前面提到的许多优点,但可以让您控制训练步骤甚至外层循环。
标准训练循环分为三个步骤:
迭代 Python 生成器或
tf.data.Dataset
来获得样本批次。使用
tf.GradientTape
收集梯度。使用
tf.keras.optimizers
之一将权重更新应用于模型的变量。
请记住:
始终在子类化层和模型的
call
方法上包含一个training
参数。确保在
training
参数正确设置的情况下调用模型。根据用法,在对一批数据运行模型之前,模型变量可能不存在。
您需要手动处理模型的正则化损失这类问题。
无需运行变量初始值设定项或添加手动控制依赖项。tf.function
会在创建时为您处理自动控制依赖项和变量初始化。
通过 Python 控制流充分利用 tf.function
tf.function
提供了一种将依赖于数据的控制流转换为计算图模式等效项(如 tf.cond
和 tf.while_loop
)的方法。
数据依赖控制流出现的一个常见位置是序列模型。tf.keras.layers.RNN
封装一个 RNN 单元,允许您以静态或动态方式展开递归。例如,您可以按照下文所述重新实现动态展开。
阅读 tf.function
指南以了解更多信息。
新型指标和损失
指标和损失均为对象,两者都在 Eager 模式下工作,且都位于 tf.function
中。
损失对象是可调用对象,并使用 (y_true
, y_pred
) 作为参数:
使用指标收集和显示数据
您可以使用 tf.metrics
聚合数据,使用 tf.summary
记录摘要并使用上下文管理器将其重定向到编写器。摘要会直接发送到编写器,这意味着您必须在调用点提供 step
值。
要在将数据记录为摘要之前对其进行聚合,请使用 tf.metrics
。指标是有状态的;它们积累值并在您调用 result
方法(例如 Mean.result
)时返回累积结果。可以使用 Model.reset_states
清除累积值。
通过将 TensorBoard 指向摘要日志目录来呈现生成的摘要:
使用 tf.summary
API 编写要在 TensorBoard 中呈现的摘要数据。有关更多信息,请阅读 tf.summary
指南。
调试
使用 Eager Execution 可以分步运行代码来检查形状、数据类型和值。某些 API(如 tf.function
、tf.keras
等)设计为使用计算图执行来提高性能和可移植性。调试时,使用 tf.config.run_functions_eagerly(True)
可以在此代码内使用 Eager Execution。
例如:
这也可以在 Keras 模型和其他支持 Eager Execution 的 API 中使用:
不要在您的对象中保留 tf.Tensors
这些张量对象可能会在 tf.function
或 Eager 上下文中创建,并且这些张量的行为有所不同。始终仅将 tf.Tensor
用于中间值。
要跟踪状态,请使用 tf.Variable
,因为它们始终可用于两种上下文。阅读 tf.Variable
指南以了解更多信息。