Path: blob/master/guides/keras_hub/semantic_segmentation_deeplab_v3.py
3293 views
"""1Title: Semantic Segmentation with KerasHub2Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)3Date created: 2024/10/114Last modified: 2024/10/225Description: DeepLabV3 training and inference with KerasHub.6Accelerator: GPU7"""89"""101112## Background13Semantic segmentation is a type of computer vision task that involves assigning a14class label such as "person", "bike", or "background" to each individual pixel15of an image, effectively dividing the image into regions that correspond to16different object classes or categories.171819202122KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic23segmentation.2425This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed26by Google for image semantic segmentation with KerasHub. Its architecture27combines Atrous convolutions, contextual information aggregation, and powerful28backbones to achieve accurate and detailed semantic segmentation.2930DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to31refine the segmentation results, especially along object boundaries. Both models32have achieved state-of-the-art results on a variety of image segmentation33benchmarks.3435### References36[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)37[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)38"""3940"""41## Setup and Imports4243Let's install the dependencies and import the necessary modules.4445To run this tutorial, you will need to install the following packages:4647* `keras-hub`48* `keras`49"""5051"""shell52pip install -q --upgrade keras-hub53pip install -q --upgrade keras54"""5556"""57After installing `keras` and `keras-hub`, set the backend for `keras`.58This guide can be run with any backend (Tensorflow, JAX, PyTorch).59"""6061import os6263os.environ["KERAS_BACKEND"] = "jax"64import keras65from keras import ops66import keras_hub67import numpy as np68import tensorflow as tf69import matplotlib.pyplot as plt7071"""72## Perform semantic segmentation with a pretrained DeepLabv3+ model7374The highest level API in the KerasHub semantic segmentation API is the75`keras_hub.models` API. This API includes fully pretrained semantic segmentation76models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.7778Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC79dataset.80Also, define the preprocessing function for the model to preprocess images and81labels.82**Note:** By default `from_preset()` method in KerasHub loads the pretrained83task weights with all the classes, 21 classes in this case.84"""8586model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(87"deeplab_v3_plus_resnet50_pascalvoc"88)8990image_converter = keras_hub.layers.DeepLabV3ImageConverter(91image_size=(512, 512),92interpolation="bilinear",93)94preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)9596"""97Let us visualize the results of this pretrained model98"""99filepath = keras.utils.get_file(100origin="https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png"101)102image = keras.utils.load_img(filepath)103image = np.array(image)104105image = preprocessor(image)106image = keras.ops.expand_dims(image, axis=0)107preds = ops.expand_dims(ops.argmax(model.predict(image), axis=-1), axis=-1)108109110def plot_segmentation(original_image, predicted_mask):111plt.figure(figsize=(5, 5))112113plt.subplot(1, 2, 1)114plt.imshow(original_image[0] / 255)115plt.axis("off")116117plt.subplot(1, 2, 2)118plt.imshow(predicted_mask[0])119plt.axis("off")120121plt.tight_layout()122plt.show()123124125plot_segmentation(image, preds)126127"""128## Train a custom semantic segmentation model129In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3130semantic segmentation model. This includes data loading, augmentation, training,131metric evaluation, and inference!132"""133134"""135## Download the data136137We download Pascal VOC 2012 dataset with additional annotations provided here138[Semantic contours from inverse detectors](https://ieeexplore.ieee.org/document/6126343)139and split them into train dataset `train_ds` and `eval_ds`.140"""141142# @title helper functions143import logging144import multiprocessing145from builtins import open146import os.path147import random148import xml149150import tensorflow_datasets as tfds151152VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"153154SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"155156# Note that this list doesn't contain the background class. In the157# classification use case, the label is 0 based (aeroplane -> 0), whereas in158# segmentation use case, the 0 is reserved for background, so aeroplane maps to159# 1.160CLASSES = [161"aeroplane",162"bicycle",163"bird",164"boat",165"bottle",166"bus",167"car",168"cat",169"chair",170"cow",171"diningtable",172"dog",173"horse",174"motorbike",175"person",176"pottedplant",177"sheep",178"sofa",179"train",180"tvmonitor",181]182# This is used to map between string class to index.183CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}184185# For the mask data in the PNG file, the encoded raw pixel value need to be186# converted to the proper class index. In the following map, [0, 0, 0] will be187# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also188# note that the mask class is 1 base since class 0 is reserved for the189# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.190VOC_PNG_COLOR_VALUE = [191[0, 0, 0],192[128, 0, 0],193[0, 128, 0],194[128, 128, 0],195[0, 0, 128],196[128, 0, 128],197[0, 128, 128],198[128, 128, 128],199[64, 0, 0],200[192, 0, 0],201[64, 128, 0],202[192, 128, 0],203[64, 0, 128],204[192, 0, 128],205[64, 128, 128],206[192, 128, 128],207[0, 64, 0],208[128, 64, 0],209[0, 192, 0],210[128, 192, 0],211[0, 64, 128],212]213# Will be populated by maybe_populate_voc_color_mapping() below.214VOC_PNG_COLOR_MAPPING = None215216217def maybe_populate_voc_color_mapping():218"""Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""219global VOC_PNG_COLOR_MAPPING220if VOC_PNG_COLOR_MAPPING is None:221VOC_PNG_COLOR_MAPPING = [0] * (256**3)222for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):223VOC_PNG_COLOR_MAPPING[224(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]225] = i226# There is a special mapping with [224, 224, 192] -> 255227VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255228VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)229return VOC_PNG_COLOR_MAPPING230231232def parse_annotation_data(annotation_file_path):233"""Parse the annotation XML file for the image.234235The annotation contains the metadata, as well as the object bounding box236information.237238"""239with open(annotation_file_path, "r") as f:240root = xml.etree.ElementTree.parse(f).getroot()241242size = root.find("size")243width = int(size.find("width").text)244height = int(size.find("height").text)245246objects = []247for obj in root.findall("object"):248# Get object's label name.249label = CLASS_TO_INDEX[obj.find("name").text.lower()]250# Get objects' pose name.251pose = obj.find("pose").text.lower()252is_truncated = obj.find("truncated").text == "1"253is_difficult = obj.find("difficult").text == "1"254bndbox = obj.find("bndbox")255xmax = int(bndbox.find("xmax").text)256xmin = int(bndbox.find("xmin").text)257ymax = int(bndbox.find("ymax").text)258ymin = int(bndbox.find("ymin").text)259objects.append(260{261"label": label,262"pose": pose,263"bbox": [ymin, xmin, ymax, xmax],264"is_truncated": is_truncated,265"is_difficult": is_difficult,266}267)268269return {"width": width, "height": height, "objects": objects}270271272def get_image_ids(data_dir, split):273"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""274data_file_mapping = {275"train": "train.txt",276"eval": "val.txt",277"trainval": "trainval.txt",278}279with open(280os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),281"r",282) as f:283image_ids = f.read().splitlines()284logging.info(f"Received {len(image_ids)} images for {split} dataset.")285return image_ids286287288def get_sbd_image_ids(data_dir, split):289"""To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""290data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}291with open(292os.path.join(data_dir, data_file_mapping[split]),293"r",294) as f:295image_ids = f.read().splitlines()296logging.info(f"Received {len(image_ids)} images for {split} dataset.")297return image_ids298299300def parse_single_image(image_file_path):301"""Creates metadata of VOC images and path."""302data_dir, image_file_name = os.path.split(image_file_path)303data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))304image_id, _ = os.path.splitext(image_file_name)305class_segmentation_file_path = os.path.join(306data_dir, "SegmentationClass", image_id + ".png"307)308object_segmentation_file_path = os.path.join(309data_dir, "SegmentationObject", image_id + ".png"310)311annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")312image_annotations = parse_annotation_data(annotation_file_path)313314result = {315"image/filename": image_id + ".jpg",316"image/file_path": image_file_path,317"segmentation/class/file_path": class_segmentation_file_path,318"segmentation/object/file_path": object_segmentation_file_path,319}320result.update(image_annotations)321# Labels field should be same as the 'object.label'322labels = list(set([o["label"] for o in result["objects"]]))323result["labels"] = sorted(labels)324return result325326327def parse_single_sbd_image(image_file_path):328"""Creates metadata of SBD images and path."""329data_dir, image_file_name = os.path.split(image_file_path)330data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))331image_id, _ = os.path.splitext(image_file_name)332class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat")333object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat")334result = {335"image/filename": image_id + ".jpg",336"image/file_path": image_file_path,337"segmentation/class/file_path": class_segmentation_file_path,338"segmentation/object/file_path": object_segmentation_file_path,339}340return result341342343def build_metadata(data_dir, image_ids):344"""Transpose the metadata which convert from list of dict to dict of list."""345# Parallel process all the images.346image_file_paths = [347os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids348]349pool_size = 10 if len(image_ids) > 10 else len(image_ids)350with multiprocessing.Pool(pool_size) as p:351metadata = p.map(parse_single_image, image_file_paths)352353keys = [354"image/filename",355"image/file_path",356"segmentation/class/file_path",357"segmentation/object/file_path",358"labels",359"width",360"height",361]362result = {}363for key in keys:364values = [value[key] for value in metadata]365result[key] = values366367# The ragged objects need some special handling368for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]:369values = []370objects = [value["objects"] for value in metadata]371for object in objects:372values.append([o[key] for o in object])373result["objects/" + key] = values374return result375376377def build_sbd_metadata(data_dir, image_ids):378"""Transpose the metadata which convert from list of dict to dict of list."""379# Parallel process all the images.380image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]381pool_size = 10 if len(image_ids) > 10 else len(image_ids)382with multiprocessing.Pool(pool_size) as p:383metadata = p.map(parse_single_sbd_image, image_file_paths)384385keys = [386"image/filename",387"image/file_path",388"segmentation/class/file_path",389"segmentation/object/file_path",390]391result = {}392for key in keys:393values = [value[key] for value in metadata]394result[key] = values395return result396397398def decode_png_mask(mask):399"""Decode the raw PNG image and convert it to 2D tensor with probably400class."""401# Cast the mask to int32 since the original uint8 will overflow when402# multiplied with 256403mask = tf.cast(mask, tf.int32)404mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]405mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)406mask = tf.cast(mask, tf.uint8)407return mask408409410def load_images(example):411"""Loads VOC images for segmentation task from the provided paths"""412image_file_path = example.pop("image/file_path")413segmentation_class_file_path = example.pop("segmentation/class/file_path")414segmentation_object_file_path = example.pop("segmentation/object/file_path")415image = tf.io.read_file(image_file_path)416image = tf.image.decode_jpeg(image)417418segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)419segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)420segmentation_class_mask = decode_png_mask(segmentation_class_mask)421422segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)423segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)424segmentation_object_mask = decode_png_mask(segmentation_object_mask)425426example.update(427{428"image": image,429"class_segmentation": segmentation_class_mask,430"object_segmentation": segmentation_object_mask,431}432)433return example434435436def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):437"""Loads SBD images for segmentation task from the provided paths"""438image = tf.io.read_file(image_file_path)439image = tf.image.decode_jpeg(image)440441segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)442segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0]443segmentation_class_mask = segmentation_class_mask[..., np.newaxis]444445segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(446seg_obj_file_path447)448segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0]449segmentation_object_mask = segmentation_object_mask[..., np.newaxis]450451return {452"image": image,453"class_segmentation": segmentation_class_mask,454"object_segmentation": segmentation_object_mask,455}456457458def build_dataset_from_metadata(metadata):459"""Builds TensorFlow dataset from the image metadata of VOC dataset."""460# The objects need some manual conversion to ragged tensor.461metadata["labels"] = tf.ragged.constant(metadata["labels"])462metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])463metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"])464metadata["objects/is_truncated"] = tf.ragged.constant(465metadata["objects/is_truncated"]466)467metadata["objects/is_difficult"] = tf.ragged.constant(468metadata["objects/is_difficult"]469)470metadata["objects/bbox"] = tf.ragged.constant(471metadata["objects/bbox"], ragged_rank=1472)473474dataset = tf.data.Dataset.from_tensor_slices(metadata)475dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)476return dataset477478479def build_sbd_dataset_from_metadata(metadata):480"""Builds TensorFlow dataset from the image metadata of SBD dataset."""481img_filepath = metadata["image/file_path"]482cls_filepath = metadata["segmentation/class/file_path"]483obj_filepath = metadata["segmentation/object/file_path"]484485def md_gen():486c = list(zip(img_filepath, cls_filepath, obj_filepath))487# random shuffling for each generator boosts up the quality.488random.shuffle(c)489for fp in c:490img_fp, cls_fp, obj_fp = fp491yield load_sbd_images(img_fp, cls_fp, obj_fp)492493dataset = tf.data.Dataset.from_generator(494md_gen,495output_signature=(496{497"image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),498"class_segmentation": tf.TensorSpec(499shape=(None, None, 1), dtype=tf.uint8500),501"object_segmentation": tf.TensorSpec(502shape=(None, None, 1), dtype=tf.uint8503),504}505),506)507508return dataset509510511def load(512split="sbd_train",513data_dir=None,514):515"""Load the Pacal VOC 2012 dataset.516517This function will download the data tar file from remote if needed, and518untar to the local `data_dir`, and build dataset from it.519520It supports both VOC2012 and Semantic Boundaries Dataset (SBD).521522The returned segmentation masks will be int ranging from [0, num_classes),523as well as 255 which is the boundary mask.524525Args:526split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or527'sbd_eval'. 'sbd_train' represents the training dataset for SBD528dataset, while 'train' represents the training dataset for VOC2012529dataset. Defaults to `sbd_train`.530data_dir: string, local directory path for the loaded data. This will be531used to download the data file, and unzip. It will be used as a532cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`533will be used.534"""535supported_split_value = [536"train",537"eval",538"trainval",539"sbd_train",540"sbd_eval",541]542if split not in supported_split_value:543raise ValueError(544f"The support value for `split` are {supported_split_value}. "545f"Got: {split}"546)547548if data_dir is not None:549data_dir = os.path.expanduser(data_dir)550551if "sbd" in split:552return load_sbd(split, data_dir)553else:554return load_voc(split, data_dir)555556557def load_voc(558split="train",559data_dir=None,560):561"""This function will download VOC data from a URL. If the data is already562present in the cache directory, it will load the data from that directory563instead.564"""565extracted_dir = os.path.join("VOCdevkit", "VOC2012")566get_data = keras.utils.get_file(567fname=os.path.basename(VOC_URL),568origin=VOC_URL,569cache_dir=data_dir,570extract=True,571)572data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)573image_ids = get_image_ids(data_dir, split)574# len(metadata) = #samples, metadata[i] is a dict.575metadata = build_metadata(data_dir, image_ids)576maybe_populate_voc_color_mapping()577dataset = build_dataset_from_metadata(metadata)578579return dataset580581582def load_sbd(583split="sbd_train",584data_dir=None,585):586"""This function will download SBD data from a URL. If the data is already587present in the cache directory, it will load the data from that directory588instead.589"""590extracted_dir = os.path.join("benchmark_RELEASE", "dataset")591get_data = keras.utils.get_file(592fname=os.path.basename(SBD_URL),593origin=SBD_URL,594cache_dir=data_dir,595extract=True,596)597data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)598image_ids = get_sbd_image_ids(data_dir, split)599# len(metadata) = #samples, metadata[i] is a dict.600metadata = build_sbd_metadata(data_dir, image_ids)601602dataset = build_sbd_dataset_from_metadata(metadata)603return dataset604605606"""607## Load the dataset608609For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also610choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',611'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the612SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.613"""614train_ds = load(split="sbd_train", data_dir="segmentation")615eval_ds = load(split="sbd_eval", data_dir="segmentation")616617"""618## Preprocess the data619620The preprocess_inputs utility function preprocesses inputs, converting them into621a dictionary containing images and segmentation_masks. Both images and622segmentation masks are resized to 512x512. The resulting dataset is then batched623into groups of four image and segmentation mask pairs.624"""625626627def preprocess_inputs(inputs):628def unpackage_inputs(inputs):629return {630"images": inputs["image"],631"segmentation_masks": inputs["class_segmentation"],632}633634outputs = inputs.map(unpackage_inputs)635outputs = outputs.map(keras.layers.Resizing(height=512, width=512))636outputs = outputs.batch(4, drop_remainder=True)637return outputs638639640train_ds = preprocess_inputs(train_ds)641batch = train_ds.take(1).get_single_element()642643"""644A batch of this preprocessed input training data can be visualized using the645`plot_images_masks` function. This function takes a batch of images and646segmentation masks and prediction masks as input and displays them in a grid.647"""648649650def plot_images_masks(images, masks, pred_masks=None):651num_images = len(images)652plt.figure(figsize=(8, 4))653rows = 3 if pred_masks is not None else 2654655for i in range(num_images):656plt.subplot(rows, num_images, i + 1)657plt.imshow(images[i] / 255)658plt.axis("off")659660plt.subplot(rows, num_images, num_images + i + 1)661plt.imshow(masks[i])662plt.axis("off")663664if pred_masks is not None:665plt.subplot(rows, num_images, i + 1 + 2 * num_images)666plt.imshow(pred_masks[i])667plt.axis("off")668669plt.show()670671672plot_images_masks(batch["images"], batch["segmentation_masks"])673674"""675The preprocessing is applied to the evaluation dataset `eval_ds`.676"""677eval_ds = preprocess_inputs(eval_ds)678679"""680## Data Augmentation681682Keras provides a variety of image augmentation options. In this example, we will683use the `RandomFlip` augmentation to augment the training dataset. The684`RandomFlip` augmentation randomly flips the images in the training dataset685horizontally or vertically. This can help to improve the model's robustness to686changes in the orientation of the objects in the images.687"""688689train_ds = train_ds.map(keras.layers.RandomFlip())690batch = train_ds.take(1).get_single_element()691692plot_images_masks(batch["images"], batch["segmentation_masks"])693694"""695## Model Configuration696697Please feel free to modify the configurations for model training and note how the698training results changes. This is an great exercise to get a better699understanding of the training pipeline.700701The learning rate schedule is used by the optimizer to calculate the learning702rate for each epoch. The optimizer then uses the learning rate to update the703weights of the model.704In this case, the learning rate schedule uses a cosine decay function. A cosine705decay function starts high and then decreases over time, eventually reaching706zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The707dataset cardinality is important for learning rate decay because it determines708how many steps the model will train for. The initial learning rate is709proportional to 0.007 and the decay steps are 2124. This means that the learning710rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.711712"""713714BATCH_SIZE = 4715INITIAL_LR = 0.007 * BATCH_SIZE / 16716EPOCHS = 1717NUM_CLASSES = 21718learning_rate = keras.optimizers.schedules.CosineDecay(719INITIAL_LR,720decay_steps=EPOCHS * 2124,721)722723"""724Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for725the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with726additional decoder block.727For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing728`low_level_feature_key` as `P2` a pyramid level output to extract features from729`resnet_50_imagenet` which acts as a decoder block.730To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`731which defaults to `None`.732733Then we create DeepLabV3ImageSegmenter instance.734The `num_classes` parameter specifies the number of classes that the model will735be trained to segment. `preprocessor` argument to apply preprocessing to image736input and masks.737"""738739image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")740741deeplab_backbone = keras_hub.models.DeepLabV3Backbone(742image_encoder=image_encoder,743low_level_feature_key="P2",744spatial_pyramid_pooling_key="P5",745dilation_rates=[6, 12, 18],746upsampling_size=8,747)748749model = keras_hub.models.DeepLabV3ImageSegmenter(750backbone=deeplab_backbone,751num_classes=21,752activation="softmax",753preprocessor=preprocessor,754)755756"""757## Compile the model758759The model.compile() function sets up the training process for the model. It defines the760- optimization algorithm - Stochastic Gradient Descent (SGD)761- the loss function - categorical cross-entropy762- the evaluation metrics - Mean IoU and categorical accuracy763764Semantic segmentation evaluation metrics:765766Mean Intersection over Union (MeanIoU):767MeanIoU measures how well a semantic segmentation model accurately identifies768and delineates different objects or regions in an image. It calculates the769overlap between predicted and actual object boundaries, providing a score770between 0 and 1, where 1 represents a perfect match.771772Categorical Accuracy:773Categorical Accuracy measures the proportion of correctly classified pixels in774an image. It gives a simple percentage indicating how accurately the model775predicts the categories of pixels in the entire image.776777In essence, MeanIoU emphasizes the accuracy of identifying specific object778boundaries, while Categorical Accuracy gives a broad overview of overall779pixel-level correctness.780"""781782model.compile(783optimizer=keras.optimizers.SGD(784learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0785),786loss=keras.losses.CategoricalCrossentropy(from_logits=False),787metrics=[788keras.metrics.MeanIoU(789num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False790),791keras.metrics.CategoricalAccuracy(),792],793)794795model.summary()796797"""798The utility function `dict_to_tuple` effectively transforms the dictionaries of799training and validation datasets into tuples of images and one-hot encoded800segmentation masks, which is used during training and evaluation of the801DeepLabv3+ model.802"""803804805def dict_to_tuple(x):806807return x["images"], tf.one_hot(808tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21809)810811812train_ds = train_ds.map(dict_to_tuple)813eval_ds = eval_ds.map(dict_to_tuple)814815model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)816817"""818## Predictions with trained model819Now that the model training of DeepLabv3+ has completed, let's test it by making820predications821on a few sample images.822Note: For demonstration purpose the model has been trained on only 1 epoch, for823better accuracy and result train with more number of epochs.824"""825826test_ds = load(split="sbd_eval")827test_ds = preprocess_inputs(test_ds)828829images, masks = next(iter(test_ds.take(1)))830images = ops.convert_to_tensor(images)831masks = ops.convert_to_tensor(masks)832preds = ops.expand_dims(ops.argmax(model.predict(images), axis=-1), axis=-1)833masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)834835plot_images_masks(images, masks, preds)836837"""838Here are some additional tips for using the KerasHub DeepLabv3 model:839840- The model can be trained on a variety of datasets, including the COCO dataset, the841PASCAL VOC dataset, and the Cityscapes dataset.842- The model can be fine-tuned on a custom dataset to improve its performance on a843specific task.844- The model can be used to perform real-time inference on images.845- Also, check out KerasHub's other segmentation models.846"""847848849