Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/segment_anything_in_keras_hub.py
3293 views
1
"""
2
Title: Segment Anything in KerasHub!
3
Author: Tirth Patel, Ian Stenbit, Divyashree Sreepathihalli<br>
4
Date created: 2024/10/1<br>
5
Last modified: 2024/10/1<br>
6
Description: Segment anything using text, box, and points prompts in KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
The Segment Anything Model (SAM) produces high quality object masks from input prompts
14
such as points or boxes, and it can be used to generate masks for all objects in an
15
image. It has been trained on a
16
[dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1
17
billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
18
19
In this guide, we will show how to use KerasHub's implementation of the
20
[Segment Anything Model](https://github.com/facebookresearch/segment-anything)
21
and show how powerful TensorFlow's and JAX's performance boost is.
22
23
First, let's get all our dependencies and images for our demo.
24
"""
25
26
"""shell
27
!pip install -Uq git+https://github.com/keras-team/keras-hub.git
28
!pip install -Uq keras
29
"""
30
31
"""shell
32
!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
33
"""
34
35
"""
36
## Choose your backend
37
38
With Keras 3, you can choose to use your favorite backend!
39
"""
40
41
import os
42
43
os.environ["KERAS_BACKEND"] = "jax"
44
45
import timeit
46
import numpy as np
47
import matplotlib.pyplot as plt
48
import keras
49
from keras import ops
50
import keras_hub
51
52
"""
53
## Helper functions
54
55
Let's define some helper functions for visulazing the images, prompts, and the
56
segmentation results.
57
"""
58
59
60
def show_mask(mask, ax, random_color=False):
61
if random_color:
62
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
63
else:
64
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
65
h, w = mask.shape[-2:]
66
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
67
ax.imshow(mask_image)
68
69
70
def show_points(coords, labels, ax, marker_size=375):
71
pos_points = coords[labels == 1]
72
neg_points = coords[labels == 0]
73
ax.scatter(
74
pos_points[:, 0],
75
pos_points[:, 1],
76
color="green",
77
marker="*",
78
s=marker_size,
79
edgecolor="white",
80
linewidth=1.25,
81
)
82
ax.scatter(
83
neg_points[:, 0],
84
neg_points[:, 1],
85
color="red",
86
marker="*",
87
s=marker_size,
88
edgecolor="white",
89
linewidth=1.25,
90
)
91
92
93
def show_box(box, ax):
94
box = box.reshape(-1)
95
x0, y0 = box[0], box[1]
96
w, h = box[2] - box[0], box[3] - box[1]
97
ax.add_patch(
98
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
99
)
100
101
102
def inference_resizing(image, pad=True):
103
# Compute Preprocess Shape
104
image = ops.cast(image, dtype="float32")
105
old_h, old_w = image.shape[0], image.shape[1]
106
scale = 1024 * 1.0 / max(old_h, old_w)
107
new_h = old_h * scale
108
new_w = old_w * scale
109
preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)
110
111
# Resize the image
112
image = ops.image.resize(image[None, ...], preprocess_shape)[0]
113
114
# Pad the shorter side
115
if pad:
116
pixel_mean = ops.array([123.675, 116.28, 103.53])
117
pixel_std = ops.array([58.395, 57.12, 57.375])
118
image = (image - pixel_mean) / pixel_std
119
h, w = image.shape[0], image.shape[1]
120
pad_h = 1024 - h
121
pad_w = 1024 - w
122
image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])
123
# KerasHub now rescales the images and normalizes them.
124
# Just unnormalize such that when KerasHub normalizes them
125
# again, the padded values map to 0.
126
image = image * pixel_std + pixel_mean
127
return image
128
129
130
"""
131
## Get the pretrained SAM model
132
133
We can initialize a trained SAM model using KerasHub's `from_preset` factory method. Here,
134
we use the huge ViT backbone trained on the SA-1B dataset (`sam_huge_sa1b`) for
135
high-quality segmentation masks. You can also use one of the `sam_large_sa1b` or
136
`sam_base_sa1b` for better performance (at the cost of decreasing quality of segmentation
137
masks).
138
"""
139
140
model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")
141
142
"""
143
## Understanding Prompts
144
145
Segment Anything allows prompting an image using points, boxes, and masks:
146
147
1. Point prompts are the most basic of all: the model tries to guess the object given a
148
point on an image. The point can either be a foreground point (i.e. the desired
149
segmentation mask contains the point in it) or a backround point (i.e. the point lies
150
outside the desired mask).
151
2. Another way to prompt the model is using boxes. Given a bounding box, the model tries
152
to segment the object contained in it.
153
3. Finally, the model can also be prompted using a mask itself. This is useful, for
154
instance, to refine the borders of a previously predicted or known segmentation mask.
155
156
What makes the model incredibly powerful is the ability to combine the prompts above.
157
Point, box, and mask prompts can be combined in several different ways to achieve the
158
best result.
159
160
Let's see the semantics of passing these prompts to the Segment Anything model in
161
KerasHub. Input to the SAM model is a dictionary with keys:
162
163
1. `"images"`: A batch of images to segment. Must be of shape `(B, 1024, 1024, 3)`.
164
2. `"points"`: A batch of point prompts. Each point is an `(x, y)` coordinate originating
165
from the top-left corner of the image. In other works, each point is of the form `(r, c)`
166
where `r` and `c` are the row and column of the pixel in the image. Must be of shape `(B,
167
N, 2)`.
168
3. `"labels"`: A batch of labels for the given points. `1` represents foreground points
169
and `0` represents background points. Must be of shape `(B, N)`.
170
4. `"boxes"`: A batch of boxes. Note that the model only accepts one box per batch.
171
Hence, the expected shape is `(B, 1, 2, 2)`. Each box is a collection of 2 points: the
172
top left corner and the bottom right corner of the box. The points here follow the same
173
semantics as the point prompts. Here the `1` in the second dimension represents the
174
presence of box prompts. If the box prompts are missing, a placeholder input of shape
175
`(B, 0, 2, 2)` must be passed.
176
5. `"masks"`: A batch of masks. Just like box prompts, only one mask prompt per image is
177
allowed. The shape of the input mask must be `(B, 1, 256, 256, 1)` if they are present
178
and `(B, 0, 256, 256, 1)` for missing mask prompt.
179
180
Placeholder prompts are only required when calling the model directly (i.e.
181
`model(...)`). When calling the `predict` method, missing prompts can be omitted from the
182
input dictionary.
183
184
## Point prompts
185
186
First, let's segment an image using point prompts. We load the image and resize it to
187
shape `(1024, 1024)`, the image size the pretrained SAM model expects.
188
"""
189
190
# Load our image
191
image = np.array(keras.utils.load_img("truck.jpg"))
192
image = inference_resizing(image)
193
194
plt.figure(figsize=(10, 10))
195
plt.imshow(ops.convert_to_numpy(image) / 255.0)
196
plt.axis("on")
197
plt.show()
198
199
"""
200
Next, we will define the point on the object we want to segment. Let's try to segment the
201
truck's window pane at coordinates `(284, 213)`.
202
"""
203
204
# Define the input point prompt
205
input_point = np.array([[284, 213.5]])
206
input_label = np.array([1])
207
208
plt.figure(figsize=(10, 10))
209
plt.imshow(ops.convert_to_numpy(image) / 255.0)
210
show_points(input_point, input_label, plt.gca())
211
plt.axis("on")
212
plt.show()
213
214
"""
215
Now let's call the `predict` method of our model to get the segmentation masks.
216
217
**Note**: We don't call the model directly (`model(...)`) since placeholder prompts are
218
required to do so. Missing prompts are handled automatically by the predict method so we
219
call it instead. Also, when no box prompts are present, the points and labels need to be
220
padded with a zero point prompt and `-1` label prompt respectively. The cell below
221
demonstrates how this works.
222
"""
223
224
outputs = model.predict(
225
{
226
"images": image[np.newaxis, ...],
227
"points": np.concatenate(
228
[input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1
229
),
230
"labels": np.concatenate(
231
[input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1
232
),
233
}
234
)
235
236
"""
237
`SegmentAnythingModel.predict` returns two outputs. First are logits (segmentation masks)
238
of shape `(1, 4, 256, 256)` and the other are the IoU confidence scores (of shape `(1,
239
4)`) for each mask predicted. The pretrained SAM model predicts four masks: the first is
240
the best mask the model could come up with for the given prompts, and the other 3 are the
241
alternative masks which can be used in case the best prediction doesn't contain the
242
desired object. The user can choose whichever mask they prefer.
243
244
Let's visualize the masks returned by the model!
245
"""
246
247
# Resize the mask to our image shape i.e. (1024, 1024)
248
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
249
# Convert the logits to a numpy array
250
# and convert the logits to a boolean mask
251
mask = ops.convert_to_numpy(mask) > 0.0
252
iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])
253
254
plt.figure(figsize=(10, 10))
255
plt.imshow(ops.convert_to_numpy(image) / 255.0)
256
show_mask(mask, plt.gca())
257
show_points(input_point, input_label, plt.gca())
258
plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)
259
plt.axis("off")
260
plt.show()
261
262
"""
263
As expected, the model returns a segmentation mask for the truck's window pane. But, our
264
point prompt can also mean a range of other things. For example, another possible mask
265
that contains our point is just the right side of the window pane or the whole truck.
266
"""
267
268
"""
269
Let's also visualize the other masks the model has predicted.
270
"""
271
272
fig, ax = plt.subplots(1, 3, figsize=(20, 60))
273
masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]
274
for i, (mask, score) in enumerate(zip(masks, scores)):
275
mask = inference_resizing(mask[..., None], pad=False)[..., 0]
276
mask, score = map(ops.convert_to_numpy, (mask, score))
277
mask = 1 * (mask > 0.0)
278
ax[i].imshow(ops.convert_to_numpy(image) / 255.0)
279
show_mask(mask, ax[i])
280
show_points(input_point, input_label, ax[i])
281
ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
282
ax[i].axis("off")
283
plt.show()
284
285
"""
286
Nice! SAM was able to capture the ambiguity of our point prompt and also returned other
287
possible segmentation masks.
288
"""
289
290
"""
291
## Box Prompts
292
293
Now, let's see how we can prompt the model using boxes. The box is specified using two
294
points, the top-left corner and the bottom-right corner of the bounding box in xyxy
295
format. Let's prompt the model using a bounding box around the left front tyre of the
296
truck.
297
"""
298
299
# Let's specify the box
300
input_box = np.array([[240, 340], [400, 500]])
301
302
outputs = model.predict(
303
{"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}
304
)
305
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
306
mask = ops.convert_to_numpy(mask) > 0.0
307
308
plt.figure(figsize=(10, 10))
309
plt.imshow(ops.convert_to_numpy(image) / 255.0)
310
show_mask(mask, plt.gca())
311
show_box(input_box, plt.gca())
312
plt.axis("off")
313
plt.show()
314
315
"""
316
Boom! The model perfectly segments out the left front tyre in our bounding box.
317
318
## Combining prompts
319
320
To get the true potential of the model out, let's combine box and point prompts and see
321
what the model does.
322
"""
323
324
# Let's specify the box
325
input_box = np.array([[240, 340], [400, 500]])
326
# Let's specify the point and mark it background
327
input_point = np.array([[325, 425]])
328
input_label = np.array([0])
329
330
outputs = model.predict(
331
{
332
"images": image[np.newaxis, ...],
333
"points": input_point[np.newaxis, ...],
334
"labels": input_label[np.newaxis, ...],
335
"boxes": input_box[np.newaxis, np.newaxis, ...],
336
}
337
)
338
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
339
mask = ops.convert_to_numpy(mask) > 0.0
340
341
plt.figure(figsize=(10, 10))
342
plt.imshow(ops.convert_to_numpy(image) / 255.0)
343
show_mask(mask, plt.gca())
344
show_box(input_box, plt.gca())
345
show_points(input_point, input_label, plt.gca())
346
plt.axis("off")
347
plt.show()
348
349
"""
350
Voila! The model understood that the object we wanted to exclude from our mask was the
351
rim of the tyre.
352
353
## Text prompts
354
355
Finally, let's see how text prompts can be used along with KerasHub's
356
`SegmentAnythingModel`.
357
358
For this demo, we will use the
359
[offical Grounding DINO model](https://github.com/IDEA-Research/GroundingDINO).
360
Grounding DINO is a model that
361
takes as input a `(image, text)` pair and generates a bounding box around the object in
362
the `image` described by the `text`. You can refer to the
363
[paper](https://arxiv.org/abs/2303.05499) for more details on the implementation of the
364
model.
365
366
For this part of the demo, we will need to install the `groundingdino` package from
367
source:
368
369
```
370
pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git
371
```
372
373
Then, we can install the pretrained model's weights and config:
374
"""
375
376
"""shell
377
!pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git
378
"""
379
380
"""shell
381
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
382
!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py
383
"""
384
385
from groundingdino.util.inference import Model as GroundingDINO
386
387
CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
388
WEIGHTS_PATH = "groundingdino_swint_ogc.pth"
389
390
grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)
391
392
"""
393
Let's load an image of a dog for this part!
394
"""
395
396
filepath = keras.utils.get_file(
397
origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
398
)
399
image = np.array(keras.utils.load_img(filepath))
400
image = ops.convert_to_numpy(inference_resizing(image))
401
402
plt.figure(figsize=(10, 10))
403
plt.imshow(image / 255.0)
404
plt.axis("on")
405
plt.show()
406
407
"""
408
We first predict the bounding box of the object we want to segment using the Grounding
409
DINO model. Then, we prompt the SAM model using the bounding box to get the segmentation
410
mask.
411
412
Let's try to segment out the harness of the dog. Change the image and text below to
413
segment whatever you want using text from your image!
414
"""
415
416
# Let's predict the bounding box for the harness of the dog
417
boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")
418
boxes = np.array(boxes[0].xyxy)
419
420
outputs = model.predict(
421
{
422
"images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
423
"boxes": boxes.reshape(-1, 1, 2, 2),
424
},
425
batch_size=1,
426
)
427
428
"""
429
And that's it! We got a segmentation mask for our text prompt using the combination of
430
Gounding DINO + SAM! This is a very powerful technique to combine different models to
431
expand the applications!
432
433
Let's visualize the results.
434
"""
435
436
plt.figure(figsize=(10, 10))
437
plt.imshow(image / 255.0)
438
439
for mask in outputs["masks"]:
440
mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]
441
mask = ops.convert_to_numpy(mask) > 0.0
442
show_mask(mask, plt.gca())
443
show_box(boxes, plt.gca())
444
445
plt.axis("off")
446
plt.show()
447
448
"""
449
## Optimizing SAM
450
451
You can use `mixed_float16` or `bfloat16` dtype policies to gain huge speedups and memory
452
optimizations at releatively low precision loss.
453
"""
454
455
# Load our image
456
image = np.array(keras.utils.load_img("truck.jpg"))
457
image = inference_resizing(image)
458
459
# Specify the prompt
460
input_box = np.array([[240, 340], [400, 500]])
461
462
# Let's first see how fast the model is with float32 dtype
463
time_taken = timeit.repeat(
464
'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
465
repeat=3,
466
number=3,
467
globals=globals(),
468
)
469
print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")
470
471
# Set the dtype policy in Keras
472
keras.mixed_precision.set_global_policy("mixed_float16")
473
474
model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")
475
476
time_taken = timeit.repeat(
477
'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis,np.newaxis, ...]}, verbose=False)',
478
repeat=3,
479
number=3,
480
globals=globals(),
481
)
482
print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")
483
484
"""
485
Here's a comparison of KerasHub's implementation with the original PyTorch
486
implementation!
487
488
![benchmark](https://github.com/tirthasheshpatel/segment_anything_keras/blob/main/benchmark.png?raw=true)
489
490
The script used to generate the benchmarks is present
491
[here](https://github.com/tirthasheshpatel/segment_anything_keras/blob/main/Segment_Anything_Benchmarks.ipynb).
492
"""
493
494
"""
495
## Conclusion
496
497
KerasHub's `SegmentAnythingModel` supports a variety of applications and, with the help of
498
Keras 3, enables running the model on TensorFlow, JAX, and PyTorch! With the help of XLA
499
in JAX and TensorFlow, the model runs several times faster than the original
500
implementation. Moreover, using Keras's mixed precision support helps optimize memory use
501
and computation time with just one line of code!
502
503
For more advanced uses, check out the
504
[Automatic Mask Generator demo](https://github.com/tirthasheshpatel/segment_anything_keras/blob/main/Segment_Anything_Automatic_Mask_Generator_Demo.ipynb).
505
"""
506
507