Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/datasets/common_gotchas.md
25115 views

常见实现陷阱

本页介绍了实现新数据集时的常见实现问题。

应避免使用旧版 SplitGenerator

旧的 tfds.core.SplitGenerator API 已弃用。

def _split_generator(...): return [ tfds.core.SplitGenerator(name='train', gen_kwargs={'path': train_path}), tfds.core.SplitGenerator(name='test', gen_kwargs={'path': test_path}), ]

应替换为:

def _split_generator(...): return { 'train': self._generate_examples(path=train_path), 'test': self._generate_examples(path=test_path), }

理由:新的 API 不那么冗长而且更加明确。旧 API 将在未来版本中移除。

新数据集应位于独立的文件夹中

tensorflow_datasets/ 仓库中添加数据集时,请确保遵循数据集即文件夹结构(所有校验和、虚拟数据、实现代码均位于独立的文件夹中)。

  • 旧数据集(不良):<category>/<ds_name>.py

  • 新数据集(良好):<category>/<ds_name>/<ds_name>.py

使用 TFDS CLItfds new,或对于 Google 员工为 gtfds new)生成模板。

理由:旧结构需要校验和、假数据的绝对路径,并且在许多地方分发数据集文件。这样便导致在 TFDS 仓库之外实现数据集变得更加困难。为了保持一致性,现在应当在任何地方使用新结构。

描述列表应格式化为 markdown

DatasetInfo.description str 被格式化为 markdown。Markdown 列表在第一项之前需要一个空行:

_DESCRIPTION = """ Some text. # << Empty line here !!! 1. Item 1 2. Item 1 3. Item 1 # << Empty line here !!! Some other text. """

理由:格式错误的描述会在我们的目录文档中创建可视化工件。如果没有空行,上面的文本将呈现为:

Some text. 1. Item 1 2. Item 1 3. Item 1 Some other text

忘记了 ClassLabel 名称

使用 tfds.features.ClassLabel 时,尝试使用 names=names_file=(而不是 num_classes=10)提供人类可读的标签 str

features = { 'label': tfds.features.ClassLabel(names=['dog', 'cat', ...]), }

理由:许多地方都使用了人类可读的标签:

  • 允许直接在 _generate_examples 中产生 stryield {'label': 'dog'}

  • info.features['label'].names 等用户中公开(转换方法 .str2int('dog')... 也可用)

  • 可视化实用工具 tfds.show_examplestfds.as_dataframe 中使用

忘记了图片形状

使用 tfds.features.Imagetfds.features.Video 时,如果图片具有静态形状,则应明确指定:

features = { 'image': tfds.features.Image(shape=(256, 256, 3)), }

理由:它允许静态形状推断(例如 ds.element_spec['image'].shape),这是批处理所必需的(要批处理未知形状的图片,需要先调整它们的大小)。

首选更具体的类型而不是 tfds.features.Tensor

如果可能,首选更具体的类型 tfds.features.ClassLabeltfds.features.BBoxFeatures,而不是通用的 tfds.features.Tensor

理由:除了在语义上更准确之外,具体特征还为用户提供了额外的元数据并可被工具检测到。

全局空间中的延迟导入

不应从全局空间调用延迟导入。例如,下面的代码是错误的:

tfds.lazy_imports.apache_beam # << Error: Import beam in the global scope def f() -> beam.Map: ...

理由:在全局范围内使用延迟导入会为所有 tf​​ds 用户导入模块,这违背了延迟导入的目的。

动态计算训练/测试拆分

如果数据集不提供官方拆分,TFDS 也不应提供。应避免以下情况:

_TRAIN_TEST_RATIO = 0.7 def _split_generator(): ids = list(range(num_examples)) np.random.RandomState(seed).shuffle(ids) # Split train/test train_ids = ids[_TRAIN_TEST_RATIO * num_examples:] test_ids = ids[:_TRAIN_TEST_RATIO * num_examples] return { 'train': self._generate_examples(train_ids), 'test': self._generate_examples(test_ids), }

理由:TFDS 尝试提供尽可能接近原始数据的数据集。应当使用 sub-split API 来让用户动态创建他们想要的子拆分:

ds_train, ds_test = tfds.load(..., split=['train[:80%]', 'train[80%:]'])

Python 风格指南

首选使用 pathlib API

最好使用 pathlib API 而不是 tf.io.gfile API。所有 dl_manager 方法都返回与 GCS、S3 等兼容的类似路径库的对象

path = dl_manager.download_and_extract('http://some-website/my_data.zip') json_path = path / 'data/file.json' json.loads(json_path.read_text())

理由:pathlib API 是一个移除了样板的面向现代对象的文件 API。使用 .read_text() / .read_bytes() 还可以保证文件正确关闭。

如果方法不使用 self,它应当是一个函数

如果一个类方法不使用 self,它应当是一个简单的函数(定义在类之外)。

理由:它向读者明确表明该函数没有副作用,也没有隐藏的输入/输出:

x = f(y) # Clear inputs/outputs x = self.f(y) # Does f depend on additional hidden variables ? Is it stateful ?

Python 中的延迟导入

我们延迟导入像 TensorFlow 这样的大模块。延迟导入会将模块的实际导入推迟到模块的第一次使用。因此,不需要这个大模块的用户永远不会导入。

from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf # After this statement, TensorFlow is not imported yet ... features = tfds.features.Image(dtype=tf.uint8) # After using it (`tf.uint8`), TensorFlow is now imported

在幕后,LazyModule充当工厂,只有在访问属性 (__getattr__) 时才会实际导入模块。

您还可以通过上下文管理器方便地使用它:

from tensorflow_datasets.core.utils.lazy_imports_utils import lazy_imports with lazy_imports(error_callback=..., success_callback=...): import some_big_module