Copyright 2018 The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
tf.data:构建 TensorFlow 输入流水线
借助 tf.data
API,您可以根据简单的可重用片段构建复杂的输入流水线。例如,图像模型的流水线可以聚合来自分布式文件系统中文件的数据,对每个图像应用随机扰动,并将随机选中的图像合并成一个批次进行训练。文本模型的流水线可能涉及从原始文本数据中提取符号,将提取的符号转换为带有查找表的嵌入向量标识符,并将不同长度的序列组合成一个批次。tf.data
API 可以处理大量数据、从不同的数据格式读取数据和执行复杂的转换。
tf.data
API 引入了一个 tf.data.Dataset
抽象,它表示一个元素序列,其中每个元素都由一个或多个组件组成。例如,在一个图像流水线中,一个元素可能是一个训练样本,有一对表示图像及其标签的张量组件。
您可以通过两种不同的方式创建数据集:
数据源从存储在内存中或存储在一个或多个文件中的数据构造
Dataset
。数据转换从一个或多个
tf.data.Dataset
对象构造数据集。
基本机制
要创建输入流水线,您必须从数据源开始。例如,要从内存中的数据构造一个 Dataset
,您可以使用 tf.data.Dataset.from_tensors()
或 tf.data.Dataset.from_tensor_slices()
。或者,如果您的输入数据以推荐的 TFRecord 格式存储在文件中,则您可以使用 tf.data.TFRecordDataset()
。
有了一个 Dataset
对象之后,您可以通过链接 tf.data.Dataset
对象上的方法调用将其转换成一个新的 Dataset
。例如,您可以应用逐元素转换(例如 Dataset.map
)和多元素转换(例如 Dataset.batch
)。有关完整的转换列表,请参阅 tf.data.Dataset
文档。
Dataset
对象是一个 Python 可迭代对象。这使得利用 for 循环使用它的元素成为可能:
或者使用 iter
显式创建一个 Python 迭代器,并利用 next
来使用它的元素:
或者,还可以利用 reduce
转换来使用数据集元素,从而减少所有元素以生成单个结果。以下示例演示了如何使用 reduce
转换来计算整数数据集的和。
数据集结构
数据集会生成一系列元素,其中每个元素都是组件的相同(嵌套)结构。结构的各个组件可以是可由 tf.TypeSpec
表示的任何类型,包括 tf.Tensor
、tf.sparse.SparseTensor
、tf.RaggedTensor
、tf.TensorArray
或 tf.data.Dataset
。
可用于表达元素的(嵌套)结构的 Python 构造包括 tuple
、dict
、NamedTuple
和 OrderedDict
。特别要指出的是,list
不是用于表达数据集元素结构的有效构造。这是因为早期的 tf.data
用户坚决要求将 list
输入(例如,传递到 tf.data.Dataset.from_tensors
时)自动打包为张量,并将 list
输出(例如,用户定义函数的返回值)强制转换为 tuple
。因此,如果您希望将 list
输入视为结构,则需要将其转换为 tuple
,而如果要将 list
输出作为单个组件,则需要使用 tf.stack
将其显式打包。
Dataset.element_spec
属性允许您检查每个元素组件的类型。该属性会返回 tf.TypeSpec
对象与元素结构相匹配的嵌套结构,可以是单个组件、组件元组,或者组件的嵌套元组。例如:
Dataset
转换支持任何结构的数据集。使用 Dataset.map
和 Dataset.filter
转换时会将函数应用于每个元素,而元素结构会决定函数的参数:
读取输入数据
使用 NumPy 数组
请参阅加载 NumPy 数组教程了解更多示例。
如果您所有的输入数据都适合装入内存,那么从这些数据创建 Dataset
的最简单方式是将它们转换为 tf.Tensor
对象并使用 Dataset.from_tensor_slices
。
注:上面的代码段会将 features
和 labels
数组作为 tf.constant()
运算嵌入到 TensorFlow 计算图中。这对于小数据集来说效果很好,但是会浪费内存(因为数组的内容会被多次复制),并且可能会达到 tf.GraphDef
协议缓冲区的 2GB 上限。
使用 Python 生成器
另一个可被轻松整合为 tf.data.Dataset
的常用数据源是 Python 生成器。
小心:虽然这种方式比较简便,但它的可移植性和可扩缩性有限。它必须在创建生成器的同一 Python 进程中运行,且仍受 Python GIL 约束。
Dataset.from_generator
构造函数会将 Python 生成器转换为具有完整功能的 tf.data.Dataset
。
构造函数会获取可调用对象作为输入,而非迭代器。这样,构造函数结束后便可重启生成器。构造函数会获取一个可选的 args
参数,作为可调用对象的参数。
output_types
参数是必需的,因为 tf.data
会在内部构建 tf.Graph
,而计算图边缘需要 tf.dtype
。
output_shapes
参数虽然不是必需的,但强烈建议添加,因为许多 TensorFlow 运算不支持秩未知的张量。如果特定轴的长度未知或可变,请在 output_shapes
中将其设置为 None
。
还需要注意的是,output_shapes
和 output_types
与其他数据集方法遵循相同的嵌套规则。
下面的示例生成器对两方面进行了演示,它会返回由数组组成的元组,其中第二个数组是长度未知的向量。
第一个输出是 int32
,第二个输出是 float32
。
第一个条目是标量,形状为 ()
,第二个条目是长度未知的向量,形状为 (None,)
现在,可以将它当作常规 tf.data.Dataset
使用。请注意,在批处理形状可变的数据集时,需要使用 Dataset.padded_batch
。
举个更加实际的例子,请尝试将 preprocessing.image.ImageDataGenerator
封装为 tf.data.Dataset
。
先下载数据:
创建 image.ImageDataGenerator
处理 TFRecord 数据
请参阅加载 TFRecord 教程了解端到端示例。
tf.data
API 支持多种文件格式,因此可以处理不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种简单的、面向记录的二进制格式,许多 TensorFlow 应用都将其用于训练数据。您可以利用 tf.data.TFRecordDataset
类将一个或多个 TFRecord 文件的内容作为输入流水线的一部分进行流式传输。
下面的示例使用了来自 French Street Name Signs (FSNS) 的测试文件。
TFRecordDataset
初始值设定项的 filenames
参数可以是字符串、字符串列表,或由字符串组成的 tf.Tensor
。因此,如果您有两组分别用于训练和验证的文件,则可以创建一个工厂方法来生成数据集,并将文件名作为输入参数:
许多 TensorFlow 项目在其 TFRecord 文件中使用序列化的 tf.train.Example
记录。这些记录需要在检查前进行解码:
使用文本数据
请参阅加载文本教程了解端到端示例。
许多数据集会作为一个或多个文本文件进行分发。tf.data.TextLineDataset
提供了一种从一个或多个文本文件中提取行的简便方式。如果给定一个或多个文件名,TextLineDataset
会为这些文件的每一行生成一个字符串元素。
这是第一个文件的前几行:
要交错不同文件中的行,请使用 Dataset.interleave
。这样可以更轻松地重排文件。以下是来自每个转换的第一、第二和第三行:
默认情况下,TextLineDataset
会生成每个文件的每一行,但这可能并不理想(例如,有时文件会以标题行开始,或者包含注释)。可以使用 Dataset.skip()
或 Dataset.filter
转换移除这些行。如下所示,跳过第一行,然后过滤出剩余内容。
使用 CSV 数据
如果您的数据适合存储在内存中,那么 Dataset.from_tensor_slices
方法对字典同样有效,使这些数据可以被轻松导入:
更具可扩展性的方式是根据需要从磁盘加载。
tf.data
模块提供了从一个或多个符合 RFC 4180 的 CSV 文件提取记录的方法。
experimental.make_csv_dataset
函数是用来读取 CSV 文件集的高级接口。它支持列类型推断和许多其他功能,如批处理和重排,以简化使用。
如果只需要列的一个子集,您可以使用 select_columns
参数。
还有一个级别更低的 experimental.CsvDataset
类,该类可以提供粒度更细的控制,但它不支持列类型推断,您必须指定每个列的类型。
如果某些列为空,则此低级接口允许您提供默认值,而非列类型。
默认情况下,CsvDataset
会生成文件每一行的每一列,但这可能并不理想(例如,有时文件会以一个应该忽略的标题行开始,或者输入中不需要某些列)。可以分别使用 header
和 select_cols
参数移除这些行和字段。
使用文件集
许多数据集会作为文件集进行分发,其中,每个文件都是一个样本。
注:这些图像已获得 CC-BY 许可,请参阅 LICENSE.txt 以了解详情。
根目录包含每个类的路径:
每个类目录中的文件是样本:
使用 tf.io.read_file
函数读取数据,并从路径提取标签,返回 (image, label)
对:
批处理数据集元素
简单批处理
最简单的批处理方式是将数据集的 n
个连续元素堆叠成单个元素。Dataset.batch()
转换就负责执行此操作,它有和 tf.stack()
算子相同的约束,应用于元素的每个组件:也就是说,对于每个组件 i,所有元素都必须有一个形状完全相同的张量。
当 tf.data
试图传播形状信息时,Dataset.batch
的默认设置会导致未知的批次大小,因为最后一个批次可能不完整。请注意形状中的 None
:
使用 drop_remainder
参数忽略最后一个批次,以获得完整的形状传播:
批处理带填充的张量
上述方式适用于所有具有相同大小的张量。然而,许多模型(包括序列模型)处理的输入数据可能具有不同的大小(例如,长度不同的序列)。为了处理这种情况,可以通过 Dataset.padded_batch
转换指定一个或多个可能被填充的维度,从而批处理不同形状的张量。
您可以通过 Dataset.padded_batch
转换为每个组件的每个维度设置不同的填充,它可以是可变长度(在上面的示例中由 None
表示)或是固定长度。它还可以重写填充值,填充值默认为 0。
训练工作流
处理多个周期
tf.data
API 提供了两种主要方式来处理同一数据的多个周期。
要在多个周期内迭代数据集,最简单的方式是使用 Dataset.repeat()
转换。首先,创建一个由 Titanic 数据组成的数据集:
如果应用不带参数的 Dataset.repeat()
转换,将无限次地重复输入。
Dataset.repeat
转换会连接其参数,而不会在一个周期结束和下一个周期开始时发出信号。因此,在 Dataset.repeat
之后应用的 Dataset.batch
将生成跨越周期边界的批次:
如果您需要明确的周期分隔,请将 Dataset.batch
置于重复前:
如果您想在每个周期结束时执行自定义计算(例如,收集统计信息),最简单的方式是在每个周期上重新启动数据集迭代:
随机重排输入数据
Dataset.shuffle()
转换会维持一个固定大小的缓冲区,并从该缓冲区均匀地随机选择下一个元素:
注:虽然较大的 buffer_size 可以更彻底地重排,但可能会占用大量的内存和时间来填充。如果这成为问题,请考虑跨文件使用 Dataset.interleave
。
向数据集添加索引,便能看到效果:
由于 buffer_size
是 100,而批次大小是 20,第一个批次不包含索引大于 120 的元素。
对于 Dataset.batch
,与 Dataset.repeat
的相对顺序很重要。
在重排缓冲区为空之前,Dataset.shuffle
不会发出周期结束的信号。因此,置于重复之前的重排会先显示一个周期内的每个元素,然后移至下一个周期:
但在重排之前的重复会将周期边界混合在一起:
预处理数据
Dataset.map(f)
转换会通过对输入数据集的每个元素应用一个给定函数 f
来生成一个新的数据集。它基于 map()
函数,该函数通常应用于函数式编程语言中的列表(和其他结构)。函数 f
会获取在输入中表示单个元素的 tf.Tensor
对象,并返回在新数据集中表示单个元素的 tf.Tensor
对象。它的实现使用标准的 TensorFlow 运算来将一个元素转换为另一个元素。
本部分介绍了关于 Dataset.map()
使用方法的常见示例。
解码图像数据并调整大小
使用真实的图像数据训练神经网络时,常常需要将不同大小的图像转换为统一大小,以便将其批处理成某一固定大小。
重建花卉文件名数据集:
编写一个操作数据集元素的函数。
测试它的有效性。
将它映射到数据集。
应用任意 Python 逻辑
出于性能考虑,请尽可能使用 TensorFlow 运算预处理数据。不过,在解析输入数据时,调用外部 Python 库有时会很有帮助。您可以在 Dataset.map
转换中使用 tf.py_function
运算。
例如,如果您想应用一个随机旋转,而 tf.image
模块只有 tf.image.rot90
,这对图像增强不是很有帮助。
注:tensorflow_addons
在 tensorflow_addons.image.rotate
中有一个与 TensorFlow 兼容的 rotate
。
为了演示 tf.py_function
,请尝试使用 scipy.ndimage.rotate
函数:
要将此函数用于 Dataset.map
,需留意与 Dataset.from_generator
相同的注意事项,在应用该函数时需要描述返回的形状和类型:
解析 tf.Example
协议缓冲区消息
许多输入流水线从 TFRecord 格式文件中提取 tf.train.Example
协议缓冲区消息。每个 tf.train.Example
记录包含一个或多个“特征”,而输入流水线通常会将这些特征转换为张量。
您可以在 tf.data.Dataset
外部使用 tf.train.Example
协议来理解数据:
有关端到端的时间序列示例,请参阅:时间序列预测。
时间序列数据通常以保持完整时间轴的方式进行组织。
用一个简单的 Dataset.range
来演示:
通常,基于此类数据的模型需要一个连续的时间片。
最简单的方式是批处理这些数据:
使用 batch
或者,要对未来进行一步密集预测,您可以将特征和标签相对彼此移动一步:
要预测整个窗口而非一个固定偏移量,您可以将批次分成两部分:
要允许一个批次的特征和另一批次的标签部分重叠,请使用 Dataset.zip
:
使用 window
尽管使用 Dataset.batch
奏效,某些情况可能需要更精细的控制。Dataset.window
方法可以为您提供完全控制,但需要注意:它返回的是由 Datasets
组成的 Dataset
。请参阅数据集结构部分以了解详情。
Dataset.flat_map
方法可以获取由数据集组成的数据集,并将其合并为一个数据集:
在几乎所有情况下,您都需要先 Dataset.batch
数据集:
现在,您可以看到 shift
参数控制着每个窗口的移动量。
将上述内容整合起来,您就可以编写出以下函数:
然后,可以像之前一样轻松提取标签:
现在,检查类的分布,它是高度倾斜的:
使用不平衡的数据集进行训练的一种常用方式是使其平衡。tf.data
包括一些能够让此工作流变得可行的方法:
数据集采样
对数据集重新采样的一种方式是使用 sample_from_datasets
。当每个类都有单独的 tf.data.Dataset
时,这种方式更加适用。
下面用过滤器从信用卡欺诈数据中生成一个重采样数据集:
要使用 tf.data.Dataset.sample_from_datasets
传递数据集以及每个数据集的权重,请运行以下代码:
现在,数据集为每个类生成样本的概率是 50/50:
拒绝重采样
上述 Dataset.sample_from_datasets
方式的一个问题是每个类需要一个单独的 tf.data.Dataset
。您可以使用 Dataset.filter
创建这两个数据集,但这会导致所有数据被加载两次。
可以将 tf.data.Dataset.rejection_resample
方法应用于数据集以使其重新平衡,而且只需加载一次数据。元素将被删除或重复,以实现平衡。
rejection_resample
需要一个 class_func
参数。这个 class_func
参数会被应用至每个数据集元素,并且会被用来确定某个样本属于哪一类,以实现平衡的目的。
这里的目标是平衡标签分布,而 creditcard_ds
的元素已经形成 (features, label)
对。因此,class_func
只需返回这些标签:
重采样方法会处理单个样本,因此您必须在应用该方法前 unbatch
数据集。
该方法需要一个目标分布,以及一个可选的初始分布估计作为输入:
rejection_resample
方法会返回 (class, example)
对,其中的 class
为 class_func
的输出。在此例中,example
已经形成 (feature, label)
对,因此请使用 map
删除多余的标签副本:
现在,数据集为每个类生成样本的概率是 50/50:
迭代器检查点操作
TensorFlow 支持获取检查点 ,这样当训练过程重启时,可以还原至最新的检查点,从而恢复大部分进度。除了可以对模型变量进行检查点操作外,您还可以为数据集迭代器的进度设置检查点。如果您有一个很大的数据集,并且不想在每次重启后都从头开始,此功能会非常有用。但是请注意,迭代器检查点可能会很大,因为像 Dataset.shuffle
和 Dataset.prefetch
之类的转换需要在迭代器内缓冲元素。
要在检查点中包含迭代器,请将迭代器传递至 tf.train.Checkpoint
构造函数。
注:无法为依赖于外部状态的迭代器(例如 tf.py_function
)设置检查点。尝试这样做将引发外部状态异常。
结合使用 tf.data 与 tf.keras
tf.keras
API 在创建和执行机器学习模型的许多方面进行了简化。它的 Model.fit
、Model.evaluate
和 Model.predict
API 支持将数据集作为输入。下面是一个快速的数据集和模型设置:
只需向 Model.fit
和 Model.evaluate
传递一个由 (feature, label)
对组成的数据集:
如果您要传递一个无限大的数据集(比如通过调用 Dataset.repeat
),您只需要同时传递 steps_per_epoch
参数:
对于评估,可以传递评估步数:
对于长数据集,可以设置要评估的步数:
调用 Model.predict
时不需要标签。
但是,如果您传递了包含标签的数据集,则标签会被忽略: