Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_cv/object_detection_keras_cv.py
3283 views
1
"""
2
Title: Object Detection with KerasCV
3
Author: [lukewood](https://twitter.com/luke_wood_ml), Ian Stenbit, Tirth Patel
4
Date created: 2023/04/08
5
Last modified: 2023/08/10
6
Description: Train an object detection model with KerasCV.
7
Accelerator: GPU
8
"""
9
10
"""
11
KerasCV offers a complete set of production grade APIs to solve object detection
12
problems.
13
These APIs include object-detection-specific
14
data augmentation techniques, Keras native COCO metrics, bounding box format
15
conversion utilities, visualization tools, pretrained object detection models,
16
and everything you need to train your own state of the art object detection
17
models!
18
19
Let's give KerasCV's object detection API a spin.
20
"""
21
22
"""shell
23
pip install -q --upgrade keras-cv
24
pip install -q --upgrade keras # Upgrade to Keras 3.
25
"""
26
27
import os
28
29
os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
30
31
from tensorflow import data as tf_data
32
import tensorflow_datasets as tfds
33
import keras
34
import keras_cv
35
import numpy as np
36
from keras_cv import bounding_box
37
import os
38
from keras_cv import visualization
39
import tqdm
40
41
"""
42
## Object detection introduction
43
44
Object detection is the process of identifying, classifying,
45
and localizing objects within a given image. Typically, your inputs are
46
images, and your labels are bounding boxes with optional class
47
labels.
48
Object detection can be thought of as an extension of classification, however
49
instead of one class label for the image, you must detect and localize an
50
arbitrary number of classes.
51
52
**For example:**
53
54
<img width="300" src="https://i.imgur.com/8xSEbQD.png">
55
56
The data for the above image may look something like this:
57
```python
58
image = [height, width, 3]
59
bounding_boxes = {
60
"classes": [0], # 0 is an arbitrary class ID representing "cat"
61
"boxes": [[0.25, 0.4, .15, .1]]
62
# bounding box is in "rel_xywh" format
63
# so 0.25 represents the start of the bounding box 25% of
64
# the way across the image.
65
# The .15 represents that the width is 15% of the image width.
66
}
67
```
68
69
Since the inception of [*You Only Look Once*](https://arxiv.org/abs/1506.02640)
70
(aka YOLO),
71
object detection has primarily been solved using deep learning.
72
Most deep learning architectures do this by cleverly framing the object detection
73
problem as a combination of many small classification problems and
74
many regression problems.
75
76
More specifically, this is done by generating many anchor boxes of varying
77
shapes and sizes across the input images and assigning them each a class label,
78
as well as `x`, `y`, `width` and `height` offsets.
79
The model is trained to predict the class labels of each box, as well as the
80
`x`, `y`, `width`, and `height` offsets of each box that is predicted to be an
81
object.
82
83
**Visualization of some sample anchor boxes**:
84
85
<img width="400" src="https://i.imgur.com/cJIuiK9.jpg">
86
87
Objection detection is a technically complex problem but luckily we offer a
88
bulletproof approach to getting great results.
89
Let's do this!
90
"""
91
92
"""
93
## Perform detections with a pretrained model
94
95
![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_beginner.png)
96
97
The highest level API in the KerasCV Object Detection API is the `keras_cv.models` API.
98
This API includes fully pretrained object detection models, such as
99
`keras_cv.models.YOLOV8Detector`.
100
101
Let's get started by constructing a YOLOV8Detector pretrained on the `pascalvoc`
102
dataset.
103
"""
104
105
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
106
"yolo_v8_m_pascalvoc", bounding_box_format="xywh"
107
)
108
109
"""
110
Notice the `bounding_box_format` argument?
111
112
Recall in the section above, the format of bounding boxes:
113
114
```
115
bounding_boxes = {
116
"classes": [num_boxes],
117
"boxes": [num_boxes, 4]
118
}
119
```
120
121
This argument describes *exactly* what format the values in the `"boxes"`
122
field of the label dictionary take in your pipeline.
123
For example, a box in `xywh` format with its top left corner at the coordinates
124
(100, 100) with a width of 55 and a height of 70 would be represented by:
125
```
126
[100, 100, 55, 75]
127
```
128
129
or equivalently in `xyxy` format:
130
131
```
132
[100, 100, 155, 175]
133
```
134
135
While this may seem simple, it is a critical piece of the KerasCV object
136
detection API!
137
Every component that processes bounding boxes requires a
138
`bounding_box_format` argument.
139
You can read more about
140
KerasCV bounding box formats [in the API docs](https://keras.io/api/keras_cv/bounding_box/formats/).
141
142
143
This is done because there is no one correct format for bounding boxes!
144
Components in different pipelines expect different formats, and so by requiring
145
them to be specified we ensure that our components remain readable, reusable,
146
and clear.
147
Box format conversion bugs are perhaps the most common bug surface in object
148
detection pipelines - by requiring this parameter we mitigate against these
149
bugs (especially when combining code from many sources).
150
151
Next let's load an image:
152
"""
153
154
filepath = keras.utils.get_file(origin="https://i.imgur.com/gCNcJJI.jpg")
155
image = keras.utils.load_img(filepath)
156
image = np.array(image)
157
158
visualization.plot_image_gallery(
159
np.array([image]),
160
value_range=(0, 255),
161
rows=1,
162
cols=1,
163
scale=5,
164
)
165
166
"""
167
To use the `YOLOV8Detector` architecture with a ResNet50 backbone, you'll need to
168
resize your image to a size that is divisible by 64. This is to ensure
169
compatibility with the number of downscaling operations done by the convolution
170
layers in the ResNet.
171
172
If the resize operation distorts
173
the input's aspect ratio, the model will perform signficantly poorer. For the
174
pretrained `"yolo_v8_m_pascalvoc"` preset we are using, the final
175
`MeanAveragePrecision` on the `pascalvoc/2012` evaluation set drops to `0.15`
176
from `0.38` when using a naive resizing operation.
177
178
Additionally, if you crop to preserve the aspect ratio as you do in classification
179
your model may entirely miss some bounding boxes. As such, when running inference
180
on an object detection model we recommend the use of padding to the desired size,
181
while resizing the longest size to match the aspect ratio.
182
183
KerasCV makes resizing properly easy; simply pass `pad_to_aspect_ratio=True` to
184
a `keras_cv.layers.Resizing` layer.
185
186
This can be implemented in one line of code:
187
"""
188
189
inference_resizing = keras_cv.layers.Resizing(
190
640, 640, pad_to_aspect_ratio=True, bounding_box_format="xywh"
191
)
192
193
"""
194
This can be used as our inference preprocessing pipeline:
195
"""
196
197
image_batch = inference_resizing([image])
198
199
"""
200
`keras_cv.visualization.plot_bounding_box_gallery()` supports a `class_mapping`
201
parameter to highlight what class each box was assigned to. Let's assemble a
202
class mapping now.
203
"""
204
205
class_ids = [
206
"Aeroplane",
207
"Bicycle",
208
"Bird",
209
"Boat",
210
"Bottle",
211
"Bus",
212
"Car",
213
"Cat",
214
"Chair",
215
"Cow",
216
"Dining Table",
217
"Dog",
218
"Horse",
219
"Motorbike",
220
"Person",
221
"Potted Plant",
222
"Sheep",
223
"Sofa",
224
"Train",
225
"Tvmonitor",
226
"Total",
227
]
228
class_mapping = dict(zip(range(len(class_ids)), class_ids))
229
230
"""
231
Just like any other `keras.Model` you can predict bounding boxes using the
232
`model.predict()` API.
233
"""
234
235
y_pred = pretrained_model.predict(image_batch)
236
# y_pred is a bounding box Tensor:
237
# {"classes": ..., boxes": ...}
238
visualization.plot_bounding_box_gallery(
239
image_batch,
240
value_range=(0, 255),
241
rows=1,
242
cols=1,
243
y_pred=y_pred,
244
scale=5,
245
font_scale=0.7,
246
bounding_box_format="xywh",
247
class_mapping=class_mapping,
248
)
249
250
"""
251
In order to support this easy and intuitive inference workflow, KerasCV
252
performs non-max suppression inside of the `YOLOV8Detector` class.
253
Non-max suppression is a traditional computing algorithm that solves the problem
254
of a model detecting multiple boxes for the same object.
255
256
Non-max suppression is a highly configurable algorithm, and in most cases you
257
will want to customize the settings of your model's non-max
258
suppression operation.
259
This can be done by overriding to the `prediction_decoder` argument.
260
261
To show this concept off, let's temporarily disable non-max suppression on our
262
YOLOV8Detector. This can be done by writing to the `prediction_decoder` attribute.
263
"""
264
265
# The following NonMaxSuppression layer is equivalent to disabling the operation
266
prediction_decoder = keras_cv.layers.NonMaxSuppression(
267
bounding_box_format="xywh",
268
from_logits=True,
269
iou_threshold=1.0,
270
confidence_threshold=0.0,
271
)
272
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
273
"yolo_v8_m_pascalvoc",
274
bounding_box_format="xywh",
275
prediction_decoder=prediction_decoder,
276
)
277
278
y_pred = pretrained_model.predict(image_batch)
279
visualization.plot_bounding_box_gallery(
280
image_batch,
281
value_range=(0, 255),
282
rows=1,
283
cols=1,
284
y_pred=y_pred,
285
scale=5,
286
font_scale=0.7,
287
bounding_box_format="xywh",
288
class_mapping=class_mapping,
289
)
290
291
292
"""
293
Next, let's re-configure `keras_cv.layers.NonMaxSuppression` for our
294
use case!
295
In this case, we will tune the `iou_threshold` to `0.2`, and the
296
`confidence_threshold` to `0.7`.
297
298
Raising the `confidence_threshold` will cause the model to only output boxes
299
that have a higher confidence score. `iou_threshold` controls the threshold of
300
intersection over union (IoU) that two boxes must have in order for one to be
301
pruned out.
302
[More information on these parameters may be found in the TensorFlow API docs](https://www.tensorflow.org/api_docs/python/tf/image/combined_non_max_suppression)
303
"""
304
305
prediction_decoder = keras_cv.layers.NonMaxSuppression(
306
bounding_box_format="xywh",
307
from_logits=True,
308
# Decrease the required threshold to make predictions get pruned out
309
iou_threshold=0.2,
310
# Tune confidence threshold for predictions to pass NMS
311
confidence_threshold=0.7,
312
)
313
pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(
314
"yolo_v8_m_pascalvoc",
315
bounding_box_format="xywh",
316
prediction_decoder=prediction_decoder,
317
)
318
319
y_pred = pretrained_model.predict(image_batch)
320
visualization.plot_bounding_box_gallery(
321
image_batch,
322
value_range=(0, 255),
323
rows=1,
324
cols=1,
325
y_pred=y_pred,
326
scale=5,
327
font_scale=0.7,
328
bounding_box_format="xywh",
329
class_mapping=class_mapping,
330
)
331
332
"""
333
That looks a lot better!
334
335
## Train a custom object detection model
336
337
![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_advanced.png)
338
339
Whether you're an object detection amateur or a well seasoned veteran, assembling
340
an object detection pipeline from scratch is a massive undertaking.
341
Luckily, all KerasCV object detection APIs are built as modular components.
342
Whether you need a complete pipeline, just an object detection model, or even
343
just a conversion utility to transform your boxes from `xywh` format to `xyxy`,
344
KerasCV has you covered.
345
346
In this guide, we'll assemble a full training pipeline for a KerasCV object
347
detection model. This includes data loading, augmentation, metric evaluation,
348
and inference!
349
350
To get started, let's sort out all of our imports and define global
351
configuration parameters.
352
"""
353
354
BATCH_SIZE = 4
355
356
"""
357
## Data loading
358
359
To get started, let's discuss data loading and bounding box formatting.
360
KerasCV has a predefined format for bounding boxes.
361
To comply with this, you
362
should package your bounding boxes into a dictionary matching the
363
specification below:
364
365
```
366
bounding_boxes = {
367
# num_boxes may be a Ragged dimension
368
'boxes': Tensor(shape=[batch, num_boxes, 4]),
369
'classes': Tensor(shape=[batch, num_boxes])
370
}
371
```
372
373
`bounding_boxes['boxes']` contains the coordinates of your bounding box in a KerasCV
374
supported `bounding_box_format`.
375
KerasCV requires a `bounding_box_format` argument in all components that process
376
bounding boxes.
377
This is done to maximize your ability to plug and play individual components
378
into their object detection pipelines, as well as to make code self-documenting
379
across object detection pipelines.
380
381
To match the KerasCV API style, it is recommended that when writing a
382
custom data loader, you also support a `bounding_box_format` argument.
383
This makes it clear to those invoking your data loader what format the bounding boxes
384
are in.
385
In this example, we format our boxes to `xywh` format.
386
387
For example:
388
389
```python
390
train_ds, ds_info = your_data_loader.load(
391
split='train', bounding_box_format='xywh', batch_size=8
392
)
393
```
394
395
This clearly yields bounding boxes in the format `xywh`. You can read more about
396
KerasCV bounding box formats [in the API docs](https://keras.io/api/keras_cv/bounding_box/formats/).
397
398
Our data comes loaded into the format
399
`{"images": images, "bounding_boxes": bounding_boxes}`. This format is
400
supported in all KerasCV preprocessing components.
401
402
Let's load some data and verify that the data looks as we expect it to.
403
"""
404
405
406
def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
407
inputs = next(iter(inputs.take(1)))
408
images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
409
visualization.plot_bounding_box_gallery(
410
images,
411
value_range=value_range,
412
rows=rows,
413
cols=cols,
414
y_true=bounding_boxes,
415
scale=5,
416
font_scale=0.7,
417
bounding_box_format=bounding_box_format,
418
class_mapping=class_mapping,
419
)
420
421
422
def unpackage_raw_tfds_inputs(inputs, bounding_box_format):
423
image = inputs["image"]
424
boxes = keras_cv.bounding_box.convert_format(
425
inputs["objects"]["bbox"],
426
images=image,
427
source="rel_yxyx",
428
target=bounding_box_format,
429
)
430
bounding_boxes = {
431
"classes": inputs["objects"]["label"],
432
"boxes": boxes,
433
}
434
return {"images": image, "bounding_boxes": bounding_boxes}
435
436
437
def load_pascal_voc(split, dataset, bounding_box_format):
438
ds = tfds.load(dataset, split=split, with_info=False, shuffle_files=True)
439
ds = ds.map(
440
lambda x: unpackage_raw_tfds_inputs(x, bounding_box_format=bounding_box_format),
441
num_parallel_calls=tf_data.AUTOTUNE,
442
)
443
return ds
444
445
446
train_ds = load_pascal_voc(
447
split="train", dataset="voc/2007", bounding_box_format="xywh"
448
)
449
eval_ds = load_pascal_voc(split="test", dataset="voc/2007", bounding_box_format="xywh")
450
451
train_ds = train_ds.shuffle(BATCH_SIZE * 4)
452
453
"""
454
Next, let's batch our data.
455
456
In KerasCV object detection tasks it is recommended that
457
users use ragged batches of inputs.
458
This is due to the fact that images may be of different sizes in PascalVOC,
459
as well as the fact that there may be different numbers of bounding boxes per
460
image.
461
462
To construct a ragged dataset in a `tf.data` pipeline, you can use the
463
`ragged_batch()` method.
464
"""
465
466
train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
467
eval_ds = eval_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
468
469
"""
470
Let's make sure our dataset is following the format KerasCV expects.
471
By using the `visualize_dataset()` function, you can visually verify
472
that your data is in the format that KerasCV expects. If the bounding boxes
473
are not visible or are visible in the wrong locations that is a sign that your
474
data is mis-formatted.
475
"""
476
477
visualize_dataset(
478
train_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2
479
)
480
481
"""
482
And for the eval set:
483
"""
484
485
visualize_dataset(
486
eval_ds,
487
bounding_box_format="xywh",
488
value_range=(0, 255),
489
rows=2,
490
cols=2,
491
# If you are not running your experiment on a local machine, you can also
492
# make `visualize_dataset()` dump the plot to a file using `path`:
493
# path="eval.png"
494
)
495
496
"""
497
Looks like everything is structured as expected.
498
Now we can move on to constructing our
499
data augmentation pipeline.
500
501
## Data augmentation
502
503
One of the most challenging tasks when constructing object detection
504
pipelines is data augmentation. Image augmentation techniques must be aware of the underlying
505
bounding boxes, and must update them accordingly.
506
507
Luckily, KerasCV natively supports bounding box augmentation with its extensive
508
library
509
of [data augmentation layers](https://keras.io/api/keras_cv/layers/preprocessing/).
510
The code below loads the Pascal VOC dataset, and performs on-the-fly,
511
bounding-box-friendly data augmentation inside a `tf.data` pipeline.
512
"""
513
514
augmenters = [
515
keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),
516
keras_cv.layers.JitteredResize(
517
target_size=(640, 640), scale_factor=(0.75, 1.3), bounding_box_format="xywh"
518
),
519
]
520
521
522
def create_augmenter_fn(augmenters):
523
def augmenter_fn(inputs):
524
for augmenter in augmenters:
525
inputs = augmenter(inputs)
526
return inputs
527
528
return augmenter_fn
529
530
531
augmenter_fn = create_augmenter_fn(augmenters)
532
533
train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf_data.AUTOTUNE)
534
visualize_dataset(
535
train_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2
536
)
537
538
"""
539
Great! We now have a bounding-box-friendly data augmentation pipeline.
540
Let's format our evaluation dataset to match. Instead of using
541
`JitteredResize`, let's use the deterministic `keras_cv.layers.Resizing()`
542
layer.
543
"""
544
545
inference_resizing = keras_cv.layers.Resizing(
546
640, 640, bounding_box_format="xywh", pad_to_aspect_ratio=True
547
)
548
eval_ds = eval_ds.map(inference_resizing, num_parallel_calls=tf_data.AUTOTUNE)
549
550
"""
551
Due to the fact that the resize operation differs between the train dataset,
552
which uses `JitteredResize()` to resize images, and the inference dataset, which
553
uses `layers.Resizing(pad_to_aspect_ratio=True)`, it is good practice to
554
visualize both datasets:
555
"""
556
557
visualize_dataset(
558
eval_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2
559
)
560
561
"""
562
Finally, let's unpackage our inputs from the preprocessing dictionary, and
563
prepare to feed the inputs into our model. In order to be TPU compatible,
564
bounding box Tensors need to be `Dense` instead of `Ragged`.
565
"""
566
567
568
def dict_to_tuple(inputs):
569
return inputs["images"], bounding_box.to_dense(
570
inputs["bounding_boxes"], max_boxes=32
571
)
572
573
574
train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf_data.AUTOTUNE)
575
eval_ds = eval_ds.map(dict_to_tuple, num_parallel_calls=tf_data.AUTOTUNE)
576
577
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
578
eval_ds = eval_ds.prefetch(tf_data.AUTOTUNE)
579
580
"""
581
582
### Optimizer
583
584
In this guide, we use a standard SGD optimizer and rely on the
585
[`keras.callbacks.ReduceLROnPlateau`](https://keras.io/api/callbacks/reduce_lr_on_plateau/)
586
callback to reduce the learning rate.
587
588
You will always want to include a `global_clipnorm` when training object
589
detection models. This is to remedy exploding gradient problems that frequently
590
occur when training object detection models.
591
"""
592
593
base_lr = 0.005
594
# including a global_clipnorm is extremely important in object detection tasks
595
optimizer = keras.optimizers.SGD(
596
learning_rate=base_lr, momentum=0.9, global_clipnorm=10.0
597
)
598
599
"""
600
To achieve the best results on your dataset, you'll likely want to hand craft a
601
`PiecewiseConstantDecay` learning rate schedule.
602
While `PiecewiseConstantDecay` schedules tend to perform better, they don't
603
translate between problems.
604
"""
605
606
"""
607
### Loss functions
608
609
You may not be familiar with the `"ciou"` loss. While not common in other
610
models, this loss is sometimes used in the object detection world.
611
612
In short, ["Complete IoU"](https://arxiv.org/abs/1911.08287) is a flavour of the Intersection over Union loss and is used due to its convergence properties.
613
614
In KerasCV, you can use this loss simply by passing the string `"ciou"` to `compile()`.
615
We also use standard binary crossentropy loss for the class head.
616
"""
617
618
pretrained_model.compile(
619
classification_loss="binary_crossentropy",
620
box_loss="ciou",
621
)
622
623
"""
624
### Metric evaluation
625
626
The most popular object detection metrics are COCO metrics,
627
which were published alongside the MSCOCO dataset. KerasCV provides an
628
easy-to-use suite of COCO metrics under the `keras_cv.callbacks.PyCOCOCallback`
629
symbol. Note that we use a Keras callback instead of a Keras metric to compute
630
COCO metrics. This is because computing COCO metrics requires storing all of a
631
model's predictions for the entire evaluation dataset in memory at once, which
632
is impractical to do during training time.
633
"""
634
635
coco_metrics_callback = keras_cv.callbacks.PyCOCOCallback(
636
eval_ds.take(20), bounding_box_format="xywh"
637
)
638
639
640
"""
641
Our data pipeline is now complete!
642
We can now move on to model creation and training.
643
644
## Model creation
645
646
Next, let's use the KerasCV API to construct an untrained YOLOV8Detector model.
647
In this tutorial we use a pretrained ResNet50 backbone from the imagenet
648
dataset.
649
650
KerasCV makes it easy to construct a `YOLOV8Detector` with any of the KerasCV
651
backbones. Simply use one of the presets for the architecture you'd like!
652
653
For example:
654
"""
655
656
model = keras_cv.models.YOLOV8Detector.from_preset(
657
"resnet50_imagenet",
658
# For more info on supported bounding box formats, visit
659
# https://keras.io/api/keras_cv/bounding_box/
660
bounding_box_format="xywh",
661
num_classes=20,
662
)
663
664
"""
665
That is all it takes to construct a KerasCV YOLOv8. The YOLOv8 accepts
666
tuples of dense image Tensors and bounding box dictionaries to `fit()` and
667
`train_on_batch()`
668
669
This matches what we have constructed in our input pipeline above.
670
"""
671
672
673
"""
674
## Training our model
675
676
All that is left to do is train our model. KerasCV object detection models
677
follow the standard Keras workflow, leveraging `compile()` and `fit()`.
678
679
Let's compile our model:
680
"""
681
model.compile(
682
classification_loss="binary_crossentropy",
683
box_loss="ciou",
684
optimizer=optimizer,
685
)
686
"""
687
If you want to fully train the model, remove `.take(20)` from all dataset
688
references (below and in the initialization of the metrics callback).
689
"""
690
model.fit(
691
train_ds.take(20),
692
# Run for 10-35~ epochs to achieve good scores.
693
epochs=1,
694
callbacks=[coco_metrics_callback],
695
)
696
"""
697
698
## Inference and plotting results
699
700
KerasCV makes object detection inference simple. `model.predict(images)`
701
returns a tensor of bounding boxes. By default, `YOLOV8Detector.predict()`
702
will perform a non max suppression operation for you.
703
704
In this section, we will use a `keras_cv` provided preset:
705
"""
706
model = keras_cv.models.YOLOV8Detector.from_preset(
707
"yolo_v8_m_pascalvoc", bounding_box_format="xywh"
708
)
709
710
"""
711
Next, for convenience we construct a dataset with larger batches:
712
"""
713
visualization_ds = eval_ds.unbatch()
714
visualization_ds = visualization_ds.ragged_batch(16)
715
visualization_ds = visualization_ds.shuffle(8)
716
"""
717
Let's create a simple function to plot our inferences:
718
"""
719
720
721
def visualize_detections(model, dataset, bounding_box_format):
722
images, y_true = next(iter(dataset.take(1)))
723
y_pred = model.predict(images)
724
visualization.plot_bounding_box_gallery(
725
images,
726
value_range=(0, 255),
727
bounding_box_format=bounding_box_format,
728
y_true=y_true,
729
y_pred=y_pred,
730
scale=4,
731
rows=2,
732
cols=2,
733
show=True,
734
font_scale=0.7,
735
class_mapping=class_mapping,
736
)
737
738
739
"""
740
You may need to configure your NonMaxSuppression operation to achieve
741
visually appealing results.
742
"""
743
744
model.prediction_decoder = keras_cv.layers.NonMaxSuppression(
745
bounding_box_format="xywh",
746
from_logits=True,
747
iou_threshold=0.5,
748
confidence_threshold=0.75,
749
)
750
751
visualize_detections(model, dataset=visualization_ds, bounding_box_format="xywh")
752
753
"""
754
Awesome!
755
One final helpful pattern to be aware of is to visualize
756
detections in a `keras.callbacks.Callback` to monitor training :
757
"""
758
759
760
class VisualizeDetections(keras.callbacks.Callback):
761
def on_epoch_end(self, epoch, logs):
762
visualize_detections(
763
self.model, bounding_box_format="xywh", dataset=visualization_ds
764
)
765
766
767
"""
768
## Takeaways and next steps
769
770
KerasCV makes it easy to construct state-of-the-art object detection pipelines.
771
In this guide, we started off by writing a data loader using the KerasCV
772
bounding box specification.
773
Following this, we assembled a production grade data augmentation pipeline using
774
KerasCV preprocessing layers in <50 lines of code.
775
776
KerasCV object detection components can be used independently, but also have deep
777
integration with each other.
778
KerasCV makes authoring production grade bounding box augmentation,
779
model training, visualization, and
780
metric evaluation easy.
781
782
Some follow up exercises for the reader:
783
784
- add additional augmentation techniques to improve model performance
785
- tune the hyperparameters and data augmentation used to produce high quality results
786
- train an object detection model on your own dataset
787
788
One last fun code snippet to showcase the power of KerasCV's API!
789
"""
790
791
stable_diffusion = keras_cv.models.StableDiffusionV2(512, 512)
792
images = stable_diffusion.text_to_image(
793
prompt="A zoomed out photograph of a cool looking cat. The cat stands in a beautiful forest",
794
negative_prompt="unrealistic, bad looking, malformed",
795
batch_size=4,
796
seed=1231,
797
)
798
encoded_predictions = model(images)
799
y_pred = model.decode_predictions(encoded_predictions, images)
800
visualization.plot_bounding_box_gallery(
801
images,
802
value_range=(0, 255),
803
y_pred=y_pred,
804
rows=2,
805
cols=2,
806
scale=5,
807
font_scale=0.7,
808
bounding_box_format="xywh",
809
class_mapping=class_mapping,
810
)
811
812