Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/object_detection_retinanet.py
17129 views
1
"""
2
Title: Object Detection with KerasHub
3
Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Siva Sravana Kumar Neeli](https://github.com/sineeli)
4
Date created: 2026/03/27
5
Last modified: 2026/03/27
6
Description: RetinaNet Object Detection: Training, Fine-tuning, and Inference.
7
Accelerator: GPU
8
"""
9
10
"""
11
![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png)
12
13
## Introduction
14
15
Object detection is a crucial computer vision task that goes beyond simple image
16
classification. It requires models to not only identify the types of objects
17
present in an image but also pinpoint their locations using bounding boxes. This
18
dual requirement of classification and localization makes object detection a
19
more complex and powerful tool.
20
Object detection models are broadly classified into two categories: "two-stage"
21
and "single-stage" detectors. Two-stage detectors often achieve higher accuracy
22
by first proposing regions of interest and then classifying them. However, this
23
approach can be computationally expensive. Single-stage detectors, on the other
24
hand, aim for speed by directly predicting object classes and bounding boxes in
25
a single pass.
26
27
In this tutorial, we'll be diving into `RetinaNet`, a powerful object detection
28
model known for its speed and precision. `RetinaNet` is a single-stage detector,
29
a design choice that allows it to be remarkably efficient. Its impressive
30
performance stems from two key architectural innovations:
31
1. **Feature Pyramid Network (FPN):** FPN equips `RetinaNet` with the ability to
32
seamlessly detect objects of all scales, from distant, tiny instances to large,
33
prominent ones.
34
2. **Focal Loss:** This ingenious loss function tackles the common challenge of
35
imbalanced data by focusing the model's learning on the most crucial and
36
challenging object examples, leading to enhanced accuracy without compromising
37
speed.
38
39
![retinanet](/img/guides/object_detection_retinanet/retinanet_architecture.png)
40
41
### References
42
43
- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
44
- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
45
"""
46
47
"""
48
## Setup and Imports
49
50
Let's install the dependencies and import the necessary modules.
51
52
To run this tutorial, you will need to install the following packages:
53
54
* `keras-hub`
55
* `keras`
56
* `opencv-python`
57
"""
58
59
"""shell
60
pip install -q --upgrade keras-hub
61
pip install -q --upgrade keras
62
pip install -q opencv-python
63
"""
64
65
import os
66
67
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
68
import keras
69
import keras_hub
70
import tensorflow as tf
71
72
"""
73
### Helper functions
74
75
We download the Pascal VOC 2012 and 2007 datasets using these helper functions,
76
prepare them for the object detection task, and split them into training and
77
validation datasets.
78
"""
79
# @title Helper functions
80
import logging
81
import multiprocessing
82
import xml
83
84
import tensorflow_datasets as tfds
85
86
VOC_2007_URL = (
87
"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar"
88
)
89
VOC_2012_URL = (
90
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
91
)
92
VOC_2007_test_URL = (
93
"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar"
94
)
95
96
# Note that this list doesn't contain the background class. In the
97
# classification use case, the label is 0 based (aeroplane -> 0), whereas in
98
# segmentation use case, the 0 is reserved for background, so aeroplane maps to
99
# 1.
100
CLASSES = [
101
"aeroplane",
102
"bicycle",
103
"bird",
104
"boat",
105
"bottle",
106
"bus",
107
"car",
108
"cat",
109
"chair",
110
"cow",
111
"diningtable",
112
"dog",
113
"horse",
114
"motorbike",
115
"person",
116
"pottedplant",
117
"sheep",
118
"sofa",
119
"train",
120
"tvmonitor",
121
]
122
COCO_90_CLASS_MAPPING = {
123
1: "person",
124
2: "bicycle",
125
3: "car",
126
4: "motorcycle",
127
5: "airplane",
128
6: "bus",
129
7: "train",
130
8: "truck",
131
9: "boat",
132
10: "traffic light",
133
11: "fire hydrant",
134
13: "stop sign",
135
14: "parking meter",
136
15: "bench",
137
16: "bird",
138
17: "cat",
139
18: "dog",
140
19: "horse",
141
20: "sheep",
142
21: "cow",
143
22: "elephant",
144
23: "bear",
145
24: "zebra",
146
25: "giraffe",
147
27: "backpack",
148
28: "umbrella",
149
31: "handbag",
150
32: "tie",
151
33: "suitcase",
152
34: "frisbee",
153
35: "skis",
154
36: "snowboard",
155
37: "sports ball",
156
38: "kite",
157
39: "baseball bat",
158
40: "baseball glove",
159
41: "skateboard",
160
42: "surfboard",
161
43: "tennis racket",
162
44: "bottle",
163
46: "wine glass",
164
47: "cup",
165
48: "fork",
166
49: "knife",
167
50: "spoon",
168
51: "bowl",
169
52: "banana",
170
53: "apple",
171
54: "sandwich",
172
55: "orange",
173
56: "broccoli",
174
57: "carrot",
175
58: "hot dog",
176
59: "pizza",
177
60: "donut",
178
61: "cake",
179
62: "chair",
180
63: "couch",
181
64: "potted plant",
182
65: "bed",
183
67: "dining table",
184
70: "toilet",
185
72: "tv",
186
73: "laptop",
187
74: "mouse",
188
75: "remote",
189
76: "keyboard",
190
77: "cell phone",
191
78: "microwave",
192
79: "oven",
193
80: "toaster",
194
81: "sink",
195
82: "refrigerator",
196
84: "book",
197
85: "clock",
198
86: "vase",
199
87: "scissors",
200
88: "teddy bear",
201
89: "hair drier",
202
90: "toothbrush",
203
}
204
# This is used to map between string class to index.
205
CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
206
INDEX_TO_CLASS = {index: name for index, name in enumerate(CLASSES)}
207
208
209
def get_image_ids(data_dir, split):
210
"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""
211
data_file_mapping = {
212
"train": "train.txt",
213
"eval": "val.txt",
214
"trainval": "trainval.txt",
215
"test": "test.txt",
216
}
217
with open(
218
os.path.join(data_dir, "ImageSets", "Main", data_file_mapping[split]),
219
"r",
220
) as f:
221
image_ids = f.read().splitlines()
222
logging.info(f"Received {len(image_ids)} images for {split} dataset.")
223
return image_ids
224
225
226
def load_images(example):
227
"""Loads VOC images for segmentation task from the provided paths"""
228
image_file_path = example.pop("image/file_path")
229
image = tf.io.read_file(image_file_path)
230
image = tf.image.decode_jpeg(image)
231
232
example.update(
233
{
234
"image": image,
235
}
236
)
237
return example
238
239
240
def parse_annotation_data(annotation_file_path):
241
"""Parse the annotation XML file for the image.
242
243
The annotation contains the metadata, as well as the object bounding box
244
information.
245
246
"""
247
with open(annotation_file_path, "r") as f:
248
root = xml.etree.ElementTree.parse(f).getroot()
249
250
size = root.find("size")
251
width = int(size.find("width").text)
252
height = int(size.find("height").text)
253
filename = root.find("filename").text
254
255
objects = []
256
for obj in root.findall("object"):
257
# Get object's label name.
258
label = CLASS_TO_INDEX[obj.find("name").text.lower()]
259
bndbox = obj.find("bndbox")
260
xmax = int(float(bndbox.find("xmax").text))
261
xmin = int(float(bndbox.find("xmin").text))
262
ymax = int(float(bndbox.find("ymax").text))
263
ymin = int(float(bndbox.find("ymin").text))
264
objects.append(
265
{
266
"label": label,
267
"bbox": [ymin, xmin, ymax, xmax],
268
}
269
)
270
271
return {
272
"image/filename": filename,
273
"width": width,
274
"height": height,
275
"objects": objects,
276
}
277
278
279
def parse_single_image(annotation_file_path):
280
"""Creates metadata of VOC images and path."""
281
data_dir, annotation_file_name = os.path.split(annotation_file_path)
282
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
283
image_annotations = parse_annotation_data(annotation_file_path)
284
285
result = {
286
"image/file_path": os.path.join(
287
data_dir, "JPEGImages", image_annotations["image/filename"]
288
)
289
}
290
result.update(image_annotations)
291
# Labels field should be same as the 'object.label'
292
labels = list({o["label"] for o in result["objects"]})
293
result["labels"] = sorted(labels)
294
return result
295
296
297
def build_metadata(data_dir, image_ids):
298
"""Transpose the metadata which converts from a list of dicts to a dict of lists."""
299
# Parallel process all the images.
300
annotation_file_paths = [
301
os.path.join(data_dir, "Annotations", f"{image_id}.xml")
302
for image_id in image_ids
303
]
304
pool_size = min(10, len(image_ids))
305
with multiprocessing.Pool(pool_size) as p:
306
metadata = p.map(parse_single_image, annotation_file_paths)
307
308
keys = [
309
"image/filename",
310
"image/file_path",
311
"labels",
312
"width",
313
"height",
314
]
315
result = {}
316
for key in keys:
317
values = [value[key] for value in metadata]
318
result[key] = values
319
320
# The ragged objects need some special handling
321
for key in ["label", "bbox"]:
322
values = []
323
objects = [value["objects"] for value in metadata]
324
for obj_list in objects:
325
values.append([o[key] for o in obj_list])
326
result["objects/" + key] = values
327
return result
328
329
330
def build_dataset_from_metadata(metadata):
331
"""Builds TensorFlow dataset from the image metadata of VOC dataset."""
332
# The objects need some manual conversion to ragged tensor.
333
metadata["labels"] = tf.ragged.constant(metadata["labels"])
334
metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
335
metadata["objects/bbox"] = tf.ragged.constant(
336
metadata["objects/bbox"], ragged_rank=1
337
)
338
339
dataset = tf.data.Dataset.from_tensor_slices(metadata)
340
dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
341
return dataset
342
343
344
def load_voc(
345
year="2007",
346
split="trainval",
347
data_dir="./",
348
voc_url=VOC_2007_URL,
349
):
350
extracted_dir = os.path.join("VOCdevkit", f"VOC{year}")
351
get_data = keras.utils.get_file(
352
fname=os.path.basename(voc_url),
353
origin=voc_url,
354
cache_dir=data_dir,
355
extract=True,
356
)
357
data_dir = os.path.join(get_data, extracted_dir)
358
image_ids = get_image_ids(data_dir, split)
359
metadata = build_metadata(data_dir, image_ids)
360
dataset = build_dataset_from_metadata(metadata)
361
362
return dataset
363
364
365
"""
366
## Load the dataset
367
368
Let's load the training data. Here, we load both the VOC 2007 and 2012 datasets
369
and split them into training and validation sets.
370
"""
371
train_ds_2007 = load_voc(
372
year="2007",
373
split="trainval",
374
data_dir="./",
375
voc_url=VOC_2007_URL,
376
)
377
train_ds_2012 = load_voc(
378
year="2012",
379
split="trainval",
380
data_dir="./",
381
voc_url=VOC_2012_URL,
382
)
383
eval_ds = load_voc(
384
year="2007",
385
split="test",
386
data_dir="./",
387
voc_url=VOC_2007_test_URL,
388
)
389
390
"""
391
## Inference using a pre-trained object detector
392
393
Let's begin with the simplest `KerasHub` API: a pre-trained object detector. In
394
this example, we will construct an object detector that was pre-trained on the
395
`COCO` dataset. We'll use this model to detect objects in a sample image.
396
397
The highest-level module in KerasHub is a `task`. A `task` is a `keras.Model`
398
consisting of a (generally pre-trained) backbone model and task-specific layers.
399
Here's an example using `keras_hub.models.ImageObjectDetector` with the
400
`RetinaNet` model architecture and `ResNet50` as the backbone.
401
402
`ResNet` is a great starting model when constructing an image classification
403
pipeline. This architecture manages to achieve high accuracy while using a
404
relatively small number of parameters. If a ResNet isn't powerful enough for the
405
task you are hoping to solve, be sure to check out KerasHub's other available
406
backbones here https://keras.io/keras_hub/presets/
407
"""
408
409
object_detector = keras_hub.models.ImageObjectDetector.from_preset(
410
"retinanet_resnet50_fpn_coco"
411
)
412
object_detector.summary()
413
414
"""
415
## Preprocessing Layers
416
417
Let's define the below preprocessing layers:
418
419
- Resizing Layer: Resizes the image and maintains the aspect ratio by applying
420
padding when `pad_to_aspect_ratio=True`. Also, sets the default bounding box
421
format for representing the data.
422
- Max Bounding Box Layer: Limits the maximum number of bounding boxes per image.
423
"""
424
image_size = (800, 800)
425
batch_size = 4
426
bbox_format = "yxyx"
427
epochs = 5
428
429
resizing = keras.layers.Resizing(
430
height=image_size[0],
431
width=image_size[1],
432
interpolation="bilinear",
433
pad_to_aspect_ratio=True,
434
bounding_box_format=bbox_format,
435
)
436
437
max_box_layer = keras.layers.MaxNumBoundingBoxes(
438
max_number=100, bounding_box_format=bbox_format
439
)
440
441
"""
442
### Predict and Visualize
443
444
Next, let's obtain predictions from our object detector by loading the image and
445
visualizing them. We'll apply the preprocessing pipeline defined in the
446
preprocessing layers step.
447
"""
448
449
filepath = keras.utils.get_file(
450
origin="http://images.cocodataset.org/val2017/000000039769.jpg",
451
)
452
image = keras.utils.load_img(filepath)
453
image = keras.ops.cast(image, "float32")
454
image = keras.ops.expand_dims(image, axis=0)
455
456
predictions = object_detector.predict(image, batch_size=1)
457
458
keras.visualization.plot_bounding_box_gallery(
459
resizing(image), # resize image as per prediction preprocessing pipeline
460
bounding_box_format=bbox_format,
461
y_pred=predictions,
462
scale=4,
463
class_mapping=COCO_90_CLASS_MAPPING,
464
)
465
466
"""
467
## Fine tuning a pretrained object detector
468
469
In this guide, we'll assemble a full training pipeline for a KerasHub `RetinaNet`
470
object detection model. This includes data loading, augmentation, training, and
471
inference using Pascal VOC 2007 & 2012 dataset!
472
"""
473
474
"""
475
## TFDS Preprocessing
476
477
This preprocessing step prepares the TFDS dataset for object detection. It
478
includes:
479
- Merging the Pascal VOC 2007 and 2012 datasets.
480
- Resizing all images to a resolution of 800x800 pixels.
481
- Limiting the number of bounding boxes per image to a maximum of 100.
482
- Finally, the resulting dataset is batched into sets of 4 images and bounding
483
box annotations.
484
"""
485
486
487
def decode_custom_tfds(record):
488
"""Decodes a custom TFDS record into a dictionary.
489
490
Args:
491
record: A dictionary representing a single TFDS record.
492
493
Returns:
494
A dictionary with "images" and "bounding_boxes".
495
"""
496
image = record["image"]
497
boxes = record["objects/bbox"]
498
labels = record["objects/label"]
499
500
bounding_boxes = {"boxes": boxes, "labels": labels}
501
502
return {"images": image, "bounding_boxes": bounding_boxes}
503
504
505
def convert_to_tuple(record):
506
"""Converts a decoded TFDS record to a tuple for KerasHub.
507
508
Args:
509
record: A dictionary returned by `decode_custom_tfds`.
510
511
Returns:
512
A tuple (image, bounding_boxes).
513
"""
514
return record["images"], {
515
"boxes": record["bounding_boxes"]["boxes"],
516
"labels": record["bounding_boxes"]["labels"],
517
}
518
519
520
def preprocess_tfds(ds, resizing, max_box_layer, batch_size):
521
"""Preprocesses a TFDS dataset for object detection.
522
523
Args:
524
ds: The TFDS dataset.
525
resizing: A resizing function.
526
max_box_layer: A max box processing function.
527
batch_size: The batch size.
528
529
Returns:
530
A preprocessed TFDS dataset.
531
"""
532
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
533
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)
534
ds = ds.batch(batch_size, drop_remainder=True)
535
return ds
536
537
538
"""
539
Now concatenate both 2007 and 2012 VOC data
540
"""
541
train_ds = train_ds_2007.concatenate(train_ds_2012)
542
train_ds = train_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)
543
train_ds = preprocess_tfds(train_ds, resizing, max_box_layer, batch_size)
544
545
"""
546
Load the eval data
547
"""
548
eval_ds = eval_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)
549
eval_ds = preprocess_tfds(eval_ds, resizing, max_box_layer, batch_size)
550
551
"""
552
### Let's visualize a batch of training data
553
"""
554
record = next(iter(train_ds.shuffle(100).take(1)))
555
keras.visualization.plot_bounding_box_gallery(
556
record["images"],
557
bounding_box_format=bbox_format,
558
y_true=record["bounding_boxes"],
559
scale=3,
560
rows=2,
561
cols=2,
562
class_mapping=INDEX_TO_CLASS,
563
)
564
565
"""
566
### Decode TFDS records to a tuple for KerasHub
567
"""
568
train_ds = train_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
569
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
570
571
eval_ds = eval_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
572
eval_ds = eval_ds.prefetch(tf.data.AUTOTUNE)
573
574
"""
575
## Configure RetinaNet Model
576
577
Configure the model with `backbone`, `num_classes` and `preprocessor`.
578
Use callbacks for recording logs and saving checkpoints.
579
"""
580
581
582
def get_callbacks(experiment_path):
583
"""Creates a list of callbacks for model training.
584
585
Args:
586
experiment_path (str): Path to the experiment directory.
587
588
Returns:
589
List of keras callback instances.
590
"""
591
tb_logs_path = os.path.join(experiment_path, "logs")
592
backup_path = os.path.join(experiment_path, "backup")
593
ckpt_path = os.path.join(experiment_path, "weights")
594
return [
595
keras.callbacks.BackupAndRestore(backup_path, delete_checkpoint=False),
596
keras.callbacks.TensorBoard(
597
tb_logs_path,
598
update_freq=1,
599
),
600
keras.callbacks.ModelCheckpoint(
601
os.path.join(ckpt_path, "{epoch:04d}-{val_loss:.2f}.weights.h5"),
602
save_best_only=True,
603
save_weights_only=True,
604
verbose=1,
605
),
606
]
607
608
609
"""
610
## Load backbone weights and preprocessor config
611
612
Let's use the "retinanet_resnet50_fpn_coco" pretrained weights as the backbone
613
model, applying its predefined configuration from the preprocessor of the
614
"retinanet_resnet50_fpn_coco" preset.
615
Define a RetinaNet object detector model with the backbone and preprocessor
616
specified above, and set `num_classes` to 20 to represent the object categories
617
from Pascal VOC.
618
Finally, compile the model using Mean Absolute Error (MAE) as the box loss.
619
"""
620
621
backbone = keras_hub.models.Backbone.from_preset("retinanet_resnet50_fpn_coco")
622
623
preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor.from_preset(
624
"retinanet_resnet50_fpn_coco"
625
)
626
model = keras_hub.models.RetinaNetObjectDetector(
627
backbone=backbone, num_classes=len(CLASSES), preprocessor=preprocessor
628
)
629
model.compile(box_loss=keras.losses.MeanAbsoluteError(reduction="sum"))
630
631
"""
632
## Train the model
633
634
Now that the object detector model is compiled, let's train it using the
635
training and validation data we created earlier.
636
For demonstration purposes, we have used a small number of epochs. You can
637
increase the number of epochs to achieve better results.
638
639
**Note:** The model is trained on an L4 GPU. Training for 5 epochs on a T4 GPU
640
takes approximately 7 hours.
641
"""
642
643
model.fit(
644
train_ds,
645
epochs=epochs,
646
validation_data=eval_ds,
647
callbacks=get_callbacks("fine_tuning"),
648
)
649
650
"""
651
### Prediction on evaluation data
652
653
Let's make predictions using our model on the evaluation dataset.
654
"""
655
images, y_true = next(iter(eval_ds.shuffle(50).take(1)))
656
y_pred = model.predict(images)
657
658
"""
659
### Plot the predictions
660
"""
661
keras.visualization.plot_bounding_box_gallery(
662
images,
663
bounding_box_format=bbox_format,
664
y_true=y_true,
665
y_pred=y_pred,
666
scale=3,
667
rows=2,
668
cols=2,
669
class_mapping=INDEX_TO_CLASS,
670
)
671
672
"""
673
## Custom training object detector
674
675
Additionally, you can customize the object detector by modifying the image
676
converter, selecting a different image encoder, etc.
677
678
### Image Converter
679
680
The `RetinaNetImageConverter` class prepares images for use with the `RetinaNet`
681
object detection model. Here's what it does:
682
683
- Scaling and Offsetting
684
- ImageNet Normalization
685
- Resizing
686
"""
687
688
image_converter = keras_hub.layers.RetinaNetImageConverter(scale=1 / 255)
689
690
preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor(
691
image_converter=image_converter
692
)
693
694
"""
695
### Image Encoder and RetinaNet Backbone
696
697
The image encoder, while typically initialized with pre-trained weights
698
(e.g., from ImageNet), can also be instantiated without them. This results in
699
the image encoder (and, consequently, the entire object detection network built
700
upon it) having randomly initialized weights.
701
702
Here we load pre-trained ResNet50 model.
703
This will serve as the base for extracting image features.
704
705
And then build the RetinaNet Feature Pyramid Network (FPN) on top of the ResNet50
706
backbone. The FPN creates multi-scale feature maps for better object detection
707
at different sizes.
708
709
**Note:**
710
`use_p5`: If True, the output of the last backbone layer (typically `P5` in an
711
`FPN`) is used as input to create higher-level feature maps (e.g., `P6`, `P7`)
712
through additional convolutional layers. If `False`, the original `P5` feature
713
map from the backbone is directly used as input for creating the coarser levels,
714
bypassing any further processing of `P5` within the feature pyramid. Defaults to
715
`False`.
716
"""
717
718
image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
719
720
backbone = keras_hub.models.RetinaNetBackbone(
721
image_encoder=image_encoder, min_level=3, max_level=5, use_p5=True
722
)
723
724
"""
725
### Train and visualize RetinaNet model
726
727
**Note:** Training the model (for demonstration purposes only 5 epochs). In a
728
real scenario, you would train for many more epochs (often hundreds) to achieve
729
good results.
730
"""
731
model = keras_hub.models.RetinaNetObjectDetector(
732
backbone=backbone,
733
num_classes=len(CLASSES),
734
preprocessor=preprocessor,
735
use_prediction_head_norm=True,
736
)
737
model.compile(
738
optimizer=keras.optimizers.Adam(learning_rate=0.001),
739
box_loss=keras.losses.MeanAbsoluteError(reduction="sum"),
740
)
741
742
model.fit(
743
train_ds,
744
epochs=epochs,
745
validation_data=eval_ds,
746
callbacks=get_callbacks("custom_training"),
747
)
748
749
images, y_true = next(iter(eval_ds.shuffle(50).take(1)))
750
y_pred = model.predict(images)
751
752
keras.visualization.plot_bounding_box_gallery(
753
images,
754
bounding_box_format=bbox_format,
755
y_true=y_true,
756
y_pred=y_pred,
757
scale=3,
758
rows=2,
759
cols=2,
760
class_mapping=INDEX_TO_CLASS,
761
)
762
763
"""
764
## Conclusion
765
766
In this tutorial, you learned how to custom train and fine-tune the RetinaNet
767
object detector.
768
769
You can experiment with different existing backbones trained on ImageNet as the
770
image encoder, or you can fine-tune your own backbone.
771
772
This configuration is equivalent to training the model from scratch, as opposed
773
to fine-tuning a pre-trained model.
774
775
Training from scratch generally requires significantly more data and
776
computational resources to achieve performance comparable to fine-tuning.
777
778
To achieve better results when fine-tuning the model, you can increase the
779
number of epochs and experiment with different hyperparameter values.
780
In addition to the training data used here, you can also use other object
781
detection datasets, but keep in mind that custom training these requires
782
high GPU memory.
783
"""
784
785