Path: blob/master/site/zh-cn/datasets/tfless_tfds.ipynb
25115 views
Copyright 2023 The TensorFlow Datasets Authors.
适用于 Jax 和 PyTorch 的 TFDS
TFDS 一直都独立于框架。例如,您可以轻松地加载 NumPy 格式的数据集以在 Jax 和 PyTorch 中使用。
TensorFlow 及其数据加载解决方案 (tf.data
) 按照设计是我们 API 中的一等公民。
我们扩展了 TFDS 以支持仅使用 NumPy 而无需 TensorFlow 的数据加载。这对于在 Jax 和 PyTorch 等机器学习框架中使用非常方便。事实上,对于后者的用户来说,TensorFlow:
会保留 GPU/TPU 内存;
会在 CI/CD 中增加构建时间;
在运行时需要花费时间导入。
TensorFlow 不再是读取数据集的依赖项。
机器学习流水线需要一个数据加载器来加载样本,将其解码并呈现给模型。数据加载器使用“源/采样器/加载器”范式:
数据源负责实时访问和解码来自 TFDS 数据集的样本。
索引采样器负责确定记录处理的顺序。在读取任何记录之前,实现全局转换(例如全局重排、分片、重复多个周期)非常重要。
数据加载器通过利用数据源和索引采样器来编排加载。它可以实现性能优化(例如,预提取、多进程或多线程)。
速览
tfds.data_source
是一个用于创建数据源的 API:
用于纯 Python 流水线的快速原型设计;
用于大规模管理数据密集型机器学习流水线。
安装
让我们安装并导入所需依赖项:
数据源
数据源基本上是 Python 序列。因此,它们需要实现以下协议:
警告:该 API 仍在积极开发中。特别是,__getitem__
目前在输入中必须支持 int
和 list[int]
。将来,按照标准,它可能仅支持 int
。
底层文件格式需要支持高效的随机访问。目前,TFDS 依赖于 array_record
。
array_record
是一种衍生自 Riegeli 的新文件格式,实现了 IO 效率的新前沿。特别是,ArrayRecord 支持按记录索引并行读取、写入和随机访问。ArrayRecord 建立在 Riegeli 之上,并支持相同的压缩算法。
fashion_mnist
是一个常见的计算机视觉数据集。要使用 TFDS 检索基于 ArrayRecord 的数据源,只需使用以下命令:
tfds.data_source
是一个方便的包装器。它等同于:
这将输出一个数据源字典:
一旦 download_and_prepare
运行并在您生成记录文件后,我们就不再需要 TensorFlow 了。一切都将在 Python/NumPy 中完成!
让我们通过卸载 TensorFlow 并在另一个子进程中重新加载数据源对此进行检查:
在未来的版本中,我们还将使数据集准备不再依赖 TensorFlow。
数据源的长度为:
访问数据集的第一个元素:
…开销与访问任何其他元素一样低。下面是随机访问的定义:
特征现在使用 NumPy DType(而不是 TensorFlow DType)。您可以使用以下命令检查特征:
您可以在我们的文档中找到有关特征的更多信息。在这里,我们特别可以检索图像的形状和类别数量:
在纯 Python 中使用
您可以通过迭代来使用 Python 中的数据源:
与 PyTorch 结合使用
PyTorch 使用源/采样器/加载器范式。在 Torch 中,“数据源”称为“数据集”。torch.utils.data
包含构建高效输入流水线所需的所有详细信息。
TFDS 数据源可以像常规的映射样式数据集一样使用。
首先,我们安装并导入Torch:
我们已经为训练和测试分别定义了数据源(分别是 ds['train']
和 ds['test']
)。现在,我们可以定义采样器和加载器:
使用 PyTorch,我们在第一个样本上进行训练,并评估简单的逻辑回归:
即将推出:与 JAX 结合使用
我们正在与 Grain 密切合作。Grain 是适用于 Python 的开源、快速和确定性数据加载器。敬请关注!
阅读更多内容
有关详情,请参阅 tfds.data_source
API 文档。