Path: blob/master/site/zh-cn/tutorials/load_data/images.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
加载和预处理图像
本教程介绍如何以三种方式加载和预处理图像数据集:
首先,您将使用高级 Keras 预处理效用函数(例如
tf.keras.utils.image_dataset_from_directory
)和层(例如tf.keras.layers.Rescaling
)来读取磁盘上的图像目录。然后,您将使用 tf.data 从头编写自己的输入流水线。
最后,您将从 TensorFlow Datasets 中的大型目录下载数据集。
配置
检索图片
本教程使用一个包含数千张花卉照片的数据集。该花卉数据集包含 5 个子目录,每个子目录对应一个类:
注:所有图像均获得 CC-BY 许可,创作者在 LICENSE.txt 文件中列出。
下载 (218MB) 后,您现在应该拥有花卉照片的副本。总共有 3670 个图像:
每个目录都包含该类型花卉的图像。下面是一些玫瑰:
使用 Keras 效用函数加载数据
让我们使用实用的 tf.keras.utils.image_dataset_from_directory
效用函数从磁盘加载这些图像。
创建数据集
为加载器定义一些参数:
开发模型时,最好使用验证拆分。您将使用 80% 的图像进行训练,20% 的图像进行验证。
您可以在这些数据集的 class_names
特性中找到类名称。
呈现数据
下面是训练数据集中的前 9 个图像。
您可以使用这些数据集来训练模型,方法是将它们传递给 model.fit
(在本教程后面展示)。如果愿意,您还可以手动迭代数据集并检索批量图像:
image_batch
是形状为 (32, 180, 180, 3)
的张量。这是由 32 个形状为 180x180x3
(最后一个维度是指颜色通道 RGB)的图像组成的批次。label_batch
是形状为 (32,)
的张量,这些是 32 个图像的对应标签。
您可以对这些张量中的任何一个调用 .numpy()
以将它们转换为 numpy.ndarray
。
标准化数据
RGB 通道值在 [0, 255]
范围内。这对于神经网络来说并不理想;一般而言,您应当设法使您的输入值变小。
在这里,我们通过使用 tf.keras.layers.Rescaling
将值标准化为在 [0, 1]
范围内。
可以通过两种方式使用该层。您可以通过调用 Dataset.map
将其应用于数据集:
或者,您也可以在模型定义中包含该层以简化部署。在这里,您将使用第二种方式。
注:如果您想将像素值缩放到 [-1,1]
,则可以改为编写 tf.keras.layers.Rescaling(1./127.5, offset=-1)
注:您之前使用 tf.keras.utils.image_dataset_from_directory
的 image_size
参数调整了图像大小。如果您还希望在模型中包括调整大小的逻辑,可以使用 tf.keras.layers.Resizing
层。
配置数据集以提高性能
我们确保使用缓冲预获取,以便您可以从磁盘生成数据,而不会导致 I/O 阻塞。下面是加载数据时应当使用的两个重要方法。
在第一个周期期间从磁盘加载图像后,
Dataset.cache()
会将这些图像保留在内存中。这将确保在训练模型时数据集不会成为瓶颈。如果数据集太大无法装入内存,您也可以使用此方法创建高性能的磁盘缓存。Dataset.prefetch()
会在训练时将数据预处理和模型执行重叠。
感兴趣的读者可以在使用 tf.data API 提升性能指南的预提取部分了解更多有关这两种方法的详细信息,以及如何将数据缓存到磁盘。
选择 tf.keras.optimizers.Adam
优化器和 tf.keras.losses.SparseCategoricalCrossentropy
损失函数。要查看每个训练周期的训练和验证准确率,请将 metrics
参数传递给 Model.compile
。
注:您将仅训练几个周期,因此本教程的运行速度很快。
注:您也可以编写自定义训练循环而不是使用 Model.fit
。要了解详情,请访问从头编写训练循环教程。
您可能会注意到,与训练准确率相比,验证准确率较低,这表明我们的模型存在过拟合。您可以在此教程中详细了解过拟合以及如何减少过拟合。
使用 tf.data 进行更精细的控制
利用上面的 Keras 预处理效用函数 tf.keras.utils.image_dataset_from_directory
,可以方便地从头创建 tf.data.Dataset
。
要实现更精细的控制,您可以使用 tf.data
编写自己的输入流水线。本部分展示了如何做到这一点,从我们之前下载的 TGZ 文件中的文件路径开始。
文件的树结构可用于编译 class_names
列表。
将数据集拆分为训练集和测试集:
您可以按照如下方式打印每个数据集的长度:
编写一个将文件路径转换为 (img, label)
对的短函数:
使用 Dataset.map
创建 image, label
对的数据集:
训练的基本方法
呈现数据
您可以通过与之前创建的数据集类似的方式呈现此数据集:
继续训练模型
您现在已经手动构建了一个与由上面的 keras.preprocessing
创建的数据集类似的 tf.data.Dataset
。您可以继续用它来训练模型。和之前一样,您将只训练几个周期以确保较短的运行时间。
使用 TensorFlow Datasets
到目前为止,本教程的重点是从磁盘加载数据。此外,您还可以通过在 TensorFlow Datasets 上探索易于下载的大型数据集目录来查找要使用的数据集。
由于您之前已经从磁盘加载了花卉数据集,接下来看看如何使用 TensorFlow Datasets 导入它。
使用 TensorFlow Datasets 下载花卉数据集:
花卉数据集有五个类:
从数据集中检索图像:
和以前一样,请记得对训练集、验证集和测试集进行批处理、打乱顺序和配置以提高性能。
您可以通过访问数据增强教程找到使用花卉数据集和 TensorFlow Datasets 的完整示例。
后续步骤
本教程展示了从磁盘加载图像的两种方式。首先,您学习了如何使用 Keras 预处理层和效用函数加载和预处理图像数据集。接下来,您学习了如何使用 tf.data
从头开始编写输入流水线。最后,您学习了如何从 TensorFlow Datasets 下载数据集。
后续步骤:
您可以学习如何添加数据增强。
要详细了解
tf.data
,您可以访问 tf.data:构建 TensorFlow 输入流水线指南。