Path: blob/master/guides/keras_cv/object_detection_keras_cv.py
3283 views
"""1Title: Object Detection with KerasCV2Author: [lukewood](https://twitter.com/luke_wood_ml), Ian Stenbit, Tirth Patel3Date created: 2023/04/084Last modified: 2023/08/105Description: Train an object detection model with KerasCV.6Accelerator: GPU7"""89"""10KerasCV offers a complete set of production grade APIs to solve object detection11problems.12These APIs include object-detection-specific13data augmentation techniques, Keras native COCO metrics, bounding box format14conversion utilities, visualization tools, pretrained object detection models,15and everything you need to train your own state of the art object detection16models!1718Let's give KerasCV's object detection API a spin.19"""2021"""shell22pip install -q --upgrade keras-cv23pip install -q --upgrade keras # Upgrade to Keras 3.24"""2526import os2728os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]2930from tensorflow import data as tf_data31import tensorflow_datasets as tfds32import keras33import keras_cv34import numpy as np35from keras_cv import bounding_box36import os37from keras_cv import visualization38import tqdm3940"""41## Object detection introduction4243Object detection is the process of identifying, classifying,44and localizing objects within a given image. Typically, your inputs are45images, and your labels are bounding boxes with optional class46labels.47Object detection can be thought of as an extension of classification, however48instead of one class label for the image, you must detect and localize an49arbitrary number of classes.5051**For example:**5253<img width="300" src="https://i.imgur.com/8xSEbQD.png">5455The data for the above image may look something like this:56```python57image = [height, width, 3]58bounding_boxes = {59"classes": [0], # 0 is an arbitrary class ID representing "cat"60"boxes": [[0.25, 0.4, .15, .1]]61# bounding box is in "rel_xywh" format62# so 0.25 represents the start of the bounding box 25% of63# the way across the image.64# The .15 represents that the width is 15% of the image width.65}66```6768Since the inception of [*You Only Look Once*](https://arxiv.org/abs/1506.02640)69(aka YOLO),70object detection has primarily been solved using deep learning.71Most deep learning architectures do this by cleverly framing the object detection72problem as a combination of many small classification problems and73many regression problems.7475More specifically, this is done by generating many anchor boxes of varying76shapes and sizes across the input images and assigning them each a class label,77as well as `x`, `y`, `width` and `height` offsets.78The model is trained to predict the class labels of each box, as well as the79`x`, `y`, `width`, and `height` offsets of each box that is predicted to be an80object.8182**Visualization of some sample anchor boxes**:8384<img width="400" src="https://i.imgur.com/cJIuiK9.jpg">8586Objection detection is a technically complex problem but luckily we offer a87bulletproof approach to getting great results.88Let's do this!89"""9091"""92## Perform detections with a pretrained model93949596The highest level API in the KerasCV Object Detection API is the `keras_cv.models` API.97This API includes fully pretrained object detection models, such as98`keras_cv.models.YOLOV8Detector`.99100Let's get started by constructing a YOLOV8Detector pretrained on the `pascalvoc`101dataset.102"""103104pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(105"yolo_v8_m_pascalvoc", bounding_box_format="xywh"106)107108"""109Notice the `bounding_box_format` argument?110111Recall in the section above, the format of bounding boxes:112113```114bounding_boxes = {115"classes": [num_boxes],116"boxes": [num_boxes, 4]117}118```119120This argument describes *exactly* what format the values in the `"boxes"`121field of the label dictionary take in your pipeline.122For example, a box in `xywh` format with its top left corner at the coordinates123(100, 100) with a width of 55 and a height of 70 would be represented by:124```125[100, 100, 55, 75]126```127128or equivalently in `xyxy` format:129130```131[100, 100, 155, 175]132```133134While this may seem simple, it is a critical piece of the KerasCV object135detection API!136Every component that processes bounding boxes requires a137`bounding_box_format` argument.138You can read more about139KerasCV bounding box formats [in the API docs](https://keras.io/api/keras_cv/bounding_box/formats/).140141142This is done because there is no one correct format for bounding boxes!143Components in different pipelines expect different formats, and so by requiring144them to be specified we ensure that our components remain readable, reusable,145and clear.146Box format conversion bugs are perhaps the most common bug surface in object147detection pipelines - by requiring this parameter we mitigate against these148bugs (especially when combining code from many sources).149150Next let's load an image:151"""152153filepath = keras.utils.get_file(origin="https://i.imgur.com/gCNcJJI.jpg")154image = keras.utils.load_img(filepath)155image = np.array(image)156157visualization.plot_image_gallery(158np.array([image]),159value_range=(0, 255),160rows=1,161cols=1,162scale=5,163)164165"""166To use the `YOLOV8Detector` architecture with a ResNet50 backbone, you'll need to167resize your image to a size that is divisible by 64. This is to ensure168compatibility with the number of downscaling operations done by the convolution169layers in the ResNet.170171If the resize operation distorts172the input's aspect ratio, the model will perform signficantly poorer. For the173pretrained `"yolo_v8_m_pascalvoc"` preset we are using, the final174`MeanAveragePrecision` on the `pascalvoc/2012` evaluation set drops to `0.15`175from `0.38` when using a naive resizing operation.176177Additionally, if you crop to preserve the aspect ratio as you do in classification178your model may entirely miss some bounding boxes. As such, when running inference179on an object detection model we recommend the use of padding to the desired size,180while resizing the longest size to match the aspect ratio.181182KerasCV makes resizing properly easy; simply pass `pad_to_aspect_ratio=True` to183a `keras_cv.layers.Resizing` layer.184185This can be implemented in one line of code:186"""187188inference_resizing = keras_cv.layers.Resizing(189640, 640, pad_to_aspect_ratio=True, bounding_box_format="xywh"190)191192"""193This can be used as our inference preprocessing pipeline:194"""195196image_batch = inference_resizing([image])197198"""199`keras_cv.visualization.plot_bounding_box_gallery()` supports a `class_mapping`200parameter to highlight what class each box was assigned to. Let's assemble a201class mapping now.202"""203204class_ids = [205"Aeroplane",206"Bicycle",207"Bird",208"Boat",209"Bottle",210"Bus",211"Car",212"Cat",213"Chair",214"Cow",215"Dining Table",216"Dog",217"Horse",218"Motorbike",219"Person",220"Potted Plant",221"Sheep",222"Sofa",223"Train",224"Tvmonitor",225"Total",226]227class_mapping = dict(zip(range(len(class_ids)), class_ids))228229"""230Just like any other `keras.Model` you can predict bounding boxes using the231`model.predict()` API.232"""233234y_pred = pretrained_model.predict(image_batch)235# y_pred is a bounding box Tensor:236# {"classes": ..., boxes": ...}237visualization.plot_bounding_box_gallery(238image_batch,239value_range=(0, 255),240rows=1,241cols=1,242y_pred=y_pred,243scale=5,244font_scale=0.7,245bounding_box_format="xywh",246class_mapping=class_mapping,247)248249"""250In order to support this easy and intuitive inference workflow, KerasCV251performs non-max suppression inside of the `YOLOV8Detector` class.252Non-max suppression is a traditional computing algorithm that solves the problem253of a model detecting multiple boxes for the same object.254255Non-max suppression is a highly configurable algorithm, and in most cases you256will want to customize the settings of your model's non-max257suppression operation.258This can be done by overriding to the `prediction_decoder` argument.259260To show this concept off, let's temporarily disable non-max suppression on our261YOLOV8Detector. This can be done by writing to the `prediction_decoder` attribute.262"""263264# The following NonMaxSuppression layer is equivalent to disabling the operation265prediction_decoder = keras_cv.layers.NonMaxSuppression(266bounding_box_format="xywh",267from_logits=True,268iou_threshold=1.0,269confidence_threshold=0.0,270)271pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(272"yolo_v8_m_pascalvoc",273bounding_box_format="xywh",274prediction_decoder=prediction_decoder,275)276277y_pred = pretrained_model.predict(image_batch)278visualization.plot_bounding_box_gallery(279image_batch,280value_range=(0, 255),281rows=1,282cols=1,283y_pred=y_pred,284scale=5,285font_scale=0.7,286bounding_box_format="xywh",287class_mapping=class_mapping,288)289290291"""292Next, let's re-configure `keras_cv.layers.NonMaxSuppression` for our293use case!294In this case, we will tune the `iou_threshold` to `0.2`, and the295`confidence_threshold` to `0.7`.296297Raising the `confidence_threshold` will cause the model to only output boxes298that have a higher confidence score. `iou_threshold` controls the threshold of299intersection over union (IoU) that two boxes must have in order for one to be300pruned out.301[More information on these parameters may be found in the TensorFlow API docs](https://www.tensorflow.org/api_docs/python/tf/image/combined_non_max_suppression)302"""303304prediction_decoder = keras_cv.layers.NonMaxSuppression(305bounding_box_format="xywh",306from_logits=True,307# Decrease the required threshold to make predictions get pruned out308iou_threshold=0.2,309# Tune confidence threshold for predictions to pass NMS310confidence_threshold=0.7,311)312pretrained_model = keras_cv.models.YOLOV8Detector.from_preset(313"yolo_v8_m_pascalvoc",314bounding_box_format="xywh",315prediction_decoder=prediction_decoder,316)317318y_pred = pretrained_model.predict(image_batch)319visualization.plot_bounding_box_gallery(320image_batch,321value_range=(0, 255),322rows=1,323cols=1,324y_pred=y_pred,325scale=5,326font_scale=0.7,327bounding_box_format="xywh",328class_mapping=class_mapping,329)330331"""332That looks a lot better!333334## Train a custom object detection model335336337338Whether you're an object detection amateur or a well seasoned veteran, assembling339an object detection pipeline from scratch is a massive undertaking.340Luckily, all KerasCV object detection APIs are built as modular components.341Whether you need a complete pipeline, just an object detection model, or even342just a conversion utility to transform your boxes from `xywh` format to `xyxy`,343KerasCV has you covered.344345In this guide, we'll assemble a full training pipeline for a KerasCV object346detection model. This includes data loading, augmentation, metric evaluation,347and inference!348349To get started, let's sort out all of our imports and define global350configuration parameters.351"""352353BATCH_SIZE = 4354355"""356## Data loading357358To get started, let's discuss data loading and bounding box formatting.359KerasCV has a predefined format for bounding boxes.360To comply with this, you361should package your bounding boxes into a dictionary matching the362specification below:363364```365bounding_boxes = {366# num_boxes may be a Ragged dimension367'boxes': Tensor(shape=[batch, num_boxes, 4]),368'classes': Tensor(shape=[batch, num_boxes])369}370```371372`bounding_boxes['boxes']` contains the coordinates of your bounding box in a KerasCV373supported `bounding_box_format`.374KerasCV requires a `bounding_box_format` argument in all components that process375bounding boxes.376This is done to maximize your ability to plug and play individual components377into their object detection pipelines, as well as to make code self-documenting378across object detection pipelines.379380To match the KerasCV API style, it is recommended that when writing a381custom data loader, you also support a `bounding_box_format` argument.382This makes it clear to those invoking your data loader what format the bounding boxes383are in.384In this example, we format our boxes to `xywh` format.385386For example:387388```python389train_ds, ds_info = your_data_loader.load(390split='train', bounding_box_format='xywh', batch_size=8391)392```393394This clearly yields bounding boxes in the format `xywh`. You can read more about395KerasCV bounding box formats [in the API docs](https://keras.io/api/keras_cv/bounding_box/formats/).396397Our data comes loaded into the format398`{"images": images, "bounding_boxes": bounding_boxes}`. This format is399supported in all KerasCV preprocessing components.400401Let's load some data and verify that the data looks as we expect it to.402"""403404405def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):406inputs = next(iter(inputs.take(1)))407images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]408visualization.plot_bounding_box_gallery(409images,410value_range=value_range,411rows=rows,412cols=cols,413y_true=bounding_boxes,414scale=5,415font_scale=0.7,416bounding_box_format=bounding_box_format,417class_mapping=class_mapping,418)419420421def unpackage_raw_tfds_inputs(inputs, bounding_box_format):422image = inputs["image"]423boxes = keras_cv.bounding_box.convert_format(424inputs["objects"]["bbox"],425images=image,426source="rel_yxyx",427target=bounding_box_format,428)429bounding_boxes = {430"classes": inputs["objects"]["label"],431"boxes": boxes,432}433return {"images": image, "bounding_boxes": bounding_boxes}434435436def load_pascal_voc(split, dataset, bounding_box_format):437ds = tfds.load(dataset, split=split, with_info=False, shuffle_files=True)438ds = ds.map(439lambda x: unpackage_raw_tfds_inputs(x, bounding_box_format=bounding_box_format),440num_parallel_calls=tf_data.AUTOTUNE,441)442return ds443444445train_ds = load_pascal_voc(446split="train", dataset="voc/2007", bounding_box_format="xywh"447)448eval_ds = load_pascal_voc(split="test", dataset="voc/2007", bounding_box_format="xywh")449450train_ds = train_ds.shuffle(BATCH_SIZE * 4)451452"""453Next, let's batch our data.454455In KerasCV object detection tasks it is recommended that456users use ragged batches of inputs.457This is due to the fact that images may be of different sizes in PascalVOC,458as well as the fact that there may be different numbers of bounding boxes per459image.460461To construct a ragged dataset in a `tf.data` pipeline, you can use the462`ragged_batch()` method.463"""464465train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)466eval_ds = eval_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)467468"""469Let's make sure our dataset is following the format KerasCV expects.470By using the `visualize_dataset()` function, you can visually verify471that your data is in the format that KerasCV expects. If the bounding boxes472are not visible or are visible in the wrong locations that is a sign that your473data is mis-formatted.474"""475476visualize_dataset(477train_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2478)479480"""481And for the eval set:482"""483484visualize_dataset(485eval_ds,486bounding_box_format="xywh",487value_range=(0, 255),488rows=2,489cols=2,490# If you are not running your experiment on a local machine, you can also491# make `visualize_dataset()` dump the plot to a file using `path`:492# path="eval.png"493)494495"""496Looks like everything is structured as expected.497Now we can move on to constructing our498data augmentation pipeline.499500## Data augmentation501502One of the most challenging tasks when constructing object detection503pipelines is data augmentation. Image augmentation techniques must be aware of the underlying504bounding boxes, and must update them accordingly.505506Luckily, KerasCV natively supports bounding box augmentation with its extensive507library508of [data augmentation layers](https://keras.io/api/keras_cv/layers/preprocessing/).509The code below loads the Pascal VOC dataset, and performs on-the-fly,510bounding-box-friendly data augmentation inside a `tf.data` pipeline.511"""512513augmenters = [514keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),515keras_cv.layers.JitteredResize(516target_size=(640, 640), scale_factor=(0.75, 1.3), bounding_box_format="xywh"517),518]519520521def create_augmenter_fn(augmenters):522def augmenter_fn(inputs):523for augmenter in augmenters:524inputs = augmenter(inputs)525return inputs526527return augmenter_fn528529530augmenter_fn = create_augmenter_fn(augmenters)531532train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf_data.AUTOTUNE)533visualize_dataset(534train_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2535)536537"""538Great! We now have a bounding-box-friendly data augmentation pipeline.539Let's format our evaluation dataset to match. Instead of using540`JitteredResize`, let's use the deterministic `keras_cv.layers.Resizing()`541layer.542"""543544inference_resizing = keras_cv.layers.Resizing(545640, 640, bounding_box_format="xywh", pad_to_aspect_ratio=True546)547eval_ds = eval_ds.map(inference_resizing, num_parallel_calls=tf_data.AUTOTUNE)548549"""550Due to the fact that the resize operation differs between the train dataset,551which uses `JitteredResize()` to resize images, and the inference dataset, which552uses `layers.Resizing(pad_to_aspect_ratio=True)`, it is good practice to553visualize both datasets:554"""555556visualize_dataset(557eval_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2558)559560"""561Finally, let's unpackage our inputs from the preprocessing dictionary, and562prepare to feed the inputs into our model. In order to be TPU compatible,563bounding box Tensors need to be `Dense` instead of `Ragged`.564"""565566567def dict_to_tuple(inputs):568return inputs["images"], bounding_box.to_dense(569inputs["bounding_boxes"], max_boxes=32570)571572573train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf_data.AUTOTUNE)574eval_ds = eval_ds.map(dict_to_tuple, num_parallel_calls=tf_data.AUTOTUNE)575576train_ds = train_ds.prefetch(tf_data.AUTOTUNE)577eval_ds = eval_ds.prefetch(tf_data.AUTOTUNE)578579"""580581### Optimizer582583In this guide, we use a standard SGD optimizer and rely on the584[`keras.callbacks.ReduceLROnPlateau`](https://keras.io/api/callbacks/reduce_lr_on_plateau/)585callback to reduce the learning rate.586587You will always want to include a `global_clipnorm` when training object588detection models. This is to remedy exploding gradient problems that frequently589occur when training object detection models.590"""591592base_lr = 0.005593# including a global_clipnorm is extremely important in object detection tasks594optimizer = keras.optimizers.SGD(595learning_rate=base_lr, momentum=0.9, global_clipnorm=10.0596)597598"""599To achieve the best results on your dataset, you'll likely want to hand craft a600`PiecewiseConstantDecay` learning rate schedule.601While `PiecewiseConstantDecay` schedules tend to perform better, they don't602translate between problems.603"""604605"""606### Loss functions607608You may not be familiar with the `"ciou"` loss. While not common in other609models, this loss is sometimes used in the object detection world.610611In short, ["Complete IoU"](https://arxiv.org/abs/1911.08287) is a flavour of the Intersection over Union loss and is used due to its convergence properties.612613In KerasCV, you can use this loss simply by passing the string `"ciou"` to `compile()`.614We also use standard binary crossentropy loss for the class head.615"""616617pretrained_model.compile(618classification_loss="binary_crossentropy",619box_loss="ciou",620)621622"""623### Metric evaluation624625The most popular object detection metrics are COCO metrics,626which were published alongside the MSCOCO dataset. KerasCV provides an627easy-to-use suite of COCO metrics under the `keras_cv.callbacks.PyCOCOCallback`628symbol. Note that we use a Keras callback instead of a Keras metric to compute629COCO metrics. This is because computing COCO metrics requires storing all of a630model's predictions for the entire evaluation dataset in memory at once, which631is impractical to do during training time.632"""633634coco_metrics_callback = keras_cv.callbacks.PyCOCOCallback(635eval_ds.take(20), bounding_box_format="xywh"636)637638639"""640Our data pipeline is now complete!641We can now move on to model creation and training.642643## Model creation644645Next, let's use the KerasCV API to construct an untrained YOLOV8Detector model.646In this tutorial we use a pretrained ResNet50 backbone from the imagenet647dataset.648649KerasCV makes it easy to construct a `YOLOV8Detector` with any of the KerasCV650backbones. Simply use one of the presets for the architecture you'd like!651652For example:653"""654655model = keras_cv.models.YOLOV8Detector.from_preset(656"resnet50_imagenet",657# For more info on supported bounding box formats, visit658# https://keras.io/api/keras_cv/bounding_box/659bounding_box_format="xywh",660num_classes=20,661)662663"""664That is all it takes to construct a KerasCV YOLOv8. The YOLOv8 accepts665tuples of dense image Tensors and bounding box dictionaries to `fit()` and666`train_on_batch()`667668This matches what we have constructed in our input pipeline above.669"""670671672"""673## Training our model674675All that is left to do is train our model. KerasCV object detection models676follow the standard Keras workflow, leveraging `compile()` and `fit()`.677678Let's compile our model:679"""680model.compile(681classification_loss="binary_crossentropy",682box_loss="ciou",683optimizer=optimizer,684)685"""686If you want to fully train the model, remove `.take(20)` from all dataset687references (below and in the initialization of the metrics callback).688"""689model.fit(690train_ds.take(20),691# Run for 10-35~ epochs to achieve good scores.692epochs=1,693callbacks=[coco_metrics_callback],694)695"""696697## Inference and plotting results698699KerasCV makes object detection inference simple. `model.predict(images)`700returns a tensor of bounding boxes. By default, `YOLOV8Detector.predict()`701will perform a non max suppression operation for you.702703In this section, we will use a `keras_cv` provided preset:704"""705model = keras_cv.models.YOLOV8Detector.from_preset(706"yolo_v8_m_pascalvoc", bounding_box_format="xywh"707)708709"""710Next, for convenience we construct a dataset with larger batches:711"""712visualization_ds = eval_ds.unbatch()713visualization_ds = visualization_ds.ragged_batch(16)714visualization_ds = visualization_ds.shuffle(8)715"""716Let's create a simple function to plot our inferences:717"""718719720def visualize_detections(model, dataset, bounding_box_format):721images, y_true = next(iter(dataset.take(1)))722y_pred = model.predict(images)723visualization.plot_bounding_box_gallery(724images,725value_range=(0, 255),726bounding_box_format=bounding_box_format,727y_true=y_true,728y_pred=y_pred,729scale=4,730rows=2,731cols=2,732show=True,733font_scale=0.7,734class_mapping=class_mapping,735)736737738"""739You may need to configure your NonMaxSuppression operation to achieve740visually appealing results.741"""742743model.prediction_decoder = keras_cv.layers.NonMaxSuppression(744bounding_box_format="xywh",745from_logits=True,746iou_threshold=0.5,747confidence_threshold=0.75,748)749750visualize_detections(model, dataset=visualization_ds, bounding_box_format="xywh")751752"""753Awesome!754One final helpful pattern to be aware of is to visualize755detections in a `keras.callbacks.Callback` to monitor training :756"""757758759class VisualizeDetections(keras.callbacks.Callback):760def on_epoch_end(self, epoch, logs):761visualize_detections(762self.model, bounding_box_format="xywh", dataset=visualization_ds763)764765766"""767## Takeaways and next steps768769KerasCV makes it easy to construct state-of-the-art object detection pipelines.770In this guide, we started off by writing a data loader using the KerasCV771bounding box specification.772Following this, we assembled a production grade data augmentation pipeline using773KerasCV preprocessing layers in <50 lines of code.774775KerasCV object detection components can be used independently, but also have deep776integration with each other.777KerasCV makes authoring production grade bounding box augmentation,778model training, visualization, and779metric evaluation easy.780781Some follow up exercises for the reader:782783- add additional augmentation techniques to improve model performance784- tune the hyperparameters and data augmentation used to produce high quality results785- train an object detection model on your own dataset786787One last fun code snippet to showcase the power of KerasCV's API!788"""789790stable_diffusion = keras_cv.models.StableDiffusionV2(512, 512)791images = stable_diffusion.text_to_image(792prompt="A zoomed out photograph of a cool looking cat. The cat stands in a beautiful forest",793negative_prompt="unrealistic, bad looking, malformed",794batch_size=4,795seed=1231,796)797encoded_predictions = model(images)798y_pred = model.decode_predictions(encoded_predictions, images)799visualization.plot_bounding_box_gallery(800images,801value_range=(0, 255),802y_pred=y_pred,803rows=2,804cols=2,805scale=5,806font_scale=0.7,807bounding_box_format="xywh",808class_mapping=class_mapping,809)810811812