Path: blob/master/guides/keras_hub/segment_anything_in_keras_hub.py
3293 views
"""1Title: Segment Anything in KerasHub!2Author: Tirth Patel, Ian Stenbit, Divyashree Sreepathihalli<br>3Date created: 2024/10/1<br>4Last modified: 2024/10/1<br>5Description: Segment anything using text, box, and points prompts in KerasHub.6Accelerator: GPU7"""89"""10## Overview1112The Segment Anything Model (SAM) produces high quality object masks from input prompts13such as points or boxes, and it can be used to generate masks for all objects in an14image. It has been trained on a15[dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.116billion masks, and has strong zero-shot performance on a variety of segmentation tasks.1718In this guide, we will show how to use KerasHub's implementation of the19[Segment Anything Model](https://github.com/facebookresearch/segment-anything)20and show how powerful TensorFlow's and JAX's performance boost is.2122First, let's get all our dependencies and images for our demo.23"""2425"""shell26!pip install -Uq git+https://github.com/keras-team/keras-hub.git27!pip install -Uq keras28"""2930"""shell31!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg32"""3334"""35## Choose your backend3637With Keras 3, you can choose to use your favorite backend!38"""3940import os4142os.environ["KERAS_BACKEND"] = "jax"4344import timeit45import numpy as np46import matplotlib.pyplot as plt47import keras48from keras import ops49import keras_hub5051"""52## Helper functions5354Let's define some helper functions for visulazing the images, prompts, and the55segmentation results.56"""575859def show_mask(mask, ax, random_color=False):60if random_color:61color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)62else:63color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])64h, w = mask.shape[-2:]65mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)66ax.imshow(mask_image)676869def show_points(coords, labels, ax, marker_size=375):70pos_points = coords[labels == 1]71neg_points = coords[labels == 0]72ax.scatter(73pos_points[:, 0],74pos_points[:, 1],75color="green",76marker="*",77s=marker_size,78edgecolor="white",79linewidth=1.25,80)81ax.scatter(82neg_points[:, 0],83neg_points[:, 1],84color="red",85marker="*",86s=marker_size,87edgecolor="white",88linewidth=1.25,89)909192def show_box(box, ax):93box = box.reshape(-1)94x0, y0 = box[0], box[1]95w, h = box[2] - box[0], box[3] - box[1]96ax.add_patch(97plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)98)99100101def inference_resizing(image, pad=True):102# Compute Preprocess Shape103image = ops.cast(image, dtype="float32")104old_h, old_w = image.shape[0], image.shape[1]105scale = 1024 * 1.0 / max(old_h, old_w)106new_h = old_h * scale107new_w = old_w * scale108preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)109110# Resize the image111image = ops.image.resize(image[None, ...], preprocess_shape)[0]112113# Pad the shorter side114if pad:115pixel_mean = ops.array([123.675, 116.28, 103.53])116pixel_std = ops.array([58.395, 57.12, 57.375])117image = (image - pixel_mean) / pixel_std118h, w = image.shape[0], image.shape[1]119pad_h = 1024 - h120pad_w = 1024 - w121image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])122# KerasHub now rescales the images and normalizes them.123# Just unnormalize such that when KerasHub normalizes them124# again, the padded values map to 0.125image = image * pixel_std + pixel_mean126return image127128129"""130## Get the pretrained SAM model131132We can initialize a trained SAM model using KerasHub's `from_preset` factory method. Here,133we use the huge ViT backbone trained on the SA-1B dataset (`sam_huge_sa1b`) for134high-quality segmentation masks. You can also use one of the `sam_large_sa1b` or135`sam_base_sa1b` for better performance (at the cost of decreasing quality of segmentation136masks).137"""138139model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")140141"""142## Understanding Prompts143144Segment Anything allows prompting an image using points, boxes, and masks:1451461. Point prompts are the most basic of all: the model tries to guess the object given a147point on an image. The point can either be a foreground point (i.e. the desired148segmentation mask contains the point in it) or a backround point (i.e. the point lies149outside the desired mask).1502. Another way to prompt the model is using boxes. Given a bounding box, the model tries151to segment the object contained in it.1523. Finally, the model can also be prompted using a mask itself. This is useful, for153instance, to refine the borders of a previously predicted or known segmentation mask.154155What makes the model incredibly powerful is the ability to combine the prompts above.156Point, box, and mask prompts can be combined in several different ways to achieve the157best result.158159Let's see the semantics of passing these prompts to the Segment Anything model in160KerasHub. Input to the SAM model is a dictionary with keys:1611621. `"images"`: A batch of images to segment. Must be of shape `(B, 1024, 1024, 3)`.1632. `"points"`: A batch of point prompts. Each point is an `(x, y)` coordinate originating164from the top-left corner of the image. In other works, each point is of the form `(r, c)`165where `r` and `c` are the row and column of the pixel in the image. Must be of shape `(B,166N, 2)`.1673. `"labels"`: A batch of labels for the given points. `1` represents foreground points168and `0` represents background points. Must be of shape `(B, N)`.1694. `"boxes"`: A batch of boxes. Note that the model only accepts one box per batch.170Hence, the expected shape is `(B, 1, 2, 2)`. Each box is a collection of 2 points: the171top left corner and the bottom right corner of the box. The points here follow the same172semantics as the point prompts. Here the `1` in the second dimension represents the173presence of box prompts. If the box prompts are missing, a placeholder input of shape174`(B, 0, 2, 2)` must be passed.1755. `"masks"`: A batch of masks. Just like box prompts, only one mask prompt per image is176allowed. The shape of the input mask must be `(B, 1, 256, 256, 1)` if they are present177and `(B, 0, 256, 256, 1)` for missing mask prompt.178179Placeholder prompts are only required when calling the model directly (i.e.180`model(...)`). When calling the `predict` method, missing prompts can be omitted from the181input dictionary.182183## Point prompts184185First, let's segment an image using point prompts. We load the image and resize it to186shape `(1024, 1024)`, the image size the pretrained SAM model expects.187"""188189# Load our image190image = np.array(keras.utils.load_img("truck.jpg"))191image = inference_resizing(image)192193plt.figure(figsize=(10, 10))194plt.imshow(ops.convert_to_numpy(image) / 255.0)195plt.axis("on")196plt.show()197198"""199Next, we will define the point on the object we want to segment. Let's try to segment the200truck's window pane at coordinates `(284, 213)`.201"""202203# Define the input point prompt204input_point = np.array([[284, 213.5]])205input_label = np.array([1])206207plt.figure(figsize=(10, 10))208plt.imshow(ops.convert_to_numpy(image) / 255.0)209show_points(input_point, input_label, plt.gca())210plt.axis("on")211plt.show()212213"""214Now let's call the `predict` method of our model to get the segmentation masks.215216**Note**: We don't call the model directly (`model(...)`) since placeholder prompts are217required to do so. Missing prompts are handled automatically by the predict method so we218call it instead. Also, when no box prompts are present, the points and labels need to be219padded with a zero point prompt and `-1` label prompt respectively. The cell below220demonstrates how this works.221"""222223outputs = model.predict(224{225"images": image[np.newaxis, ...],226"points": np.concatenate(227[input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1228),229"labels": np.concatenate(230[input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1231),232}233)234235"""236`SegmentAnythingModel.predict` returns two outputs. First are logits (segmentation masks)237of shape `(1, 4, 256, 256)` and the other are the IoU confidence scores (of shape `(1,2384)`) for each mask predicted. The pretrained SAM model predicts four masks: the first is239the best mask the model could come up with for the given prompts, and the other 3 are the240alternative masks which can be used in case the best prediction doesn't contain the241desired object. The user can choose whichever mask they prefer.242243Let's visualize the masks returned by the model!244"""245246# Resize the mask to our image shape i.e. (1024, 1024)247mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]248# Convert the logits to a numpy array249# and convert the logits to a boolean mask250mask = ops.convert_to_numpy(mask) > 0.0251iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])252253plt.figure(figsize=(10, 10))254plt.imshow(ops.convert_to_numpy(image) / 255.0)255show_mask(mask, plt.gca())256show_points(input_point, input_label, plt.gca())257plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)258plt.axis("off")259plt.show()260261"""262As expected, the model returns a segmentation mask for the truck's window pane. But, our263point prompt can also mean a range of other things. For example, another possible mask264that contains our point is just the right side of the window pane or the whole truck.265"""266267"""268Let's also visualize the other masks the model has predicted.269"""270271fig, ax = plt.subplots(1, 3, figsize=(20, 60))272masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]273for i, (mask, score) in enumerate(zip(masks, scores)):274mask = inference_resizing(mask[..., None], pad=False)[..., 0]275mask, score = map(ops.convert_to_numpy, (mask, score))276mask = 1 * (mask > 0.0)277ax[i].imshow(ops.convert_to_numpy(image) / 255.0)278show_mask(mask, ax[i])279show_points(input_point, input_label, ax[i])280ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)281ax[i].axis("off")282plt.show()283284"""285Nice! SAM was able to capture the ambiguity of our point prompt and also returned other286possible segmentation masks.287"""288289"""290## Box Prompts291292Now, let's see how we can prompt the model using boxes. The box is specified using two293points, the top-left corner and the bottom-right corner of the bounding box in xyxy294format. Let's prompt the model using a bounding box around the left front tyre of the295truck.296"""297298# Let's specify the box299input_box = np.array([[240, 340], [400, 500]])300301outputs = model.predict(302{"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}303)304mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]305mask = ops.convert_to_numpy(mask) > 0.0306307plt.figure(figsize=(10, 10))308plt.imshow(ops.convert_to_numpy(image) / 255.0)309show_mask(mask, plt.gca())310show_box(input_box, plt.gca())311plt.axis("off")312plt.show()313314"""315Boom! The model perfectly segments out the left front tyre in our bounding box.316317## Combining prompts318319To get the true potential of the model out, let's combine box and point prompts and see320what the model does.321"""322323# Let's specify the box324input_box = np.array([[240, 340], [400, 500]])325# Let's specify the point and mark it background326input_point = np.array([[325, 425]])327input_label = np.array([0])328329outputs = model.predict(330{331"images": image[np.newaxis, ...],332"points": input_point[np.newaxis, ...],333"labels": input_label[np.newaxis, ...],334"boxes": input_box[np.newaxis, np.newaxis, ...],335}336)337mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]338mask = ops.convert_to_numpy(mask) > 0.0339340plt.figure(figsize=(10, 10))341plt.imshow(ops.convert_to_numpy(image) / 255.0)342show_mask(mask, plt.gca())343show_box(input_box, plt.gca())344show_points(input_point, input_label, plt.gca())345plt.axis("off")346plt.show()347348"""349Voila! The model understood that the object we wanted to exclude from our mask was the350rim of the tyre.351352## Text prompts353354Finally, let's see how text prompts can be used along with KerasHub's355`SegmentAnythingModel`.356357For this demo, we will use the358[offical Grounding DINO model](https://github.com/IDEA-Research/GroundingDINO).359Grounding DINO is a model that360takes as input a `(image, text)` pair and generates a bounding box around the object in361the `image` described by the `text`. You can refer to the362[paper](https://arxiv.org/abs/2303.05499) for more details on the implementation of the363model.364365For this part of the demo, we will need to install the `groundingdino` package from366source:367368```369pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git370```371372Then, we can install the pretrained model's weights and config:373"""374375"""shell376!pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git377"""378379"""shell380!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth381!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py382"""383384from groundingdino.util.inference import Model as GroundingDINO385386CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"387WEIGHTS_PATH = "groundingdino_swint_ogc.pth"388389grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)390391"""392Let's load an image of a dog for this part!393"""394395filepath = keras.utils.get_file(396origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"397)398image = np.array(keras.utils.load_img(filepath))399image = ops.convert_to_numpy(inference_resizing(image))400401plt.figure(figsize=(10, 10))402plt.imshow(image / 255.0)403plt.axis("on")404plt.show()405406"""407We first predict the bounding box of the object we want to segment using the Grounding408DINO model. Then, we prompt the SAM model using the bounding box to get the segmentation409mask.410411Let's try to segment out the harness of the dog. Change the image and text below to412segment whatever you want using text from your image!413"""414415# Let's predict the bounding box for the harness of the dog416boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")417boxes = np.array(boxes[0].xyxy)418419outputs = model.predict(420{421"images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),422"boxes": boxes.reshape(-1, 1, 2, 2),423},424batch_size=1,425)426427"""428And that's it! We got a segmentation mask for our text prompt using the combination of429Gounding DINO + SAM! This is a very powerful technique to combine different models to430expand the applications!431432Let's visualize the results.433"""434435plt.figure(figsize=(10, 10))436plt.imshow(image / 255.0)437438for mask in outputs["masks"]:439mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]440mask = ops.convert_to_numpy(mask) > 0.0441show_mask(mask, plt.gca())442show_box(boxes, plt.gca())443444plt.axis("off")445plt.show()446447"""448## Optimizing SAM449450You can use `mixed_float16` or `bfloat16` dtype policies to gain huge speedups and memory451optimizations at releatively low precision loss.452"""453454# Load our image455image = np.array(keras.utils.load_img("truck.jpg"))456image = inference_resizing(image)457458# Specify the prompt459input_box = np.array([[240, 340], [400, 500]])460461# Let's first see how fast the model is with float32 dtype462time_taken = timeit.repeat(463'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',464repeat=3,465number=3,466globals=globals(),467)468print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")469470# Set the dtype policy in Keras471keras.mixed_precision.set_global_policy("mixed_float16")472473model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")474475time_taken = timeit.repeat(476'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis,np.newaxis, ...]}, verbose=False)',477repeat=3,478number=3,479globals=globals(),480)481print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")482483"""484Here's a comparison of KerasHub's implementation with the original PyTorch485implementation!486487488489The script used to generate the benchmarks is present490[here](https://github.com/tirthasheshpatel/segment_anything_keras/blob/main/Segment_Anything_Benchmarks.ipynb).491"""492493"""494## Conclusion495496KerasHub's `SegmentAnythingModel` supports a variety of applications and, with the help of497Keras 3, enables running the model on TensorFlow, JAX, and PyTorch! With the help of XLA498in JAX and TensorFlow, the model runs several times faster than the original499implementation. Moreover, using Keras's mixed precision support helps optimize memory use500and computation time with just one line of code!501502For more advanced uses, check out the503[Automatic Mask Generator demo](https://github.com/tirthasheshpatel/segment_anything_keras/blob/main/Segment_Anything_Automatic_Mask_Generator_Demo.ipynb).504"""505506507