Path: blob/master/site/zh-cn/tutorials/images/segmentation.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
图像分割
这篇教程将重点讨论图像分割任务,使用的是改进版的 U-Net。
什么是图像分割?
在图像分类任务中,网络会为每个输入图像分配一个标签(或类)。但是,如何了解该对象的形状、哪个像素属于哪个对象等信息呢?在这种情况下,您需要为图像的每个像素分配一个类。此任务称为分割。分割模型会返回有关图像的更详细信息。图像分割在医学成像、自动驾驶汽车和卫星成像等方面有很多应用。
本教程使用 Oxford-IIIT Pet Dataset (Parkhi et al, 2012)。该数据集由 37 个宠物品种的图像组成,每个品种有 200 个图像(训练拆分和测试拆分各有 100 个)。每个图像都包含相应的标签和像素级掩码。掩码是每个像素的类标签。每个像素都会被划入以下三个类别之一:
第 1 类:属于宠物的像素。
第 2 类:宠物边缘的像素。
第 3 类:以上都不是/周围的像素。
下载 Oxford-IIIT Pets 数据集
该数据集可从 TensorFlow Datasets 获得。分割掩码包含在版本 3 以上的版本中。
此外,图像颜色值被归一化到 [0,1]
范围。最后,如上所述,分割掩码中的像素被标记为 {1, 2, 3}。为方便起见,从分割掩码中减去 1,得到的标签为:{0, 1, 2}。
数据集已包含所需的训练拆分和测试拆分,因此请继续使用相同的拆分。
下面的类通过随机翻转图像来执行简单的增强。请转到图像增强教程以了解更多信息。
构建输入流水线,在对输入进行批处理后应用增强:
呈现数据集中的图像样本及其对应的掩码:
定义模型
这里使用的模型是修改后的 U-Net。U-Net 由编码器(下采样器)和解码器(上采样器)组成。为了学习稳健的特征并减少可训练参数的数量,请使用预训练模型 MobileNetV2 作为编码器。对于解码器,您将使用上采样块,该块已在 TensorFlow Examples 仓库的 pix2pix 示例中实现。(请查看笔记本中的 pix2pix:使用条件 GAN 进行图像到图像转换教程。)
如前所述,编码器是一个预训练的 MobileNetV2 模型。您将使用来自 tf.keras.applications
的模型。编码器由模型中中间层的特定输出组成。请注意,在训练过程中不会训练编码器。
解码器/上采样器只是在 TensorFlow 示例中实现的一系列上采样块:
请注意,最后一层的筛选器数量设置为 output_channels
的数量。每个类将有一个输出通道。
训练模型
现在,剩下要做的是编译和训练模型。
由于这是一个多类分类问题,请使用 tf.keras.losses.SparseCategoricalCrossentropy
损失函数,并将 from_logits
参数设置为 True
,因为标签是标量整数,而不是每个类的每个像素的分数向量。
运行推断时,分配给像素的标签是具有最高值的通道。这就是 create_mask
函数的作用。
绘制最终的模型架构:
在训练前试用一下该模型,以检查其预测结果:
下面定义的回调用于观察模型在训练过程中的改进情况:
做出预测
接下来,进行一些预测。为了节省时间,保持较小周期数,但您也可以将其设置得更高以获得更准确的结果。
可选:不平衡的类和类权重
因此,在这种情况下,您需要自己实现加权。您将使用样本权重来执行此操作:除了 (data, label)
对之外,Model.fit
还接受 (data, label, sample_weight)
三元组。
Keras Model.fit
将 sample_weight
传播给损失和指标,它们也接受 sample_weight
参数。在归约步骤之前,将样本权重乘以样本值。例如:
因此,要为本教程设置样本权重,您需要一个函数,该函数接受 (data, label)
对并返回 (data, label, sample_weight)
三元组,其中 sample_weight
是包含每个像素的类权重的单通道图像。
最简单的可能实现是将标签用作 class_weight
列表的索引:
每个生成的数据集元素包含 3 个图像:
现在,您可以在此加权数据集上训练模型:
接下来
现在您已经了解了什么是图像分割及其工作原理,您可以使用不同的中间层输出,甚至不同的预训练模型来尝试本教程。您也可以通过尝试在 Kaggle 上托管的 Carvana 图像掩码挑战来挑战自己。
您可能还想查看另一个可以根据自己的数据重新训练的模型的 Tensorflow Object Detection API。TensorFlow Hub 上提供了预训练模型。