Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/classification_with_keras_hub.py
8454 views
1
"""
2
Title: Image Classification with KerasHub
3
Author: [Gowtham Paimagam](https://github.com/gowthamkpr), [lukewood](https://lukewood.xyz)
4
Date created: 09/24/2024
5
Last modified: 10/04/2024
6
Description: Use KerasHub to train powerful image classifiers.
7
Accelerator: GPU
8
"""
9
10
"""
11
Classification is the process of predicting a categorical label for a given
12
input image.
13
While classification is a relatively straightforward computer vision task,
14
modern approaches still are built of several complex components.
15
Luckily, Keras provides APIs to construct commonly used components.
16
17
This guide demonstrates KerasHub's modular approach to solving image
18
classification problems at three levels of complexity:
19
20
- Inference with a pretrained classifier
21
- Fine-tuning a pretrained backbone
22
- Training a image classifier from scratch
23
24
KerasHub uses Keras 3 to work with any of TensorFlow, PyTorch or Jax. In the
25
guide below, we will use the `jax` backend. This guide runs in
26
TensorFlow or PyTorch backends with zero changes, simply update the
27
`KERAS_BACKEND` below.
28
29
We use Professor Keras, the official Keras mascot, as a
30
visual reference for the complexity of the material:
31
32
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png)
33
"""
34
35
"""shell
36
!pip install -q git+https://github.com/keras-team/keras-hub.git
37
!pip install -q --upgrade keras # Upgrade to Keras 3.
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
43
44
import math
45
import numpy as np
46
import matplotlib.pyplot as plt
47
48
import keras
49
from keras import losses
50
from keras import ops
51
from keras import optimizers
52
from keras.optimizers import schedules
53
from keras import metrics
54
from keras.applications.imagenet_utils import decode_predictions
55
import keras_hub
56
57
# Import tensorflow for `tf.data` and its preprocessing functions
58
import tensorflow as tf
59
import tensorflow_datasets as tfds
60
61
"""
62
## Inference with a pretrained classifier
63
64
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png)
65
66
Let's get started with the simplest KerasHub API: a pretrained classifier.
67
In this example, we will construct a classifier that was
68
pretrained on the ImageNet dataset.
69
We'll use this model to solve the age old "Cat or Dog" problem.
70
71
The highest level module in KerasHub is a *task*. A *task* is a `keras.Model`
72
consisting of a (generally pretrained) backbone model and task-specific layers.
73
Here's an example using `keras_hub.models.ImageClassifier` with an
74
ResNet Backbone.
75
76
ResNet is a great starting model when constructing an image
77
classification pipeline.
78
This architecture manages to achieve high accuracy, while using a
79
compact parameter count.
80
If a ResNet is not powerful enough for the task you are hoping to
81
solve, be sure to check out
82
[KerasHub's other available Backbones](https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models)!
83
"""
84
85
classifier = keras_hub.models.ImageClassifier.from_preset("resnet_v2_50_imagenet")
86
87
"""
88
You may notice a small deviation from the old `keras.applications` API; where
89
you would construct the class with `Resnet50V2(weights="imagenet")`.
90
While the old API was great for classification, it did not scale effectively to
91
other use cases that required complex architectures, like object detection and
92
semantic segmentation.
93
94
We first create a utility function for plotting images throughout this tutorial:
95
"""
96
97
98
def plot_image_gallery(images, titles=None, num_cols=3, figsize=(6, 12)):
99
num_images = len(images)
100
images = np.asarray(images) / 255.0
101
images = np.minimum(np.maximum(images, 0.0), 1.0)
102
num_rows = (num_images + num_cols - 1) // num_cols
103
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, squeeze=False)
104
axes = axes.flatten() # Flatten in case the axes is a 2D array
105
106
for i, ax in enumerate(axes):
107
if i < num_images:
108
# Plot the image
109
ax.imshow(images[i])
110
ax.axis("off") # Remove axis
111
if titles and len(titles) > i:
112
ax.set_title(titles[i], fontsize=12)
113
else:
114
# Turn off the axis for any empty subplot
115
ax.axis("off")
116
117
plt.show()
118
plt.close()
119
120
121
"""
122
Now that our classifier is built, let's apply it to this cute cat picture!
123
"""
124
125
filepath = keras.utils.get_file(
126
origin="https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/5hR96puA_VA.jpg/1024px-5hR96puA_VA.jpg"
127
)
128
image = keras.utils.load_img(filepath)
129
image = np.array([image])
130
plot_image_gallery(image, num_cols=1, figsize=(3, 3))
131
132
"""
133
Next, let's get some predictions from our classifier:
134
"""
135
136
predictions = classifier.predict(image)
137
138
"""
139
Predictions come in the form of softmax-ed category rankings.
140
We can use Keras' `imagenet_utils.decode_predictions` function to map
141
them to class names:
142
"""
143
144
print(f"Top two classes are:\n{decode_predictions(predictions, top=2)}")
145
146
"""
147
Great! Both of these appear to be correct!
148
However, one of the classes is "Bath towel".
149
We're trying to classify Cats VS Dogs.
150
We don't care about the towel!
151
152
Ideally, we'd have a classifier that only performs computation to determine if
153
an image is a cat or a dog, and has all of its resources dedicated to this task.
154
This can be solved by fine tuning our own classifier.
155
156
## Fine tuning a pretrained classifier
157
158
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png)
159
160
When labeled images specific to our task are available, fine-tuning a custom
161
classifier can improve performance.
162
If we want to train a Cats vs Dogs Classifier, using explicitly labeled Cat vs
163
Dog data should perform better than the generic classifier!
164
For many tasks, no relevant pretrained model
165
will be available (e.g., categorizing images specific to your application).
166
167
First, let's get started by loading some data:
168
"""
169
170
BATCH_SIZE = 32
171
IMAGE_SIZE = (224, 224)
172
AUTOTUNE = tf.data.AUTOTUNE
173
tfds.disable_progress_bar()
174
175
data, dataset_info = tfds.load("cats_vs_dogs", with_info=True, as_supervised=True)
176
train_steps_per_epoch = dataset_info.splits["train"].num_examples // BATCH_SIZE
177
train_dataset = data["train"]
178
179
num_classes = dataset_info.features["label"].num_classes
180
181
resizing = keras.layers.Resizing(
182
IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True
183
)
184
185
186
def preprocess_inputs(image, label):
187
image = tf.cast(image, tf.float32)
188
# Staticly resize images as we only iterate the dataset once.
189
return resizing(image), tf.one_hot(label, num_classes)
190
191
192
# Shuffle the dataset to increase diversity of batches.
193
# 10*BATCH_SIZE follows the assumption that bigger machines can handle bigger
194
# shuffle buffers.
195
train_dataset = train_dataset.shuffle(
196
10 * BATCH_SIZE, reshuffle_each_iteration=True
197
).map(preprocess_inputs, num_parallel_calls=AUTOTUNE)
198
train_dataset = train_dataset.batch(BATCH_SIZE)
199
200
images = next(iter(train_dataset.take(1)))[0]
201
plot_image_gallery(images)
202
203
"""
204
Meow!
205
206
Next let's construct our model.
207
The use of imagenet in the preset name indicates that the backbone was
208
pretrained on the ImageNet dataset.
209
Pretrained backbones extract more information from our labeled examples by
210
leveraging patterns extracted from potentially much larger datasets.
211
212
Next lets put together our classifier:
213
"""
214
215
model = keras_hub.models.ImageClassifier.from_preset(
216
"resnet_v2_50_imagenet", num_classes=2
217
)
218
model.compile(
219
loss="categorical_crossentropy",
220
optimizer=keras.optimizers.SGD(learning_rate=0.01),
221
metrics=["accuracy"],
222
)
223
224
"""
225
Here our classifier is just a simple `keras.Sequential`.
226
All that is left to do is call `model.fit()`:
227
"""
228
229
model.fit(train_dataset)
230
231
232
"""
233
Let's look at how our model performs after the fine tuning:
234
"""
235
236
predictions = model.predict(image)
237
238
classes = {0: "cat", 1: "dog"}
239
print("Top class is:", classes[predictions[0].argmax()])
240
241
"""
242
Awesome - looks like the model correctly classified the image.
243
"""
244
245
"""
246
## Train a Classifier from Scratch
247
248
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png)
249
250
Now that we've gotten our hands dirty with classification, let's take on one
251
last task: training a classification model from scratch!
252
A standard benchmark for image classification is the ImageNet dataset, however
253
due to licensing constraints we will use the CalTech 101 image classification
254
dataset in this tutorial.
255
While we use the simpler CalTech 101 dataset in this guide, the same training
256
template may be used on ImageNet to achieve near state-of-the-art scores.
257
258
Let's start out by tackling data loading:
259
"""
260
261
BATCH_SIZE = 32
262
NUM_CLASSES = 101
263
IMAGE_SIZE = (224, 224)
264
265
# Change epochs to 100~ to fully train.
266
EPOCHS = 1
267
268
269
def package_inputs(image, label):
270
return {"images": image, "labels": tf.one_hot(label, NUM_CLASSES)}
271
272
273
train_ds, eval_ds = tfds.load(
274
"caltech101", split=["train", "test"], as_supervised="true"
275
)
276
train_ds = train_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)
277
eval_ds = eval_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)
278
279
train_ds = train_ds.shuffle(BATCH_SIZE * 16)
280
augmenters = []
281
282
"""
283
The CalTech101 dataset has different sizes for every image, so we resize images before
284
batching them using the
285
`batch()` API.
286
"""
287
288
resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
289
train_ds = train_ds.map(resize)
290
eval_ds = eval_ds.map(resize)
291
292
train_ds = train_ds.batch(BATCH_SIZE)
293
eval_ds = eval_ds.batch(BATCH_SIZE)
294
295
batch = next(iter(train_ds.take(1)))
296
image_batch = batch["images"]
297
label_batch = batch["labels"]
298
299
plot_image_gallery(
300
image_batch,
301
)
302
303
"""
304
### Data Augmentation
305
306
In our previous finetuning example, we performed a static resizing operation and
307
did not utilize any image augmentation.
308
This is because a single pass over the training set was sufficient to achieve
309
decent results.
310
When training to solve a more difficult task, you'll want to include data
311
augmentation in your data pipeline.
312
313
Data augmentation is a technique to make your model robust to changes in input
314
data such as lighting, cropping, and orientation.
315
Keras includes some of the most useful augmentations in the `keras.layers`
316
API.
317
Creating an optimal pipeline of augmentations is an art, but in this section of
318
the guide we'll offer some tips on best practices for classification.
319
320
One caveat to be aware of with image data augmentation is that you must be careful
321
to not shift your augmented data distribution too far from the original data
322
distribution.
323
The goal is to prevent overfitting and increase generalization,
324
but samples that lie completely out of the data distribution simply add noise to
325
the training process.
326
327
The first augmentation we'll use is `RandomFlip`.
328
This augmentation behaves more or less how you'd expect: it either flips the
329
image or not.
330
While this augmentation is useful in CalTech101 and ImageNet, it should be noted
331
that it should not be used on tasks where the data distribution is not vertical
332
mirror invariant.
333
An example of a dataset where this occurs is MNIST hand written digits.
334
Flipping a `6` over the
335
vertical axis will make the digit appear more like a `7` than a `6`, but the
336
label will still show a `6`.
337
"""
338
339
random_flip = keras.layers.RandomFlip()
340
augmenters += [random_flip]
341
342
image_batch = random_flip(image_batch)
343
plot_image_gallery(image_batch)
344
345
"""
346
Half of the images have been flipped!
347
348
The next augmentation we'll use is `RandomCrop`.
349
This operation selects a random subset of the image.
350
By using this augmentation, we force our classifier to become spatially invariant.
351
352
Let's add a `RandomCrop` to our set of augmentations:
353
"""
354
355
crop = keras.layers.RandomCrop(
356
int(IMAGE_SIZE[0] * 0.9),
357
int(IMAGE_SIZE[1] * 0.9),
358
)
359
360
augmenters += [crop]
361
362
image_batch = crop(image_batch)
363
plot_image_gallery(
364
image_batch,
365
)
366
367
"""
368
We can also rotate images by a random angle using Keras' `RandomRotation` layer. Let's
369
apply a rotation by a randomly selected angle in the interval -45°...45°:
370
"""
371
372
rotate = keras.layers.RandomRotation((-45 / 360, 45 / 360))
373
374
augmenters += [rotate]
375
376
image_batch = rotate(image_batch)
377
plot_image_gallery(image_batch)
378
379
resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
380
augmenters += [resize]
381
382
image_batch = resize(image_batch)
383
plot_image_gallery(image_batch)
384
385
"""
386
Now let's apply our final augmenter to the training data:
387
"""
388
389
390
def create_augmenter_fn(augmenters):
391
def augmenter_fn(inputs):
392
for augmenter in augmenters:
393
inputs["images"] = augmenter(inputs["images"])
394
return inputs
395
396
return augmenter_fn
397
398
399
augmenter_fn = create_augmenter_fn(augmenters)
400
train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)
401
402
image_batch = next(iter(train_ds.take(1)))["images"]
403
plot_image_gallery(
404
image_batch,
405
)
406
407
"""
408
We also need to resize our evaluation set to get dense batches of the image size
409
expected by our model. We directly use the deterministic `keras.layers.Resizing` in
410
this case to avoid adding noise to our evaluation metric due to applying random
411
augmentations.
412
"""
413
414
inference_resizing = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)
415
416
417
def do_resize(inputs):
418
inputs["images"] = inference_resizing(inputs["images"])
419
return inputs
420
421
422
eval_ds = eval_ds.map(do_resize, num_parallel_calls=tf.data.AUTOTUNE)
423
424
image_batch = next(iter(eval_ds.take(1)))["images"]
425
plot_image_gallery(
426
image_batch,
427
)
428
429
"""
430
Finally, lets unpackage our datasets and prepare to pass them to `model.fit()`,
431
which accepts a tuple of `(images, labels)`.
432
"""
433
434
435
def unpackage_dict(inputs):
436
return inputs["images"], inputs["labels"]
437
438
439
train_ds = train_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
440
eval_ds = eval_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)
441
442
"""
443
Data augmentation is by far the hardest piece of training a modern
444
classifier.
445
Congratulations on making it this far!
446
447
### Optimizer Tuning
448
449
To achieve optimal performance, we need to use a learning rate schedule instead
450
of a single learning rate. While we won't go into detail on the Cosine decay
451
with warmup schedule used here,
452
[you can read more about it here](https://scorrea92.medium.com/cosine-learning-rate-decay-e8b50aa455b).
453
"""
454
455
456
def lr_warmup_cosine_decay(
457
global_step,
458
warmup_steps,
459
hold=0,
460
total_steps=0,
461
start_lr=0.0,
462
target_lr=1e-2,
463
):
464
# Cosine decay
465
learning_rate = (
466
0.5
467
* target_lr
468
* (
469
1
470
+ ops.cos(
471
math.pi
472
* ops.convert_to_tensor(
473
global_step - warmup_steps - hold, dtype="float32"
474
)
475
/ ops.convert_to_tensor(
476
total_steps - warmup_steps - hold, dtype="float32"
477
)
478
)
479
)
480
)
481
482
warmup_lr = target_lr * (global_step / warmup_steps)
483
484
if hold > 0:
485
learning_rate = ops.where(
486
global_step > warmup_steps + hold, learning_rate, target_lr
487
)
488
489
learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
490
return learning_rate
491
492
493
class WarmUpCosineDecay(schedules.LearningRateSchedule):
494
def __init__(self, warmup_steps, total_steps, hold, start_lr=0.0, target_lr=1e-2):
495
super().__init__()
496
self.start_lr = start_lr
497
self.target_lr = target_lr
498
self.warmup_steps = warmup_steps
499
self.total_steps = total_steps
500
self.hold = hold
501
502
def __call__(self, step):
503
lr = lr_warmup_cosine_decay(
504
global_step=step,
505
total_steps=self.total_steps,
506
warmup_steps=self.warmup_steps,
507
start_lr=self.start_lr,
508
target_lr=self.target_lr,
509
hold=self.hold,
510
)
511
return ops.where(step > self.total_steps, 0.0, lr)
512
513
514
"""
515
![WarmUpCosineDecay schedule](https://i.imgur.com/YCr5pII.png)
516
517
The schedule looks a as we expect.
518
519
Next let's construct this optimizer:
520
"""
521
522
total_images = 9000
523
total_steps = (total_images // BATCH_SIZE) * EPOCHS
524
warmup_steps = int(0.1 * total_steps)
525
hold_steps = int(0.45 * total_steps)
526
schedule = WarmUpCosineDecay(
527
start_lr=0.05,
528
target_lr=1e-2,
529
warmup_steps=warmup_steps,
530
total_steps=total_steps,
531
hold=hold_steps,
532
)
533
optimizer = optimizers.SGD(
534
weight_decay=5e-4,
535
learning_rate=schedule,
536
momentum=0.9,
537
)
538
539
"""
540
At long last, we can now build our model and call `fit()`!
541
Here, we directly instantiate our `ResNetBackbone`, specifying all architectural
542
parameters, which gives us full control to tweak the architecture.
543
"""
544
545
backbone = keras_hub.models.ResNetBackbone(
546
input_conv_filters=[64],
547
input_conv_kernel_sizes=[7],
548
stackwise_num_filters=[64, 64, 64],
549
stackwise_num_blocks=[2, 2, 2],
550
stackwise_num_strides=[1, 2, 2],
551
block_type="basic_block",
552
)
553
model = keras.Sequential(
554
[
555
backbone,
556
keras.layers.GlobalMaxPooling2D(),
557
keras.layers.Dropout(rate=0.5),
558
keras.layers.Dense(101, activation="softmax"),
559
]
560
)
561
562
"""
563
We employ label smoothing to prevent the model from overfitting to artifacts of
564
our augmentation process.
565
"""
566
567
loss = losses.CategoricalCrossentropy(label_smoothing=0.1)
568
569
"""
570
Let's compile our model:
571
"""
572
573
model.compile(
574
loss=loss,
575
optimizer=optimizer,
576
metrics=[
577
metrics.CategoricalAccuracy(),
578
metrics.TopKCategoricalAccuracy(k=5),
579
],
580
)
581
582
"""
583
and finally call fit().
584
"""
585
586
model.fit(
587
train_ds,
588
epochs=EPOCHS,
589
validation_data=eval_ds,
590
)
591
592
"""
593
Congratulations! You now know how to train a powerful image classifier from
594
scratch using KerasHub.
595
Depending on the availability of labeled data for your application, training
596
from scratch may or may not be more powerful than using transfer learning in
597
addition to the data augmentations discussed above. For smaller datasets,
598
pretrained models generally produce high accuracy and faster convergence.
599
"""
600
601
"""
602
## Conclusions
603
604
While image classification is perhaps the simplest problem in computer vision,
605
the modern landscape has numerous complex components.
606
Luckily, KerasHub offers robust, production-grade APIs to make assembling most
607
of these components possible in one line of code.
608
Through the use of KerasHub's `ImageClassifier` API, pretrained weights, and
609
Keras' data augmentations you can assemble everything you need to train a
610
powerful classifier in a few hundred lines of code!
611
612
As a follow up exercise, try fine tuning a KerasHub classifier on your own dataset!
613
"""
614
615