Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/classification_with_keras_hub.py
3293 views
1
"""
2
Title: Image Classification with KerasHub
3
Author: [Gowtham Paimagam](https://github.com/gowthamkpr), [lukewood](https://lukewood.xyz)
4
Date created: 09/24/2024
5
Last modified: 10/04/2024
6
Description: Use KerasHub to train powerful image classifiers.
7
Accelerator: GPU
8
"""
9
10
"""
11
Classification is the process of predicting a categorical label for a given
12
input image.
13
While classification is a relatively straightforward computer vision task,
14
modern approaches still are built of several complex components.
15
Luckily, Keras provides APIs to construct commonly used components.
16
17
This guide demonstrates KerasHub's modular approach to solving image
18
classification problems at three levels of complexity:
19
20
- Inference with a pretrained classifier
21
- Fine-tuning a pretrained backbone
22
- Training a image classifier from scratch
23
24
KerasHub uses Keras 3 to work with any of TensorFlow, PyTorch or Jax. In the
25
guide below, we will use the `jax` backend. This guide runs in
26
TensorFlow or PyTorch backends with zero changes, simply update the
27
`KERAS_BACKEND` below.
28
29
We use Professor Keras, the official Keras mascot, as a
30
visual reference for the complexity of the material:
31
32
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png)
33
"""
34
35
"""shell
36
!pip install -q git+https://github.com/keras-team/keras-hub.git
37
!pip install -q --upgrade keras # Upgrade to Keras 3.
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
43
44
import math
45
import numpy as np
46
import matplotlib.pyplot as plt
47
48
import keras
49
from keras import losses
50
from keras import ops
51
from keras import optimizers
52
from keras.optimizers import schedules
53
from keras import metrics
54
from keras.applications.imagenet_utils import decode_predictions
55
import keras_hub
56
57
# Import tensorflow for `tf.data` and its preprocessing functions
58
import tensorflow as tf
59
import tensorflow_datasets as tfds
60
61
62
"""
63
## Inference with a pretrained classifier
64
65
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png)
66
67
Let's get started with the simplest KerasHub API: a pretrained classifier.
68
In this example, we will construct a classifier that was
69
pretrained on the ImageNet dataset.
70
We'll use this model to solve the age old "Cat or Dog" problem.
71
72
The highest level module in KerasHub is a *task*. A *task* is a `keras.Model`
73
consisting of a (generally pretrained) backbone model and task-specific layers.
74
Here's an example using `keras_hub.models.ImageClassifier` with an
75
ResNet Backbone.
76
77
ResNet is a great starting model when constructing an image
78
classification pipeline.
79
This architecture manages to achieve high accuracy, while using a
80
compact parameter count.
81
If a ResNet is not powerful enough for the task you are hoping to
82
solve, be sure to check out
83
[KerasHub's other available Backbones](https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models)!
84
"""
85
86
classifier = keras_hub.models.ImageClassifier.from_preset("resnet_v2_50_imagenet")
87
88
"""
89
You may notice a small deviation from the old `keras.applications` API; where
90
you would construct the class with `Resnet50V2(weights="imagenet")`.
91
While the old API was great for classification, it did not scale effectively to
92
other use cases that required complex architectures, like object detection and
93
semantic segmentation.
94
95
We first create a utility function for plotting images throughout this tutorial:
96
"""
97
98
99
def plot_image_gallery(images, titles=None, num_cols=3, figsize=(6, 12)):
100
num_images = len(images)
101
images = np.asarray(images) / 255.0
102
images = np.minimum(np.maximum(images, 0.0), 1.0)
103
num_rows = (num_images + num_cols - 1) // num_cols
104
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, squeeze=False)
105
axes = axes.flatten() # Flatten in case the axes is a 2D array
106
107
for i, ax in enumerate(axes):
108
if i < num_images:
109
# Plot the image
110
ax.imshow(images[i])
111
ax.axis("off") # Remove axis
112
if titles and len(titles) > i:
113
ax.set_title(titles[i], fontsize=12)
114
else:
115
# Turn off the axis for any empty subplot
116
ax.axis("off")
117
118
plt.show()
119
plt.close()
120
121
122
"""
123
Now that our classifier is built, let's apply it to this cute cat picture!
124
"""
125
126
filepath = keras.utils.get_file(
127
origin="https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/5hR96puA_VA.jpg/1024px-5hR96puA_VA.jpg"
128
)
129
image = keras.utils.load_img(filepath)
130
image = np.array([image])
131
plot_image_gallery(image, num_cols=1, figsize=(3, 3))
132
133
"""
134
Next, let's get some predictions from our classifier:
135
"""
136
137
predictions = classifier.predict(image)
138
139
"""
140
Predictions come in the form of softmax-ed category rankings.
141
We can use Keras' `imagenet_utils.decode_predictions` function to map
142
them to class names:
143
"""
144
145
print(f"Top two classes are:\n{decode_predictions(predictions, top=2)}")
146
147
"""
148
Great! Both of these appear to be correct!
149
However, one of the classes is "Bath towel".
150
We're trying to classify Cats VS Dogs.
151
We don't care about the towel!
152
153
Ideally, we'd have a classifier that only performs computation to determine if
154
an image is a cat or a dog, and has all of its resources dedicated to this task.
155
This can be solved by fine tuning our own classifier.
156
157
## Fine tuning a pretrained classifier
158
159
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png)
160
161
When labeled images specific to our task are available, fine-tuning a custom
162
classifier can improve performance.
163
If we want to train a Cats vs Dogs Classifier, using explicitly labeled Cat vs
164
Dog data should perform better than the generic classifier!
165
For many tasks, no relevant pretrained model
166
will be available (e.g., categorizing images specific to your application).
167
168
First, let's get started by loading some data:
169
"""
170
171
BATCH_SIZE = 32
172
IMAGE_SIZE = (224, 224)
173
AUTOTUNE = tf.data.AUTOTUNE
174
tfds.disable_progress_bar()
175
176
data, dataset_info = tfds.load("cats_vs_dogs", with_info=True, as_supervised=True)
177
train_steps_per_epoch = dataset_info.splits["train"].num_examples // BATCH_SIZE
178
train_dataset = data["train"]
179
180
num_classes = dataset_info.features["label"].num_classes
181
182
resizing = keras.layers.Resizing(
183
IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True
184
)
185
186
187
def preprocess_inputs(image, label):
188
image = tf.cast(image, tf.float32)
189
# Staticly resize images as we only iterate the dataset once.
190
return resizing(image), tf.one_hot(label, num_classes)
191
192
193
# Shuffle the dataset to increase diversity of batches.
194
# 10*BATCH_SIZE follows the assumption that bigger machines can handle bigger
195
# shuffle buffers.
196
train_dataset = train_dataset.shuffle(
197
10 * BATCH_SIZE, reshuffle_each_iteration=True
198
).map(preprocess_inputs, num_parallel_calls=AUTOTUNE)
199
train_dataset = train_dataset.batch(BATCH_SIZE)
200
201
images = next(iter(train_dataset.take(1)))[0]
202
plot_image_gallery(images)
203
204
"""
205
Meow!
206
207
Next let's construct our model.
208
The use of imagenet in the preset name indicates that the backbone was
209
pretrained on the ImageNet dataset.
210
Pretrained backbones extract more information from our labeled examples by
211
leveraging patterns extracted from potentially much larger datasets.
212
213
Next lets put together our classifier:
214
"""
215
216
model = keras_hub.models.ImageClassifier.from_preset(
217
"resnet_v2_50_imagenet", num_classes=2
218
)
219
model.compile(
220
loss="categorical_crossentropy",
221
optimizer=keras.optimizers.SGD(learning_rate=0.01),
222
metrics=["accuracy"],
223
)
224
225
"""
226
Here our classifier is just a simple `keras.Sequential`.
227
All that is left to do is call `model.fit()`:
228
"""
229
230
model.fit(train_dataset)
231
232
233
"""
234
Let's look at how our model performs after the fine tuning:
235
"""
236
237
predictions = model.predict(image)
238
239
classes = {0: "cat", 1: "dog"}
240
print("Top class is:", classes[predictions[0].argmax()])
241
242
"""
243
Awesome - looks like the model correctly classified the image.
244
"""
245
246
"""
247
## Train a Classifier from Scratch
248
249
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png)
250
251
Now that we've gotten our hands dirty with classification, let's take on one
252
last task: training a classification model from scratch!
253
A standard benchmark for image classification is the ImageNet dataset, however
254
due to licensing constraints we will use the CalTech 101 image classification
255
dataset in this tutorial.
256
While we use the simpler CalTech 101 dataset in this guide, the same training
257
template may be used on ImageNet to achieve near state-of-the-art scores.
258
259
Let's start out by tackling data loading:
260
"""
261
262
BATCH_SIZE = 32
263
NUM_CLASSES = 101
264
IMAGE_SIZE = (224, 224)
265
266
# Change epochs to 100~ to fully train.
267
EPOCHS = 1
268
269
270
def package_inputs(image, label):
271
return {"images": image, "labels": tf.one_hot(label, NUM_CLASSES)}
272
273
274
train_ds, eval_ds = tfds.load(
275
"caltech101", split=["train", "test"], as_supervised="true"
276
)
277
train_ds = train_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)
278
eval_ds = eval_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)
279
280
train_ds = train_ds.shuffle(BATCH_SIZE * 16)
281
augmenters = []
282
283
"""
284
The CalTech101 dataset has different sizes for every image, so we resize images before
285
batching them using the
286
`batch()` API.
287
"""
288
289
resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
290
train_ds = train_ds.map(resize)
291
eval_ds = eval_ds.map(resize)
292
293
train_ds = train_ds.batch(BATCH_SIZE)
294
eval_ds = eval_ds.batch(BATCH_SIZE)
295
296
batch = next(iter(train_ds.take(1)))
297
image_batch = batch["images"]
298
label_batch = batch["labels"]
299
300
plot_image_gallery(
301
image_batch,
302
)
303
304
"""
305
### Data Augmentation
306
307
In our previous finetuning example, we performed a static resizing operation and
308
did not utilize any image augmentation.
309
This is because a single pass over the training set was sufficient to achieve
310
decent results.
311
When training to solve a more difficult task, you'll want to include data
312
augmentation in your data pipeline.
313
314
Data augmentation is a technique to make your model robust to changes in input
315
data such as lighting, cropping, and orientation.
316
Keras includes some of the most useful augmentations in the `keras.layers`
317
API.
318
Creating an optimal pipeline of augmentations is an art, but in this section of
319
the guide we'll offer some tips on best practices for classification.
320
321
One caveat to be aware of with image data augmentation is that you must be careful
322
to not shift your augmented data distribution too far from the original data
323
distribution.
324
The goal is to prevent overfitting and increase generalization,
325
but samples that lie completely out of the data distribution simply add noise to
326
the training process.
327
328
The first augmentation we'll use is `RandomFlip`.
329
This augmentation behaves more or less how you'd expect: it either flips the
330
image or not.
331
While this augmentation is useful in CalTech101 and ImageNet, it should be noted
332
that it should not be used on tasks where the data distribution is not vertical
333
mirror invariant.
334
An example of a dataset where this occurs is MNIST hand written digits.
335
Flipping a `6` over the
336
vertical axis will make the digit appear more like a `7` than a `6`, but the
337
label will still show a `6`.
338
"""
339
340
random_flip = keras.layers.RandomFlip()
341
augmenters += [random_flip]
342
343
image_batch = random_flip(image_batch)
344
plot_image_gallery(image_batch)
345
346
"""
347
Half of the images have been flipped!
348
349
The next augmentation we'll use is `RandomCrop`.
350
This operation selects a random subset of the image.
351
By using this augmentation, we force our classifier to become spatially invariant.
352
353
Let's add a `RandomCrop` to our set of augmentations:
354
"""
355
356
crop = keras.layers.RandomCrop(
357
int(IMAGE_SIZE[0] * 0.9),
358
int(IMAGE_SIZE[1] * 0.9),
359
)
360
361
augmenters += [crop]
362
363
image_batch = crop(image_batch)
364
plot_image_gallery(
365
image_batch,
366
)
367
368
"""
369
We can also rotate images by a random angle using Keras' `RandomRotation` layer. Let's
370
apply a rotation by a randomly selected angle in the interval -45°...45°:
371
"""
372
373
rotate = keras.layers.RandomRotation((-45 / 360, 45 / 360))
374
375
augmenters += [rotate]
376
377
image_batch = rotate(image_batch)
378
plot_image_gallery(image_batch)
379
380
resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
381
augmenters += [resize]
382
383
image_batch = resize(image_batch)
384
plot_image_gallery(image_batch)
385
386
"""
387
Now let's apply our final augmenter to the training data:
388
"""
389
390
391
def create_augmenter_fn(augmenters):
392
def augmenter_fn(inputs):
393
for augmenter in augmenters:
394
inputs["images"] = augmenter(inputs["images"])
395
return inputs
396
397
return augmenter_fn
398
399
400
augmenter_fn = create_augmenter_fn(augmenters)
401
train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)
402
403
image_batch = next(iter(train_ds.take(1)))["images"]
404
plot_image_gallery(
405
image_batch,
406
)
407
408
"""
409
We also need to resize our evaluation set to get dense batches of the image size
410
expected by our model. We directly use the deterministic `keras.layers.Resizing` in
411
this case to avoid adding noise to our evaluation metric due to applying random
412
augmentations.
413
"""
414
415
inference_resizing = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
416
417
418
def do_resize(inputs):
419
inputs["images"] = inference_resizing(inputs["images"])
420
return inputs
421
422
423
eval_ds = eval_ds.map(do_resize, num_parallel_calls=tf.data.AUTOTUNE)
424
425
image_batch = next(iter(eval_ds.take(1)))["images"]
426
plot_image_gallery(
427
image_batch,
428
)
429
430
"""
431
Finally, lets unpackage our datasets and prepare to pass them to `model.fit()`,
432
which accepts a tuple of `(images, labels)`.
433
"""
434
435
436
def unpackage_dict(inputs):
437
return inputs["images"], inputs["labels"]
438
439
440
train_ds = train_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
441
eval_ds = eval_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
442
443
"""
444
Data augmentation is by far the hardest piece of training a modern
445
classifier.
446
Congratulations on making it this far!
447
448
### Optimizer Tuning
449
450
To achieve optimal performance, we need to use a learning rate schedule instead
451
of a single learning rate. While we won't go into detail on the Cosine decay
452
with warmup schedule used here,
453
[you can read more about it here](https://scorrea92.medium.com/cosine-learning-rate-decay-e8b50aa455b).
454
"""
455
456
457
def lr_warmup_cosine_decay(
458
global_step,
459
warmup_steps,
460
hold=0,
461
total_steps=0,
462
start_lr=0.0,
463
target_lr=1e-2,
464
):
465
# Cosine decay
466
learning_rate = (
467
0.5
468
* target_lr
469
* (
470
1
471
+ ops.cos(
472
math.pi
473
* ops.convert_to_tensor(
474
global_step - warmup_steps - hold, dtype="float32"
475
)
476
/ ops.convert_to_tensor(
477
total_steps - warmup_steps - hold, dtype="float32"
478
)
479
)
480
)
481
)
482
483
warmup_lr = target_lr * (global_step / warmup_steps)
484
485
if hold > 0:
486
learning_rate = ops.where(
487
global_step > warmup_steps + hold, learning_rate, target_lr
488
)
489
490
learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
491
return learning_rate
492
493
494
class WarmUpCosineDecay(schedules.LearningRateSchedule):
495
def __init__(self, warmup_steps, total_steps, hold, start_lr=0.0, target_lr=1e-2):
496
super().__init__()
497
self.start_lr = start_lr
498
self.target_lr = target_lr
499
self.warmup_steps = warmup_steps
500
self.total_steps = total_steps
501
self.hold = hold
502
503
def __call__(self, step):
504
lr = lr_warmup_cosine_decay(
505
global_step=step,
506
total_steps=self.total_steps,
507
warmup_steps=self.warmup_steps,
508
start_lr=self.start_lr,
509
target_lr=self.target_lr,
510
hold=self.hold,
511
)
512
return ops.where(step > self.total_steps, 0.0, lr)
513
514
515
"""
516
![WarmUpCosineDecay schedule](https://i.imgur.com/YCr5pII.png)
517
518
The schedule looks a as we expect.
519
520
Next let's construct this optimizer:
521
"""
522
523
total_images = 9000
524
total_steps = (total_images // BATCH_SIZE) * EPOCHS
525
warmup_steps = int(0.1 * total_steps)
526
hold_steps = int(0.45 * total_steps)
527
schedule = WarmUpCosineDecay(
528
start_lr=0.05,
529
target_lr=1e-2,
530
warmup_steps=warmup_steps,
531
total_steps=total_steps,
532
hold=hold_steps,
533
)
534
optimizer = optimizers.SGD(
535
weight_decay=5e-4,
536
learning_rate=schedule,
537
momentum=0.9,
538
)
539
540
"""
541
At long last, we can now build our model and call `fit()`!
542
Here, we directly instantiate our `ResNetBackbone`, specifying all architectural
543
parameters, which gives us full control to tweak the architecture.
544
"""
545
546
backbone = keras_hub.models.ResNetBackbone(
547
input_conv_filters=[64],
548
input_conv_kernel_sizes=[7],
549
stackwise_num_filters=[64, 64, 64],
550
stackwise_num_blocks=[2, 2, 2],
551
stackwise_num_strides=[1, 2, 2],
552
block_type="basic_block",
553
)
554
model = keras.Sequential(
555
[
556
backbone,
557
keras.layers.GlobalMaxPooling2D(),
558
keras.layers.Dropout(rate=0.5),
559
keras.layers.Dense(101, activation="softmax"),
560
]
561
)
562
563
"""
564
We employ label smoothing to prevent the model from overfitting to artifacts of
565
our augmentation process.
566
"""
567
568
loss = losses.CategoricalCrossentropy(label_smoothing=0.1)
569
570
"""
571
Let's compile our model:
572
"""
573
574
model.compile(
575
loss=loss,
576
optimizer=optimizer,
577
metrics=[
578
metrics.CategoricalAccuracy(),
579
metrics.TopKCategoricalAccuracy(k=5),
580
],
581
)
582
583
"""
584
and finally call fit().
585
"""
586
587
model.fit(
588
train_ds,
589
epochs=EPOCHS,
590
validation_data=eval_ds,
591
)
592
593
"""
594
Congratulations! You now know how to train a powerful image classifier from
595
scratch using KerasHub.
596
Depending on the availability of labeled data for your application, training
597
from scratch may or may not be more powerful than using transfer learning in
598
addition to the data augmentations discussed above. For smaller datasets,
599
pretrained models generally produce high accuracy and faster convergence.
600
"""
601
602
"""
603
## Conclusions
604
605
While image classification is perhaps the simplest problem in computer vision,
606
the modern landscape has numerous complex components.
607
Luckily, KerasHub offers robust, production-grade APIs to make assembling most
608
of these components possible in one line of code.
609
Through the use of KerasHub's `ImageClassifier` API, pretrained weights, and
610
Keras' data augmentations you can assemble everything you need to train a
611
powerful classifier in a few hundred lines of code!
612
613
As a follow up exercise, try fine tuning a KerasHub classifier on your own dataset!
614
"""
615
616