Path: blob/master/guides/keras_hub/object_detection_retinanet.py
17129 views
"""1Title: Object Detection with KerasHub2Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Siva Sravana Kumar Neeli](https://github.com/sineeli)3Date created: 2026/03/274Last modified: 2026/03/275Description: RetinaNet Object Detection: Training, Fine-tuning, and Inference.6Accelerator: GPU7"""89"""101112## Introduction1314Object detection is a crucial computer vision task that goes beyond simple image15classification. It requires models to not only identify the types of objects16present in an image but also pinpoint their locations using bounding boxes. This17dual requirement of classification and localization makes object detection a18more complex and powerful tool.19Object detection models are broadly classified into two categories: "two-stage"20and "single-stage" detectors. Two-stage detectors often achieve higher accuracy21by first proposing regions of interest and then classifying them. However, this22approach can be computationally expensive. Single-stage detectors, on the other23hand, aim for speed by directly predicting object classes and bounding boxes in24a single pass.2526In this tutorial, we'll be diving into `RetinaNet`, a powerful object detection27model known for its speed and precision. `RetinaNet` is a single-stage detector,28a design choice that allows it to be remarkably efficient. Its impressive29performance stems from two key architectural innovations:301. **Feature Pyramid Network (FPN):** FPN equips `RetinaNet` with the ability to31seamlessly detect objects of all scales, from distant, tiny instances to large,32prominent ones.332. **Focal Loss:** This ingenious loss function tackles the common challenge of34imbalanced data by focusing the model's learning on the most crucial and35challenging object examples, leading to enhanced accuracy without compromising36speed.37383940### References4142- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)43- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)44"""4546"""47## Setup and Imports4849Let's install the dependencies and import the necessary modules.5051To run this tutorial, you will need to install the following packages:5253* `keras-hub`54* `keras`55* `opencv-python`56"""5758"""shell59pip install -q --upgrade keras-hub60pip install -q --upgrade keras61pip install -q opencv-python62"""6364import os6566os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"67import keras68import keras_hub69import tensorflow as tf7071"""72### Helper functions7374We download the Pascal VOC 2012 and 2007 datasets using these helper functions,75prepare them for the object detection task, and split them into training and76validation datasets.77"""78# @title Helper functions79import logging80import multiprocessing81import xml8283import tensorflow_datasets as tfds8485VOC_2007_URL = (86"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar"87)88VOC_2012_URL = (89"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"90)91VOC_2007_test_URL = (92"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar"93)9495# Note that this list doesn't contain the background class. In the96# classification use case, the label is 0 based (aeroplane -> 0), whereas in97# segmentation use case, the 0 is reserved for background, so aeroplane maps to98# 1.99CLASSES = [100"aeroplane",101"bicycle",102"bird",103"boat",104"bottle",105"bus",106"car",107"cat",108"chair",109"cow",110"diningtable",111"dog",112"horse",113"motorbike",114"person",115"pottedplant",116"sheep",117"sofa",118"train",119"tvmonitor",120]121COCO_90_CLASS_MAPPING = {1221: "person",1232: "bicycle",1243: "car",1254: "motorcycle",1265: "airplane",1276: "bus",1287: "train",1298: "truck",1309: "boat",13110: "traffic light",13211: "fire hydrant",13313: "stop sign",13414: "parking meter",13515: "bench",13616: "bird",13717: "cat",13818: "dog",13919: "horse",14020: "sheep",14121: "cow",14222: "elephant",14323: "bear",14424: "zebra",14525: "giraffe",14627: "backpack",14728: "umbrella",14831: "handbag",14932: "tie",15033: "suitcase",15134: "frisbee",15235: "skis",15336: "snowboard",15437: "sports ball",15538: "kite",15639: "baseball bat",15740: "baseball glove",15841: "skateboard",15942: "surfboard",16043: "tennis racket",16144: "bottle",16246: "wine glass",16347: "cup",16448: "fork",16549: "knife",16650: "spoon",16751: "bowl",16852: "banana",16953: "apple",17054: "sandwich",17155: "orange",17256: "broccoli",17357: "carrot",17458: "hot dog",17559: "pizza",17660: "donut",17761: "cake",17862: "chair",17963: "couch",18064: "potted plant",18165: "bed",18267: "dining table",18370: "toilet",18472: "tv",18573: "laptop",18674: "mouse",18775: "remote",18876: "keyboard",18977: "cell phone",19078: "microwave",19179: "oven",19280: "toaster",19381: "sink",19482: "refrigerator",19584: "book",19685: "clock",19786: "vase",19887: "scissors",19988: "teddy bear",20089: "hair drier",20190: "toothbrush",202}203# This is used to map between string class to index.204CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}205INDEX_TO_CLASS = {index: name for index, name in enumerate(CLASSES)}206207208def get_image_ids(data_dir, split):209"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""210data_file_mapping = {211"train": "train.txt",212"eval": "val.txt",213"trainval": "trainval.txt",214"test": "test.txt",215}216with open(217os.path.join(data_dir, "ImageSets", "Main", data_file_mapping[split]),218"r",219) as f:220image_ids = f.read().splitlines()221logging.info(f"Received {len(image_ids)} images for {split} dataset.")222return image_ids223224225def load_images(example):226"""Loads VOC images for segmentation task from the provided paths"""227image_file_path = example.pop("image/file_path")228image = tf.io.read_file(image_file_path)229image = tf.image.decode_jpeg(image)230231example.update(232{233"image": image,234}235)236return example237238239def parse_annotation_data(annotation_file_path):240"""Parse the annotation XML file for the image.241242The annotation contains the metadata, as well as the object bounding box243information.244245"""246with open(annotation_file_path, "r") as f:247root = xml.etree.ElementTree.parse(f).getroot()248249size = root.find("size")250width = int(size.find("width").text)251height = int(size.find("height").text)252filename = root.find("filename").text253254objects = []255for obj in root.findall("object"):256# Get object's label name.257label = CLASS_TO_INDEX[obj.find("name").text.lower()]258bndbox = obj.find("bndbox")259xmax = int(float(bndbox.find("xmax").text))260xmin = int(float(bndbox.find("xmin").text))261ymax = int(float(bndbox.find("ymax").text))262ymin = int(float(bndbox.find("ymin").text))263objects.append(264{265"label": label,266"bbox": [ymin, xmin, ymax, xmax],267}268)269270return {271"image/filename": filename,272"width": width,273"height": height,274"objects": objects,275}276277278def parse_single_image(annotation_file_path):279"""Creates metadata of VOC images and path."""280data_dir, annotation_file_name = os.path.split(annotation_file_path)281data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))282image_annotations = parse_annotation_data(annotation_file_path)283284result = {285"image/file_path": os.path.join(286data_dir, "JPEGImages", image_annotations["image/filename"]287)288}289result.update(image_annotations)290# Labels field should be same as the 'object.label'291labels = list({o["label"] for o in result["objects"]})292result["labels"] = sorted(labels)293return result294295296def build_metadata(data_dir, image_ids):297"""Transpose the metadata which converts from a list of dicts to a dict of lists."""298# Parallel process all the images.299annotation_file_paths = [300os.path.join(data_dir, "Annotations", f"{image_id}.xml")301for image_id in image_ids302]303pool_size = min(10, len(image_ids))304with multiprocessing.Pool(pool_size) as p:305metadata = p.map(parse_single_image, annotation_file_paths)306307keys = [308"image/filename",309"image/file_path",310"labels",311"width",312"height",313]314result = {}315for key in keys:316values = [value[key] for value in metadata]317result[key] = values318319# The ragged objects need some special handling320for key in ["label", "bbox"]:321values = []322objects = [value["objects"] for value in metadata]323for obj_list in objects:324values.append([o[key] for o in obj_list])325result["objects/" + key] = values326return result327328329def build_dataset_from_metadata(metadata):330"""Builds TensorFlow dataset from the image metadata of VOC dataset."""331# The objects need some manual conversion to ragged tensor.332metadata["labels"] = tf.ragged.constant(metadata["labels"])333metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])334metadata["objects/bbox"] = tf.ragged.constant(335metadata["objects/bbox"], ragged_rank=1336)337338dataset = tf.data.Dataset.from_tensor_slices(metadata)339dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)340return dataset341342343def load_voc(344year="2007",345split="trainval",346data_dir="./",347voc_url=VOC_2007_URL,348):349extracted_dir = os.path.join("VOCdevkit", f"VOC{year}")350get_data = keras.utils.get_file(351fname=os.path.basename(voc_url),352origin=voc_url,353cache_dir=data_dir,354extract=True,355)356data_dir = os.path.join(get_data, extracted_dir)357image_ids = get_image_ids(data_dir, split)358metadata = build_metadata(data_dir, image_ids)359dataset = build_dataset_from_metadata(metadata)360361return dataset362363364"""365## Load the dataset366367Let's load the training data. Here, we load both the VOC 2007 and 2012 datasets368and split them into training and validation sets.369"""370train_ds_2007 = load_voc(371year="2007",372split="trainval",373data_dir="./",374voc_url=VOC_2007_URL,375)376train_ds_2012 = load_voc(377year="2012",378split="trainval",379data_dir="./",380voc_url=VOC_2012_URL,381)382eval_ds = load_voc(383year="2007",384split="test",385data_dir="./",386voc_url=VOC_2007_test_URL,387)388389"""390## Inference using a pre-trained object detector391392Let's begin with the simplest `KerasHub` API: a pre-trained object detector. In393this example, we will construct an object detector that was pre-trained on the394`COCO` dataset. We'll use this model to detect objects in a sample image.395396The highest-level module in KerasHub is a `task`. A `task` is a `keras.Model`397consisting of a (generally pre-trained) backbone model and task-specific layers.398Here's an example using `keras_hub.models.ImageObjectDetector` with the399`RetinaNet` model architecture and `ResNet50` as the backbone.400401`ResNet` is a great starting model when constructing an image classification402pipeline. This architecture manages to achieve high accuracy while using a403relatively small number of parameters. If a ResNet isn't powerful enough for the404task you are hoping to solve, be sure to check out KerasHub's other available405backbones here https://keras.io/keras_hub/presets/406"""407408object_detector = keras_hub.models.ImageObjectDetector.from_preset(409"retinanet_resnet50_fpn_coco"410)411object_detector.summary()412413"""414## Preprocessing Layers415416Let's define the below preprocessing layers:417418- Resizing Layer: Resizes the image and maintains the aspect ratio by applying419padding when `pad_to_aspect_ratio=True`. Also, sets the default bounding box420format for representing the data.421- Max Bounding Box Layer: Limits the maximum number of bounding boxes per image.422"""423image_size = (800, 800)424batch_size = 4425bbox_format = "yxyx"426epochs = 5427428resizing = keras.layers.Resizing(429height=image_size[0],430width=image_size[1],431interpolation="bilinear",432pad_to_aspect_ratio=True,433bounding_box_format=bbox_format,434)435436max_box_layer = keras.layers.MaxNumBoundingBoxes(437max_number=100, bounding_box_format=bbox_format438)439440"""441### Predict and Visualize442443Next, let's obtain predictions from our object detector by loading the image and444visualizing them. We'll apply the preprocessing pipeline defined in the445preprocessing layers step.446"""447448filepath = keras.utils.get_file(449origin="http://images.cocodataset.org/val2017/000000039769.jpg",450)451image = keras.utils.load_img(filepath)452image = keras.ops.cast(image, "float32")453image = keras.ops.expand_dims(image, axis=0)454455predictions = object_detector.predict(image, batch_size=1)456457keras.visualization.plot_bounding_box_gallery(458resizing(image), # resize image as per prediction preprocessing pipeline459bounding_box_format=bbox_format,460y_pred=predictions,461scale=4,462class_mapping=COCO_90_CLASS_MAPPING,463)464465"""466## Fine tuning a pretrained object detector467468In this guide, we'll assemble a full training pipeline for a KerasHub `RetinaNet`469object detection model. This includes data loading, augmentation, training, and470inference using Pascal VOC 2007 & 2012 dataset!471"""472473"""474## TFDS Preprocessing475476This preprocessing step prepares the TFDS dataset for object detection. It477includes:478- Merging the Pascal VOC 2007 and 2012 datasets.479- Resizing all images to a resolution of 800x800 pixels.480- Limiting the number of bounding boxes per image to a maximum of 100.481- Finally, the resulting dataset is batched into sets of 4 images and bounding482box annotations.483"""484485486def decode_custom_tfds(record):487"""Decodes a custom TFDS record into a dictionary.488489Args:490record: A dictionary representing a single TFDS record.491492Returns:493A dictionary with "images" and "bounding_boxes".494"""495image = record["image"]496boxes = record["objects/bbox"]497labels = record["objects/label"]498499bounding_boxes = {"boxes": boxes, "labels": labels}500501return {"images": image, "bounding_boxes": bounding_boxes}502503504def convert_to_tuple(record):505"""Converts a decoded TFDS record to a tuple for KerasHub.506507Args:508record: A dictionary returned by `decode_custom_tfds`.509510Returns:511A tuple (image, bounding_boxes).512"""513return record["images"], {514"boxes": record["bounding_boxes"]["boxes"],515"labels": record["bounding_boxes"]["labels"],516}517518519def preprocess_tfds(ds, resizing, max_box_layer, batch_size):520"""Preprocesses a TFDS dataset for object detection.521522Args:523ds: The TFDS dataset.524resizing: A resizing function.525max_box_layer: A max box processing function.526batch_size: The batch size.527528Returns:529A preprocessed TFDS dataset.530"""531ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)532ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)533ds = ds.batch(batch_size, drop_remainder=True)534return ds535536537"""538Now concatenate both 2007 and 2012 VOC data539"""540train_ds = train_ds_2007.concatenate(train_ds_2012)541train_ds = train_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)542train_ds = preprocess_tfds(train_ds, resizing, max_box_layer, batch_size)543544"""545Load the eval data546"""547eval_ds = eval_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)548eval_ds = preprocess_tfds(eval_ds, resizing, max_box_layer, batch_size)549550"""551### Let's visualize a batch of training data552"""553record = next(iter(train_ds.shuffle(100).take(1)))554keras.visualization.plot_bounding_box_gallery(555record["images"],556bounding_box_format=bbox_format,557y_true=record["bounding_boxes"],558scale=3,559rows=2,560cols=2,561class_mapping=INDEX_TO_CLASS,562)563564"""565### Decode TFDS records to a tuple for KerasHub566"""567train_ds = train_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)568train_ds = train_ds.prefetch(tf.data.AUTOTUNE)569570eval_ds = eval_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)571eval_ds = eval_ds.prefetch(tf.data.AUTOTUNE)572573"""574## Configure RetinaNet Model575576Configure the model with `backbone`, `num_classes` and `preprocessor`.577Use callbacks for recording logs and saving checkpoints.578"""579580581def get_callbacks(experiment_path):582"""Creates a list of callbacks for model training.583584Args:585experiment_path (str): Path to the experiment directory.586587Returns:588List of keras callback instances.589"""590tb_logs_path = os.path.join(experiment_path, "logs")591backup_path = os.path.join(experiment_path, "backup")592ckpt_path = os.path.join(experiment_path, "weights")593return [594keras.callbacks.BackupAndRestore(backup_path, delete_checkpoint=False),595keras.callbacks.TensorBoard(596tb_logs_path,597update_freq=1,598),599keras.callbacks.ModelCheckpoint(600os.path.join(ckpt_path, "{epoch:04d}-{val_loss:.2f}.weights.h5"),601save_best_only=True,602save_weights_only=True,603verbose=1,604),605]606607608"""609## Load backbone weights and preprocessor config610611Let's use the "retinanet_resnet50_fpn_coco" pretrained weights as the backbone612model, applying its predefined configuration from the preprocessor of the613"retinanet_resnet50_fpn_coco" preset.614Define a RetinaNet object detector model with the backbone and preprocessor615specified above, and set `num_classes` to 20 to represent the object categories616from Pascal VOC.617Finally, compile the model using Mean Absolute Error (MAE) as the box loss.618"""619620backbone = keras_hub.models.Backbone.from_preset("retinanet_resnet50_fpn_coco")621622preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor.from_preset(623"retinanet_resnet50_fpn_coco"624)625model = keras_hub.models.RetinaNetObjectDetector(626backbone=backbone, num_classes=len(CLASSES), preprocessor=preprocessor627)628model.compile(box_loss=keras.losses.MeanAbsoluteError(reduction="sum"))629630"""631## Train the model632633Now that the object detector model is compiled, let's train it using the634training and validation data we created earlier.635For demonstration purposes, we have used a small number of epochs. You can636increase the number of epochs to achieve better results.637638**Note:** The model is trained on an L4 GPU. Training for 5 epochs on a T4 GPU639takes approximately 7 hours.640"""641642model.fit(643train_ds,644epochs=epochs,645validation_data=eval_ds,646callbacks=get_callbacks("fine_tuning"),647)648649"""650### Prediction on evaluation data651652Let's make predictions using our model on the evaluation dataset.653"""654images, y_true = next(iter(eval_ds.shuffle(50).take(1)))655y_pred = model.predict(images)656657"""658### Plot the predictions659"""660keras.visualization.plot_bounding_box_gallery(661images,662bounding_box_format=bbox_format,663y_true=y_true,664y_pred=y_pred,665scale=3,666rows=2,667cols=2,668class_mapping=INDEX_TO_CLASS,669)670671"""672## Custom training object detector673674Additionally, you can customize the object detector by modifying the image675converter, selecting a different image encoder, etc.676677### Image Converter678679The `RetinaNetImageConverter` class prepares images for use with the `RetinaNet`680object detection model. Here's what it does:681682- Scaling and Offsetting683- ImageNet Normalization684- Resizing685"""686687image_converter = keras_hub.layers.RetinaNetImageConverter(scale=1 / 255)688689preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor(690image_converter=image_converter691)692693"""694### Image Encoder and RetinaNet Backbone695696The image encoder, while typically initialized with pre-trained weights697(e.g., from ImageNet), can also be instantiated without them. This results in698the image encoder (and, consequently, the entire object detection network built699upon it) having randomly initialized weights.700701Here we load pre-trained ResNet50 model.702This will serve as the base for extracting image features.703704And then build the RetinaNet Feature Pyramid Network (FPN) on top of the ResNet50705backbone. The FPN creates multi-scale feature maps for better object detection706at different sizes.707708**Note:**709`use_p5`: If True, the output of the last backbone layer (typically `P5` in an710`FPN`) is used as input to create higher-level feature maps (e.g., `P6`, `P7`)711through additional convolutional layers. If `False`, the original `P5` feature712map from the backbone is directly used as input for creating the coarser levels,713bypassing any further processing of `P5` within the feature pyramid. Defaults to714`False`.715"""716717image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")718719backbone = keras_hub.models.RetinaNetBackbone(720image_encoder=image_encoder, min_level=3, max_level=5, use_p5=True721)722723"""724### Train and visualize RetinaNet model725726**Note:** Training the model (for demonstration purposes only 5 epochs). In a727real scenario, you would train for many more epochs (often hundreds) to achieve728good results.729"""730model = keras_hub.models.RetinaNetObjectDetector(731backbone=backbone,732num_classes=len(CLASSES),733preprocessor=preprocessor,734use_prediction_head_norm=True,735)736model.compile(737optimizer=keras.optimizers.Adam(learning_rate=0.001),738box_loss=keras.losses.MeanAbsoluteError(reduction="sum"),739)740741model.fit(742train_ds,743epochs=epochs,744validation_data=eval_ds,745callbacks=get_callbacks("custom_training"),746)747748images, y_true = next(iter(eval_ds.shuffle(50).take(1)))749y_pred = model.predict(images)750751keras.visualization.plot_bounding_box_gallery(752images,753bounding_box_format=bbox_format,754y_true=y_true,755y_pred=y_pred,756scale=3,757rows=2,758cols=2,759class_mapping=INDEX_TO_CLASS,760)761762"""763## Conclusion764765In this tutorial, you learned how to custom train and fine-tune the RetinaNet766object detector.767768You can experiment with different existing backbones trained on ImageNet as the769image encoder, or you can fine-tune your own backbone.770771This configuration is equivalent to training the model from scratch, as opposed772to fine-tuning a pre-trained model.773774Training from scratch generally requires significantly more data and775computational resources to achieve performance comparable to fine-tuning.776777To achieve better results when fine-tuning the model, you can increase the778number of epochs and experiment with different hyperparameter values.779In addition to the training data used here, you can also use other object780detection datasets, but keep in mind that custom training these requires781high GPU memory.782"""783784785