Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/semantic_segmentation_deeplab_v3.py
3293 views
1
"""
2
Title: Semantic Segmentation with KerasHub
3
Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
4
Date created: 2024/10/11
5
Last modified: 2024/10/22
6
Description: DeepLabV3 training and inference with KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png)
12
13
## Background
14
Semantic segmentation is a type of computer vision task that involves assigning a
15
class label such as "person", "bike", or "background" to each individual pixel
16
of an image, effectively dividing the image into regions that correspond to
17
different object classes or categories.
18
19
![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png)
20
21
22
23
KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic
24
segmentation.
25
26
This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed
27
by Google for image semantic segmentation with KerasHub. Its architecture
28
combines Atrous convolutions, contextual information aggregation, and powerful
29
backbones to achieve accurate and detailed semantic segmentation.
30
31
DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to
32
refine the segmentation results, especially along object boundaries. Both models
33
have achieved state-of-the-art results on a variety of image segmentation
34
benchmarks.
35
36
### References
37
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
38
[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
39
"""
40
41
"""
42
## Setup and Imports
43
44
Let's install the dependencies and import the necessary modules.
45
46
To run this tutorial, you will need to install the following packages:
47
48
* `keras-hub`
49
* `keras`
50
"""
51
52
"""shell
53
pip install -q --upgrade keras-hub
54
pip install -q --upgrade keras
55
"""
56
57
"""
58
After installing `keras` and `keras-hub`, set the backend for `keras`.
59
This guide can be run with any backend (Tensorflow, JAX, PyTorch).
60
"""
61
62
import os
63
64
os.environ["KERAS_BACKEND"] = "jax"
65
import keras
66
from keras import ops
67
import keras_hub
68
import numpy as np
69
import tensorflow as tf
70
import matplotlib.pyplot as plt
71
72
"""
73
## Perform semantic segmentation with a pretrained DeepLabv3+ model
74
75
The highest level API in the KerasHub semantic segmentation API is the
76
`keras_hub.models` API. This API includes fully pretrained semantic segmentation
77
models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.
78
79
Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC
80
dataset.
81
Also, define the preprocessing function for the model to preprocess images and
82
labels.
83
**Note:** By default `from_preset()` method in KerasHub loads the pretrained
84
task weights with all the classes, 21 classes in this case.
85
"""
86
87
model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
88
"deeplab_v3_plus_resnet50_pascalvoc"
89
)
90
91
image_converter = keras_hub.layers.DeepLabV3ImageConverter(
92
image_size=(512, 512),
93
interpolation="bilinear",
94
)
95
preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)
96
97
"""
98
Let us visualize the results of this pretrained model
99
"""
100
filepath = keras.utils.get_file(
101
origin="https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png"
102
)
103
image = keras.utils.load_img(filepath)
104
image = np.array(image)
105
106
image = preprocessor(image)
107
image = keras.ops.expand_dims(image, axis=0)
108
preds = ops.expand_dims(ops.argmax(model.predict(image), axis=-1), axis=-1)
109
110
111
def plot_segmentation(original_image, predicted_mask):
112
plt.figure(figsize=(5, 5))
113
114
plt.subplot(1, 2, 1)
115
plt.imshow(original_image[0] / 255)
116
plt.axis("off")
117
118
plt.subplot(1, 2, 2)
119
plt.imshow(predicted_mask[0])
120
plt.axis("off")
121
122
plt.tight_layout()
123
plt.show()
124
125
126
plot_segmentation(image, preds)
127
128
"""
129
## Train a custom semantic segmentation model
130
In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3
131
semantic segmentation model. This includes data loading, augmentation, training,
132
metric evaluation, and inference!
133
"""
134
135
"""
136
## Download the data
137
138
We download Pascal VOC 2012 dataset with additional annotations provided here
139
[Semantic contours from inverse detectors](https://ieeexplore.ieee.org/document/6126343)
140
and split them into train dataset `train_ds` and `eval_ds`.
141
"""
142
143
# @title helper functions
144
import logging
145
import multiprocessing
146
from builtins import open
147
import os.path
148
import random
149
import xml
150
151
import tensorflow_datasets as tfds
152
153
VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
154
155
SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
156
157
# Note that this list doesn't contain the background class. In the
158
# classification use case, the label is 0 based (aeroplane -> 0), whereas in
159
# segmentation use case, the 0 is reserved for background, so aeroplane maps to
160
# 1.
161
CLASSES = [
162
"aeroplane",
163
"bicycle",
164
"bird",
165
"boat",
166
"bottle",
167
"bus",
168
"car",
169
"cat",
170
"chair",
171
"cow",
172
"diningtable",
173
"dog",
174
"horse",
175
"motorbike",
176
"person",
177
"pottedplant",
178
"sheep",
179
"sofa",
180
"train",
181
"tvmonitor",
182
]
183
# This is used to map between string class to index.
184
CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
185
186
# For the mask data in the PNG file, the encoded raw pixel value need to be
187
# converted to the proper class index. In the following map, [0, 0, 0] will be
188
# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also
189
# note that the mask class is 1 base since class 0 is reserved for the
190
# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.
191
VOC_PNG_COLOR_VALUE = [
192
[0, 0, 0],
193
[128, 0, 0],
194
[0, 128, 0],
195
[128, 128, 0],
196
[0, 0, 128],
197
[128, 0, 128],
198
[0, 128, 128],
199
[128, 128, 128],
200
[64, 0, 0],
201
[192, 0, 0],
202
[64, 128, 0],
203
[192, 128, 0],
204
[64, 0, 128],
205
[192, 0, 128],
206
[64, 128, 128],
207
[192, 128, 128],
208
[0, 64, 0],
209
[128, 64, 0],
210
[0, 192, 0],
211
[128, 192, 0],
212
[0, 64, 128],
213
]
214
# Will be populated by maybe_populate_voc_color_mapping() below.
215
VOC_PNG_COLOR_MAPPING = None
216
217
218
def maybe_populate_voc_color_mapping():
219
"""Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""
220
global VOC_PNG_COLOR_MAPPING
221
if VOC_PNG_COLOR_MAPPING is None:
222
VOC_PNG_COLOR_MAPPING = [0] * (256**3)
223
for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):
224
VOC_PNG_COLOR_MAPPING[
225
(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]
226
] = i
227
# There is a special mapping with [224, 224, 192] -> 255
228
VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255
229
VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)
230
return VOC_PNG_COLOR_MAPPING
231
232
233
def parse_annotation_data(annotation_file_path):
234
"""Parse the annotation XML file for the image.
235
236
The annotation contains the metadata, as well as the object bounding box
237
information.
238
239
"""
240
with open(annotation_file_path, "r") as f:
241
root = xml.etree.ElementTree.parse(f).getroot()
242
243
size = root.find("size")
244
width = int(size.find("width").text)
245
height = int(size.find("height").text)
246
247
objects = []
248
for obj in root.findall("object"):
249
# Get object's label name.
250
label = CLASS_TO_INDEX[obj.find("name").text.lower()]
251
# Get objects' pose name.
252
pose = obj.find("pose").text.lower()
253
is_truncated = obj.find("truncated").text == "1"
254
is_difficult = obj.find("difficult").text == "1"
255
bndbox = obj.find("bndbox")
256
xmax = int(bndbox.find("xmax").text)
257
xmin = int(bndbox.find("xmin").text)
258
ymax = int(bndbox.find("ymax").text)
259
ymin = int(bndbox.find("ymin").text)
260
objects.append(
261
{
262
"label": label,
263
"pose": pose,
264
"bbox": [ymin, xmin, ymax, xmax],
265
"is_truncated": is_truncated,
266
"is_difficult": is_difficult,
267
}
268
)
269
270
return {"width": width, "height": height, "objects": objects}
271
272
273
def get_image_ids(data_dir, split):
274
"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""
275
data_file_mapping = {
276
"train": "train.txt",
277
"eval": "val.txt",
278
"trainval": "trainval.txt",
279
}
280
with open(
281
os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),
282
"r",
283
) as f:
284
image_ids = f.read().splitlines()
285
logging.info(f"Received {len(image_ids)} images for {split} dataset.")
286
return image_ids
287
288
289
def get_sbd_image_ids(data_dir, split):
290
"""To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""
291
data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}
292
with open(
293
os.path.join(data_dir, data_file_mapping[split]),
294
"r",
295
) as f:
296
image_ids = f.read().splitlines()
297
logging.info(f"Received {len(image_ids)} images for {split} dataset.")
298
return image_ids
299
300
301
def parse_single_image(image_file_path):
302
"""Creates metadata of VOC images and path."""
303
data_dir, image_file_name = os.path.split(image_file_path)
304
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
305
image_id, _ = os.path.splitext(image_file_name)
306
class_segmentation_file_path = os.path.join(
307
data_dir, "SegmentationClass", image_id + ".png"
308
)
309
object_segmentation_file_path = os.path.join(
310
data_dir, "SegmentationObject", image_id + ".png"
311
)
312
annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")
313
image_annotations = parse_annotation_data(annotation_file_path)
314
315
result = {
316
"image/filename": image_id + ".jpg",
317
"image/file_path": image_file_path,
318
"segmentation/class/file_path": class_segmentation_file_path,
319
"segmentation/object/file_path": object_segmentation_file_path,
320
}
321
result.update(image_annotations)
322
# Labels field should be same as the 'object.label'
323
labels = list(set([o["label"] for o in result["objects"]]))
324
result["labels"] = sorted(labels)
325
return result
326
327
328
def parse_single_sbd_image(image_file_path):
329
"""Creates metadata of SBD images and path."""
330
data_dir, image_file_name = os.path.split(image_file_path)
331
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
332
image_id, _ = os.path.splitext(image_file_name)
333
class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat")
334
object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat")
335
result = {
336
"image/filename": image_id + ".jpg",
337
"image/file_path": image_file_path,
338
"segmentation/class/file_path": class_segmentation_file_path,
339
"segmentation/object/file_path": object_segmentation_file_path,
340
}
341
return result
342
343
344
def build_metadata(data_dir, image_ids):
345
"""Transpose the metadata which convert from list of dict to dict of list."""
346
# Parallel process all the images.
347
image_file_paths = [
348
os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
349
]
350
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
351
with multiprocessing.Pool(pool_size) as p:
352
metadata = p.map(parse_single_image, image_file_paths)
353
354
keys = [
355
"image/filename",
356
"image/file_path",
357
"segmentation/class/file_path",
358
"segmentation/object/file_path",
359
"labels",
360
"width",
361
"height",
362
]
363
result = {}
364
for key in keys:
365
values = [value[key] for value in metadata]
366
result[key] = values
367
368
# The ragged objects need some special handling
369
for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]:
370
values = []
371
objects = [value["objects"] for value in metadata]
372
for object in objects:
373
values.append([o[key] for o in object])
374
result["objects/" + key] = values
375
return result
376
377
378
def build_sbd_metadata(data_dir, image_ids):
379
"""Transpose the metadata which convert from list of dict to dict of list."""
380
# Parallel process all the images.
381
image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]
382
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
383
with multiprocessing.Pool(pool_size) as p:
384
metadata = p.map(parse_single_sbd_image, image_file_paths)
385
386
keys = [
387
"image/filename",
388
"image/file_path",
389
"segmentation/class/file_path",
390
"segmentation/object/file_path",
391
]
392
result = {}
393
for key in keys:
394
values = [value[key] for value in metadata]
395
result[key] = values
396
return result
397
398
399
def decode_png_mask(mask):
400
"""Decode the raw PNG image and convert it to 2D tensor with probably
401
class."""
402
# Cast the mask to int32 since the original uint8 will overflow when
403
# multiplied with 256
404
mask = tf.cast(mask, tf.int32)
405
mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
406
mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)
407
mask = tf.cast(mask, tf.uint8)
408
return mask
409
410
411
def load_images(example):
412
"""Loads VOC images for segmentation task from the provided paths"""
413
image_file_path = example.pop("image/file_path")
414
segmentation_class_file_path = example.pop("segmentation/class/file_path")
415
segmentation_object_file_path = example.pop("segmentation/object/file_path")
416
image = tf.io.read_file(image_file_path)
417
image = tf.image.decode_jpeg(image)
418
419
segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)
420
segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)
421
segmentation_class_mask = decode_png_mask(segmentation_class_mask)
422
423
segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)
424
segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)
425
segmentation_object_mask = decode_png_mask(segmentation_object_mask)
426
427
example.update(
428
{
429
"image": image,
430
"class_segmentation": segmentation_class_mask,
431
"object_segmentation": segmentation_object_mask,
432
}
433
)
434
return example
435
436
437
def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
438
"""Loads SBD images for segmentation task from the provided paths"""
439
image = tf.io.read_file(image_file_path)
440
image = tf.image.decode_jpeg(image)
441
442
segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)
443
segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0]
444
segmentation_class_mask = segmentation_class_mask[..., np.newaxis]
445
446
segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(
447
seg_obj_file_path
448
)
449
segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0]
450
segmentation_object_mask = segmentation_object_mask[..., np.newaxis]
451
452
return {
453
"image": image,
454
"class_segmentation": segmentation_class_mask,
455
"object_segmentation": segmentation_object_mask,
456
}
457
458
459
def build_dataset_from_metadata(metadata):
460
"""Builds TensorFlow dataset from the image metadata of VOC dataset."""
461
# The objects need some manual conversion to ragged tensor.
462
metadata["labels"] = tf.ragged.constant(metadata["labels"])
463
metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
464
metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"])
465
metadata["objects/is_truncated"] = tf.ragged.constant(
466
metadata["objects/is_truncated"]
467
)
468
metadata["objects/is_difficult"] = tf.ragged.constant(
469
metadata["objects/is_difficult"]
470
)
471
metadata["objects/bbox"] = tf.ragged.constant(
472
metadata["objects/bbox"], ragged_rank=1
473
)
474
475
dataset = tf.data.Dataset.from_tensor_slices(metadata)
476
dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
477
return dataset
478
479
480
def build_sbd_dataset_from_metadata(metadata):
481
"""Builds TensorFlow dataset from the image metadata of SBD dataset."""
482
img_filepath = metadata["image/file_path"]
483
cls_filepath = metadata["segmentation/class/file_path"]
484
obj_filepath = metadata["segmentation/object/file_path"]
485
486
def md_gen():
487
c = list(zip(img_filepath, cls_filepath, obj_filepath))
488
# random shuffling for each generator boosts up the quality.
489
random.shuffle(c)
490
for fp in c:
491
img_fp, cls_fp, obj_fp = fp
492
yield load_sbd_images(img_fp, cls_fp, obj_fp)
493
494
dataset = tf.data.Dataset.from_generator(
495
md_gen,
496
output_signature=(
497
{
498
"image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
499
"class_segmentation": tf.TensorSpec(
500
shape=(None, None, 1), dtype=tf.uint8
501
),
502
"object_segmentation": tf.TensorSpec(
503
shape=(None, None, 1), dtype=tf.uint8
504
),
505
}
506
),
507
)
508
509
return dataset
510
511
512
def load(
513
split="sbd_train",
514
data_dir=None,
515
):
516
"""Load the Pacal VOC 2012 dataset.
517
518
This function will download the data tar file from remote if needed, and
519
untar to the local `data_dir`, and build dataset from it.
520
521
It supports both VOC2012 and Semantic Boundaries Dataset (SBD).
522
523
The returned segmentation masks will be int ranging from [0, num_classes),
524
as well as 255 which is the boundary mask.
525
526
Args:
527
split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or
528
'sbd_eval'. 'sbd_train' represents the training dataset for SBD
529
dataset, while 'train' represents the training dataset for VOC2012
530
dataset. Defaults to `sbd_train`.
531
data_dir: string, local directory path for the loaded data. This will be
532
used to download the data file, and unzip. It will be used as a
533
cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`
534
will be used.
535
"""
536
supported_split_value = [
537
"train",
538
"eval",
539
"trainval",
540
"sbd_train",
541
"sbd_eval",
542
]
543
if split not in supported_split_value:
544
raise ValueError(
545
f"The support value for `split` are {supported_split_value}. "
546
f"Got: {split}"
547
)
548
549
if data_dir is not None:
550
data_dir = os.path.expanduser(data_dir)
551
552
if "sbd" in split:
553
return load_sbd(split, data_dir)
554
else:
555
return load_voc(split, data_dir)
556
557
558
def load_voc(
559
split="train",
560
data_dir=None,
561
):
562
"""This function will download VOC data from a URL. If the data is already
563
present in the cache directory, it will load the data from that directory
564
instead.
565
"""
566
extracted_dir = os.path.join("VOCdevkit", "VOC2012")
567
get_data = keras.utils.get_file(
568
fname=os.path.basename(VOC_URL),
569
origin=VOC_URL,
570
cache_dir=data_dir,
571
extract=True,
572
)
573
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
574
image_ids = get_image_ids(data_dir, split)
575
# len(metadata) = #samples, metadata[i] is a dict.
576
metadata = build_metadata(data_dir, image_ids)
577
maybe_populate_voc_color_mapping()
578
dataset = build_dataset_from_metadata(metadata)
579
580
return dataset
581
582
583
def load_sbd(
584
split="sbd_train",
585
data_dir=None,
586
):
587
"""This function will download SBD data from a URL. If the data is already
588
present in the cache directory, it will load the data from that directory
589
instead.
590
"""
591
extracted_dir = os.path.join("benchmark_RELEASE", "dataset")
592
get_data = keras.utils.get_file(
593
fname=os.path.basename(SBD_URL),
594
origin=SBD_URL,
595
cache_dir=data_dir,
596
extract=True,
597
)
598
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
599
image_ids = get_sbd_image_ids(data_dir, split)
600
# len(metadata) = #samples, metadata[i] is a dict.
601
metadata = build_sbd_metadata(data_dir, image_ids)
602
603
dataset = build_sbd_dataset_from_metadata(metadata)
604
return dataset
605
606
607
"""
608
## Load the dataset
609
610
For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also
611
choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',
612
'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the
613
SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.
614
"""
615
train_ds = load(split="sbd_train", data_dir="segmentation")
616
eval_ds = load(split="sbd_eval", data_dir="segmentation")
617
618
"""
619
## Preprocess the data
620
621
The preprocess_inputs utility function preprocesses inputs, converting them into
622
a dictionary containing images and segmentation_masks. Both images and
623
segmentation masks are resized to 512x512. The resulting dataset is then batched
624
into groups of four image and segmentation mask pairs.
625
"""
626
627
628
def preprocess_inputs(inputs):
629
def unpackage_inputs(inputs):
630
return {
631
"images": inputs["image"],
632
"segmentation_masks": inputs["class_segmentation"],
633
}
634
635
outputs = inputs.map(unpackage_inputs)
636
outputs = outputs.map(keras.layers.Resizing(height=512, width=512))
637
outputs = outputs.batch(4, drop_remainder=True)
638
return outputs
639
640
641
train_ds = preprocess_inputs(train_ds)
642
batch = train_ds.take(1).get_single_element()
643
644
"""
645
A batch of this preprocessed input training data can be visualized using the
646
`plot_images_masks` function. This function takes a batch of images and
647
segmentation masks and prediction masks as input and displays them in a grid.
648
"""
649
650
651
def plot_images_masks(images, masks, pred_masks=None):
652
num_images = len(images)
653
plt.figure(figsize=(8, 4))
654
rows = 3 if pred_masks is not None else 2
655
656
for i in range(num_images):
657
plt.subplot(rows, num_images, i + 1)
658
plt.imshow(images[i] / 255)
659
plt.axis("off")
660
661
plt.subplot(rows, num_images, num_images + i + 1)
662
plt.imshow(masks[i])
663
plt.axis("off")
664
665
if pred_masks is not None:
666
plt.subplot(rows, num_images, i + 1 + 2 * num_images)
667
plt.imshow(pred_masks[i])
668
plt.axis("off")
669
670
plt.show()
671
672
673
plot_images_masks(batch["images"], batch["segmentation_masks"])
674
675
"""
676
The preprocessing is applied to the evaluation dataset `eval_ds`.
677
"""
678
eval_ds = preprocess_inputs(eval_ds)
679
680
"""
681
## Data Augmentation
682
683
Keras provides a variety of image augmentation options. In this example, we will
684
use the `RandomFlip` augmentation to augment the training dataset. The
685
`RandomFlip` augmentation randomly flips the images in the training dataset
686
horizontally or vertically. This can help to improve the model's robustness to
687
changes in the orientation of the objects in the images.
688
"""
689
690
train_ds = train_ds.map(keras.layers.RandomFlip())
691
batch = train_ds.take(1).get_single_element()
692
693
plot_images_masks(batch["images"], batch["segmentation_masks"])
694
695
"""
696
## Model Configuration
697
698
Please feel free to modify the configurations for model training and note how the
699
training results changes. This is an great exercise to get a better
700
understanding of the training pipeline.
701
702
The learning rate schedule is used by the optimizer to calculate the learning
703
rate for each epoch. The optimizer then uses the learning rate to update the
704
weights of the model.
705
In this case, the learning rate schedule uses a cosine decay function. A cosine
706
decay function starts high and then decreases over time, eventually reaching
707
zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The
708
dataset cardinality is important for learning rate decay because it determines
709
how many steps the model will train for. The initial learning rate is
710
proportional to 0.007 and the decay steps are 2124. This means that the learning
711
rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.
712
![png](/img/guides/semantic_segmentation_deeplab_v3_plus/learning_rate_schedule.png)
713
"""
714
715
BATCH_SIZE = 4
716
INITIAL_LR = 0.007 * BATCH_SIZE / 16
717
EPOCHS = 1
718
NUM_CLASSES = 21
719
learning_rate = keras.optimizers.schedules.CosineDecay(
720
INITIAL_LR,
721
decay_steps=EPOCHS * 2124,
722
)
723
724
"""
725
Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for
726
the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with
727
additional decoder block.
728
For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing
729
`low_level_feature_key` as `P2` a pyramid level output to extract features from
730
`resnet_50_imagenet` which acts as a decoder block.
731
To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`
732
which defaults to `None`.
733
734
Then we create DeepLabV3ImageSegmenter instance.
735
The `num_classes` parameter specifies the number of classes that the model will
736
be trained to segment. `preprocessor` argument to apply preprocessing to image
737
input and masks.
738
"""
739
740
image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
741
742
deeplab_backbone = keras_hub.models.DeepLabV3Backbone(
743
image_encoder=image_encoder,
744
low_level_feature_key="P2",
745
spatial_pyramid_pooling_key="P5",
746
dilation_rates=[6, 12, 18],
747
upsampling_size=8,
748
)
749
750
model = keras_hub.models.DeepLabV3ImageSegmenter(
751
backbone=deeplab_backbone,
752
num_classes=21,
753
activation="softmax",
754
preprocessor=preprocessor,
755
)
756
757
"""
758
## Compile the model
759
760
The model.compile() function sets up the training process for the model. It defines the
761
- optimization algorithm - Stochastic Gradient Descent (SGD)
762
- the loss function - categorical cross-entropy
763
- the evaluation metrics - Mean IoU and categorical accuracy
764
765
Semantic segmentation evaluation metrics:
766
767
Mean Intersection over Union (MeanIoU):
768
MeanIoU measures how well a semantic segmentation model accurately identifies
769
and delineates different objects or regions in an image. It calculates the
770
overlap between predicted and actual object boundaries, providing a score
771
between 0 and 1, where 1 represents a perfect match.
772
773
Categorical Accuracy:
774
Categorical Accuracy measures the proportion of correctly classified pixels in
775
an image. It gives a simple percentage indicating how accurately the model
776
predicts the categories of pixels in the entire image.
777
778
In essence, MeanIoU emphasizes the accuracy of identifying specific object
779
boundaries, while Categorical Accuracy gives a broad overview of overall
780
pixel-level correctness.
781
"""
782
783
model.compile(
784
optimizer=keras.optimizers.SGD(
785
learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0
786
),
787
loss=keras.losses.CategoricalCrossentropy(from_logits=False),
788
metrics=[
789
keras.metrics.MeanIoU(
790
num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False
791
),
792
keras.metrics.CategoricalAccuracy(),
793
],
794
)
795
796
model.summary()
797
798
"""
799
The utility function `dict_to_tuple` effectively transforms the dictionaries of
800
training and validation datasets into tuples of images and one-hot encoded
801
segmentation masks, which is used during training and evaluation of the
802
DeepLabv3+ model.
803
"""
804
805
806
def dict_to_tuple(x):
807
808
return x["images"], tf.one_hot(
809
tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21
810
)
811
812
813
train_ds = train_ds.map(dict_to_tuple)
814
eval_ds = eval_ds.map(dict_to_tuple)
815
816
model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)
817
818
"""
819
## Predictions with trained model
820
Now that the model training of DeepLabv3+ has completed, let's test it by making
821
predications
822
on a few sample images.
823
Note: For demonstration purpose the model has been trained on only 1 epoch, for
824
better accuracy and result train with more number of epochs.
825
"""
826
827
test_ds = load(split="sbd_eval")
828
test_ds = preprocess_inputs(test_ds)
829
830
images, masks = next(iter(test_ds.take(1)))
831
images = ops.convert_to_tensor(images)
832
masks = ops.convert_to_tensor(masks)
833
preds = ops.expand_dims(ops.argmax(model.predict(images), axis=-1), axis=-1)
834
masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)
835
836
plot_images_masks(images, masks, preds)
837
838
"""
839
Here are some additional tips for using the KerasHub DeepLabv3 model:
840
841
- The model can be trained on a variety of datasets, including the COCO dataset, the
842
PASCAL VOC dataset, and the Cityscapes dataset.
843
- The model can be fine-tuned on a custom dataset to improve its performance on a
844
specific task.
845
- The model can be used to perform real-time inference on images.
846
- Also, check out KerasHub's other segmentation models.
847
"""
848
849