Path: blob/master/guides/keras_hub/classification_with_keras_hub.py
3293 views
"""1Title: Image Classification with KerasHub2Author: [Gowtham Paimagam](https://github.com/gowthamkpr), [lukewood](https://lukewood.xyz)3Date created: 09/24/20244Last modified: 10/04/20245Description: Use KerasHub to train powerful image classifiers.6Accelerator: GPU7"""89"""10Classification is the process of predicting a categorical label for a given11input image.12While classification is a relatively straightforward computer vision task,13modern approaches still are built of several complex components.14Luckily, Keras provides APIs to construct commonly used components.1516This guide demonstrates KerasHub's modular approach to solving image17classification problems at three levels of complexity:1819- Inference with a pretrained classifier20- Fine-tuning a pretrained backbone21- Training a image classifier from scratch2223KerasHub uses Keras 3 to work with any of TensorFlow, PyTorch or Jax. In the24guide below, we will use the `jax` backend. This guide runs in25TensorFlow or PyTorch backends with zero changes, simply update the26`KERAS_BACKEND` below.2728We use Professor Keras, the official Keras mascot, as a29visual reference for the complexity of the material:303132"""3334"""shell35!pip install -q git+https://github.com/keras-team/keras-hub.git36!pip install -q --upgrade keras # Upgrade to Keras 3.37"""3839import os4041os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]4243import math44import numpy as np45import matplotlib.pyplot as plt4647import keras48from keras import losses49from keras import ops50from keras import optimizers51from keras.optimizers import schedules52from keras import metrics53from keras.applications.imagenet_utils import decode_predictions54import keras_hub5556# Import tensorflow for `tf.data` and its preprocessing functions57import tensorflow as tf58import tensorflow_datasets as tfds596061"""62## Inference with a pretrained classifier63646566Let's get started with the simplest KerasHub API: a pretrained classifier.67In this example, we will construct a classifier that was68pretrained on the ImageNet dataset.69We'll use this model to solve the age old "Cat or Dog" problem.7071The highest level module in KerasHub is a *task*. A *task* is a `keras.Model`72consisting of a (generally pretrained) backbone model and task-specific layers.73Here's an example using `keras_hub.models.ImageClassifier` with an74ResNet Backbone.7576ResNet is a great starting model when constructing an image77classification pipeline.78This architecture manages to achieve high accuracy, while using a79compact parameter count.80If a ResNet is not powerful enough for the task you are hoping to81solve, be sure to check out82[KerasHub's other available Backbones](https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models)!83"""8485classifier = keras_hub.models.ImageClassifier.from_preset("resnet_v2_50_imagenet")8687"""88You may notice a small deviation from the old `keras.applications` API; where89you would construct the class with `Resnet50V2(weights="imagenet")`.90While the old API was great for classification, it did not scale effectively to91other use cases that required complex architectures, like object detection and92semantic segmentation.9394We first create a utility function for plotting images throughout this tutorial:95"""969798def plot_image_gallery(images, titles=None, num_cols=3, figsize=(6, 12)):99num_images = len(images)100images = np.asarray(images) / 255.0101images = np.minimum(np.maximum(images, 0.0), 1.0)102num_rows = (num_images + num_cols - 1) // num_cols103fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, squeeze=False)104axes = axes.flatten() # Flatten in case the axes is a 2D array105106for i, ax in enumerate(axes):107if i < num_images:108# Plot the image109ax.imshow(images[i])110ax.axis("off") # Remove axis111if titles and len(titles) > i:112ax.set_title(titles[i], fontsize=12)113else:114# Turn off the axis for any empty subplot115ax.axis("off")116117plt.show()118plt.close()119120121"""122Now that our classifier is built, let's apply it to this cute cat picture!123"""124125filepath = keras.utils.get_file(126origin="https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/5hR96puA_VA.jpg/1024px-5hR96puA_VA.jpg"127)128image = keras.utils.load_img(filepath)129image = np.array([image])130plot_image_gallery(image, num_cols=1, figsize=(3, 3))131132"""133Next, let's get some predictions from our classifier:134"""135136predictions = classifier.predict(image)137138"""139Predictions come in the form of softmax-ed category rankings.140We can use Keras' `imagenet_utils.decode_predictions` function to map141them to class names:142"""143144print(f"Top two classes are:\n{decode_predictions(predictions, top=2)}")145146"""147Great! Both of these appear to be correct!148However, one of the classes is "Bath towel".149We're trying to classify Cats VS Dogs.150We don't care about the towel!151152Ideally, we'd have a classifier that only performs computation to determine if153an image is a cat or a dog, and has all of its resources dedicated to this task.154This can be solved by fine tuning our own classifier.155156## Fine tuning a pretrained classifier157158159160When labeled images specific to our task are available, fine-tuning a custom161classifier can improve performance.162If we want to train a Cats vs Dogs Classifier, using explicitly labeled Cat vs163Dog data should perform better than the generic classifier!164For many tasks, no relevant pretrained model165will be available (e.g., categorizing images specific to your application).166167First, let's get started by loading some data:168"""169170BATCH_SIZE = 32171IMAGE_SIZE = (224, 224)172AUTOTUNE = tf.data.AUTOTUNE173tfds.disable_progress_bar()174175data, dataset_info = tfds.load("cats_vs_dogs", with_info=True, as_supervised=True)176train_steps_per_epoch = dataset_info.splits["train"].num_examples // BATCH_SIZE177train_dataset = data["train"]178179num_classes = dataset_info.features["label"].num_classes180181resizing = keras.layers.Resizing(182IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True183)184185186def preprocess_inputs(image, label):187image = tf.cast(image, tf.float32)188# Staticly resize images as we only iterate the dataset once.189return resizing(image), tf.one_hot(label, num_classes)190191192# Shuffle the dataset to increase diversity of batches.193# 10*BATCH_SIZE follows the assumption that bigger machines can handle bigger194# shuffle buffers.195train_dataset = train_dataset.shuffle(19610 * BATCH_SIZE, reshuffle_each_iteration=True197).map(preprocess_inputs, num_parallel_calls=AUTOTUNE)198train_dataset = train_dataset.batch(BATCH_SIZE)199200images = next(iter(train_dataset.take(1)))[0]201plot_image_gallery(images)202203"""204Meow!205206Next let's construct our model.207The use of imagenet in the preset name indicates that the backbone was208pretrained on the ImageNet dataset.209Pretrained backbones extract more information from our labeled examples by210leveraging patterns extracted from potentially much larger datasets.211212Next lets put together our classifier:213"""214215model = keras_hub.models.ImageClassifier.from_preset(216"resnet_v2_50_imagenet", num_classes=2217)218model.compile(219loss="categorical_crossentropy",220optimizer=keras.optimizers.SGD(learning_rate=0.01),221metrics=["accuracy"],222)223224"""225Here our classifier is just a simple `keras.Sequential`.226All that is left to do is call `model.fit()`:227"""228229model.fit(train_dataset)230231232"""233Let's look at how our model performs after the fine tuning:234"""235236predictions = model.predict(image)237238classes = {0: "cat", 1: "dog"}239print("Top class is:", classes[predictions[0].argmax()])240241"""242Awesome - looks like the model correctly classified the image.243"""244245"""246## Train a Classifier from Scratch247248249250Now that we've gotten our hands dirty with classification, let's take on one251last task: training a classification model from scratch!252A standard benchmark for image classification is the ImageNet dataset, however253due to licensing constraints we will use the CalTech 101 image classification254dataset in this tutorial.255While we use the simpler CalTech 101 dataset in this guide, the same training256template may be used on ImageNet to achieve near state-of-the-art scores.257258Let's start out by tackling data loading:259"""260261BATCH_SIZE = 32262NUM_CLASSES = 101263IMAGE_SIZE = (224, 224)264265# Change epochs to 100~ to fully train.266EPOCHS = 1267268269def package_inputs(image, label):270return {"images": image, "labels": tf.one_hot(label, NUM_CLASSES)}271272273train_ds, eval_ds = tfds.load(274"caltech101", split=["train", "test"], as_supervised="true"275)276train_ds = train_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)277eval_ds = eval_ds.map(package_inputs, num_parallel_calls=tf.data.AUTOTUNE)278279train_ds = train_ds.shuffle(BATCH_SIZE * 16)280augmenters = []281282"""283The CalTech101 dataset has different sizes for every image, so we resize images before284batching them using the285`batch()` API.286"""287288resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)289train_ds = train_ds.map(resize)290eval_ds = eval_ds.map(resize)291292train_ds = train_ds.batch(BATCH_SIZE)293eval_ds = eval_ds.batch(BATCH_SIZE)294295batch = next(iter(train_ds.take(1)))296image_batch = batch["images"]297label_batch = batch["labels"]298299plot_image_gallery(300image_batch,301)302303"""304### Data Augmentation305306In our previous finetuning example, we performed a static resizing operation and307did not utilize any image augmentation.308This is because a single pass over the training set was sufficient to achieve309decent results.310When training to solve a more difficult task, you'll want to include data311augmentation in your data pipeline.312313Data augmentation is a technique to make your model robust to changes in input314data such as lighting, cropping, and orientation.315Keras includes some of the most useful augmentations in the `keras.layers`316API.317Creating an optimal pipeline of augmentations is an art, but in this section of318the guide we'll offer some tips on best practices for classification.319320One caveat to be aware of with image data augmentation is that you must be careful321to not shift your augmented data distribution too far from the original data322distribution.323The goal is to prevent overfitting and increase generalization,324but samples that lie completely out of the data distribution simply add noise to325the training process.326327The first augmentation we'll use is `RandomFlip`.328This augmentation behaves more or less how you'd expect: it either flips the329image or not.330While this augmentation is useful in CalTech101 and ImageNet, it should be noted331that it should not be used on tasks where the data distribution is not vertical332mirror invariant.333An example of a dataset where this occurs is MNIST hand written digits.334Flipping a `6` over the335vertical axis will make the digit appear more like a `7` than a `6`, but the336label will still show a `6`.337"""338339random_flip = keras.layers.RandomFlip()340augmenters += [random_flip]341342image_batch = random_flip(image_batch)343plot_image_gallery(image_batch)344345"""346Half of the images have been flipped!347348The next augmentation we'll use is `RandomCrop`.349This operation selects a random subset of the image.350By using this augmentation, we force our classifier to become spatially invariant.351352Let's add a `RandomCrop` to our set of augmentations:353"""354355crop = keras.layers.RandomCrop(356int(IMAGE_SIZE[0] * 0.9),357int(IMAGE_SIZE[1] * 0.9),358)359360augmenters += [crop]361362image_batch = crop(image_batch)363plot_image_gallery(364image_batch,365)366367"""368We can also rotate images by a random angle using Keras' `RandomRotation` layer. Let's369apply a rotation by a randomly selected angle in the interval -45°...45°:370"""371372rotate = keras.layers.RandomRotation((-45 / 360, 45 / 360))373374augmenters += [rotate]375376image_batch = rotate(image_batch)377plot_image_gallery(image_batch)378379resize = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)380augmenters += [resize]381382image_batch = resize(image_batch)383plot_image_gallery(image_batch)384385"""386Now let's apply our final augmenter to the training data:387"""388389390def create_augmenter_fn(augmenters):391def augmenter_fn(inputs):392for augmenter in augmenters:393inputs["images"] = augmenter(inputs["images"])394return inputs395396return augmenter_fn397398399augmenter_fn = create_augmenter_fn(augmenters)400train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf.data.AUTOTUNE)401402image_batch = next(iter(train_ds.take(1)))["images"]403plot_image_gallery(404image_batch,405)406407"""408We also need to resize our evaluation set to get dense batches of the image size409expected by our model. We directly use the deterministic `keras.layers.Resizing` in410this case to avoid adding noise to our evaluation metric due to applying random411augmentations.412"""413414inference_resizing = keras.layers.Resizing(*IMAGE_SIZE, crop_to_aspect_ratio=True)415416417def do_resize(inputs):418inputs["images"] = inference_resizing(inputs["images"])419return inputs420421422eval_ds = eval_ds.map(do_resize, num_parallel_calls=tf.data.AUTOTUNE)423424image_batch = next(iter(eval_ds.take(1)))["images"]425plot_image_gallery(426image_batch,427)428429"""430Finally, lets unpackage our datasets and prepare to pass them to `model.fit()`,431which accepts a tuple of `(images, labels)`.432"""433434435def unpackage_dict(inputs):436return inputs["images"], inputs["labels"]437438439train_ds = train_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)440eval_ds = eval_ds.map(unpackage_dict, num_parallel_calls=tf.data.AUTOTUNE)441442"""443Data augmentation is by far the hardest piece of training a modern444classifier.445Congratulations on making it this far!446447### Optimizer Tuning448449To achieve optimal performance, we need to use a learning rate schedule instead450of a single learning rate. While we won't go into detail on the Cosine decay451with warmup schedule used here,452[you can read more about it here](https://scorrea92.medium.com/cosine-learning-rate-decay-e8b50aa455b).453"""454455456def lr_warmup_cosine_decay(457global_step,458warmup_steps,459hold=0,460total_steps=0,461start_lr=0.0,462target_lr=1e-2,463):464# Cosine decay465learning_rate = (4660.5467* target_lr468* (4691470+ ops.cos(471math.pi472* ops.convert_to_tensor(473global_step - warmup_steps - hold, dtype="float32"474)475/ ops.convert_to_tensor(476total_steps - warmup_steps - hold, dtype="float32"477)478)479)480)481482warmup_lr = target_lr * (global_step / warmup_steps)483484if hold > 0:485learning_rate = ops.where(486global_step > warmup_steps + hold, learning_rate, target_lr487)488489learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)490return learning_rate491492493class WarmUpCosineDecay(schedules.LearningRateSchedule):494def __init__(self, warmup_steps, total_steps, hold, start_lr=0.0, target_lr=1e-2):495super().__init__()496self.start_lr = start_lr497self.target_lr = target_lr498self.warmup_steps = warmup_steps499self.total_steps = total_steps500self.hold = hold501502def __call__(self, step):503lr = lr_warmup_cosine_decay(504global_step=step,505total_steps=self.total_steps,506warmup_steps=self.warmup_steps,507start_lr=self.start_lr,508target_lr=self.target_lr,509hold=self.hold,510)511return ops.where(step > self.total_steps, 0.0, lr)512513514"""515516517The schedule looks a as we expect.518519Next let's construct this optimizer:520"""521522total_images = 9000523total_steps = (total_images // BATCH_SIZE) * EPOCHS524warmup_steps = int(0.1 * total_steps)525hold_steps = int(0.45 * total_steps)526schedule = WarmUpCosineDecay(527start_lr=0.05,528target_lr=1e-2,529warmup_steps=warmup_steps,530total_steps=total_steps,531hold=hold_steps,532)533optimizer = optimizers.SGD(534weight_decay=5e-4,535learning_rate=schedule,536momentum=0.9,537)538539"""540At long last, we can now build our model and call `fit()`!541Here, we directly instantiate our `ResNetBackbone`, specifying all architectural542parameters, which gives us full control to tweak the architecture.543"""544545backbone = keras_hub.models.ResNetBackbone(546input_conv_filters=[64],547input_conv_kernel_sizes=[7],548stackwise_num_filters=[64, 64, 64],549stackwise_num_blocks=[2, 2, 2],550stackwise_num_strides=[1, 2, 2],551block_type="basic_block",552)553model = keras.Sequential(554[555backbone,556keras.layers.GlobalMaxPooling2D(),557keras.layers.Dropout(rate=0.5),558keras.layers.Dense(101, activation="softmax"),559]560)561562"""563We employ label smoothing to prevent the model from overfitting to artifacts of564our augmentation process.565"""566567loss = losses.CategoricalCrossentropy(label_smoothing=0.1)568569"""570Let's compile our model:571"""572573model.compile(574loss=loss,575optimizer=optimizer,576metrics=[577metrics.CategoricalAccuracy(),578metrics.TopKCategoricalAccuracy(k=5),579],580)581582"""583and finally call fit().584"""585586model.fit(587train_ds,588epochs=EPOCHS,589validation_data=eval_ds,590)591592"""593Congratulations! You now know how to train a powerful image classifier from594scratch using KerasHub.595Depending on the availability of labeled data for your application, training596from scratch may or may not be more powerful than using transfer learning in597addition to the data augmentations discussed above. For smaller datasets,598pretrained models generally produce high accuracy and faster convergence.599"""600601"""602## Conclusions603604While image classification is perhaps the simplest problem in computer vision,605the modern landscape has numerous complex components.606Luckily, KerasHub offers robust, production-grade APIs to make assembling most607of these components possible in one line of code.608Through the use of KerasHub's `ImageClassifier` API, pretrained weights, and609Keras' data augmentations you can assemble everything you need to train a610powerful classifier in a few hundred lines of code!611612As a follow up exercise, try fine tuning a KerasHub classifier on your own dataset!613"""614615616