Path: blob/master/guides/keras_cv/custom_image_augmentations.py
3283 views
"""1Title: Custom Image Augmentations with BaseImageAugmentationLayer2Author: [lukewood](https://twitter.com/luke_wood_ml)3Date created: 2022/04/264Last modified: 2023/11/295Description: Use BaseImageAugmentationLayer to implement custom data augmentations.6Accelerator: None7"""89"""10## Overview11Data augmentation is an integral part of training any robust computer vision model.12While KerasCV offers a plethora of prebuild high quality data augmentation techniques,13you may still want to implement your own custom technique.14KerasCV offers a helpful base class for writing data augmentation layers:15`BaseImageAugmentationLayer`.16Any augmentation layer built with `BaseImageAugmentationLayer` will automatically be17compatible with the KerasCV `RandomAugmentationPipeline` class.1819This guide will show you how to implement your own custom augmentation layers using20`BaseImageAugmentationLayer`. As an example, we will implement a layer that tints all21images blue.2223Currently, KerasCV's preprocessing layers only support the TensorFlow backend with Keras 3.24"""2526"""shell27pip install -q --upgrade keras-cv28pip install -q --upgrade keras # Upgrade to Keras 329"""3031import os3233os.environ["KERAS_BACKEND"] = "tensorflow"3435import keras36from keras import ops37from keras import layers38import keras_cv39import matplotlib.pyplot as plt4041"""42First, let's implement some helper functions for visualization and some transformations.43"""444546def imshow(img):47img = img.astype(int)48plt.axis("off")49plt.imshow(img)50plt.show()515253def gallery_show(images):54images = images.astype(int)55for i in range(9):56image = images[i]57plt.subplot(3, 3, i + 1)58plt.imshow(image.astype("uint8"))59plt.axis("off")60plt.show()616263def transform_value_range(images, original_range, target_range):64images = (images - original_range[0]) / (original_range[1] - original_range[0])65scale_factor = target_range[1] - target_range[0]66return (images * scale_factor) + target_range[0]676869def parse_factor(param, min_value=0.0, max_value=1.0, seed=None):70if isinstance(param, keras_cv.core.FactorSampler):71return param72if isinstance(param, float) or isinstance(param, int):73param = (min_value, param)74if param[0] == param[1]:75return keras_cv.core.ConstantFactorSampler(param[0])76return keras_cv.core.UniformFactorSampler(param[0], param[1], seed=seed)777879"""80## BaseImageAugmentationLayer Introduction8182Image augmentation should operate on a sample-wise basis; not batch-wise.83This is a common mistake many machine learning practitioners make when implementing84custom techniques.85`BaseImageAugmentation` offers a set of clean abstractions to make implementing image86augmentation techniques on a sample wise basis much easier.87This is done by allowing the end user to override an `augment_image()` method and then88performing automatic vectorization under the hood.8990Most augmentation techniques also must sample from one or more random distributions.91KerasCV offers an abstraction to make random sampling end user configurable: the92`FactorSampler` API.9394Finally, many augmentation techniques requires some information about the pixel values95present in the input images. KerasCV offers the `value_range` API to simplify the handling of this.9697In our example, we will use the `FactorSampler` API, the `value_range` API, and98`BaseImageAugmentationLayer` to implement a robust, configurable, and correct `RandomBlueTint` layer.99100## Overriding `augment_image()`101102Let's start off with the minimum:103"""104105106class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):107def augment_image(self, image, *args, transformation=None, **kwargs):108# image is of shape (height, width, channels)109[*others, blue] = ops.unstack(image, axis=-1)110blue = ops.clip(blue + 100, 0.0, 255.0)111return ops.stack([*others, blue], axis=-1)112113114"""115Our layer overrides `BaseImageAugmentationLayer.augment_image()`. This method is116used to augment images given to the layer. By default, using117`BaseImageAugmentationLayer` gives you a few nice features for free:118119- support for unbatched inputs (HWC Tensor)120- support for batched inputs (BHWC Tensor)121- automatic vectorization on batched inputs (more information on this in automatic122vectorization performance)123124Let's check out the result. First, let's download a sample image:125"""126127SIZE = (300, 300)128elephants = keras.utils.get_file(129"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"130)131elephants = keras.utils.load_img(elephants, target_size=SIZE)132elephants = keras.utils.img_to_array(elephants)133imshow(elephants)134135"""136Next, let's augment it and visualize the result:137"""138139layer = RandomBlueTint()140augmented = layer(elephants)141imshow(ops.convert_to_numpy(augmented))142143"""144Looks great! We can also call our layer on batched inputs:145"""146147layer = RandomBlueTint()148augmented = layer(ops.expand_dims(elephants, axis=0))149imshow(ops.convert_to_numpy(augmented)[0])150151"""152## Adding Random Behavior with the `FactorSampler` API.153154Usually an image augmentation technique should not do the same thing on every155invocation of the layer's `__call__` method.156KerasCV offers the `FactorSampler` API to allow users to provide configurable random157distributions.158"""159160161class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):162"""RandomBlueTint randomly applies a blue tint to images.163164Args:165factor: A tuple of two floats, a single float or a166`keras_cv.FactorSampler`. `factor` controls the extent to which the167image is blue shifted. `factor=0.0` makes this layer perform a no-op168operation, while a value of 1.0 uses the degenerated result entirely.169Values between 0 and 1 result in linear interpolation between the original170image and a fully blue image.171Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is172sampled between the two values for every image augmented. If a single float173is used, a value between `0.0` and the passed float is sampled. In order to174ensure the value is always the same, please pass a tuple with two identical175floats: `(0.5, 0.5)`.176"""177178def __init__(self, factor, **kwargs):179super().__init__(**kwargs)180self.factor = parse_factor(factor)181182def augment_image(self, image, *args, transformation=None, **kwargs):183[*others, blue] = ops.unstack(image, axis=-1)184blue_shift = self.factor() * 255185blue = ops.clip(blue + blue_shift, 0.0, 255.0)186return ops.stack([*others, blue], axis=-1)187188189"""190Now, we can configure the random behavior of ou `RandomBlueTint` layer.191We can give it a range of values to sample from:192"""193194many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)195layer = RandomBlueTint(factor=0.5)196augmented = layer(many_elephants)197gallery_show(ops.convert_to_numpy(augmented))198199"""200Each image is augmented differently with a random factor sampled from the range201`(0, 0.5)`.202203We can also configure the layer to draw from a normal distribution:204"""205206many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)207factor = keras_cv.core.NormalFactorSampler(208mean=0.3, stddev=0.1, min_value=0.0, max_value=1.0209)210layer = RandomBlueTint(factor=factor)211augmented = layer(many_elephants)212gallery_show(ops.convert_to_numpy(augmented))213214"""215As you can see, the augmentations now are drawn from a normal distributions.216There are various types of `FactorSamplers` including `UniformFactorSampler`,217`NormalFactorSampler`, and `ConstantFactorSampler`. You can also implement you own.218219## Overriding `get_random_transformation()`220221Now, suppose that your layer impacts the prediction targets: whether they are bounding222boxes, classification labels, or regression targets.223Your layer will need to have information about what augmentations are taken on the image224when augmenting the label.225Luckily, `BaseImageAugmentationLayer` was designed with this in mind.226227To handle this issue, `BaseImageAugmentationLayer` has an overridable228`get_random_transformation()` method alongside with `augment_label()`,229`augment_target()` and `augment_bounding_boxes()`.230`augment_segmentation_map()` and others will be added in the future.231232Let's add this to our layer.233"""234235236class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):237"""RandomBlueTint randomly applies a blue tint to images.238239Args:240factor: A tuple of two floats, a single float or a241`keras_cv.FactorSampler`. `factor` controls the extent to which the242image is blue shifted. `factor=0.0` makes this layer perform a no-op243operation, while a value of 1.0 uses the degenerated result entirely.244Values between 0 and 1 result in linear interpolation between the original245image and a fully blue image.246Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is247sampled between the two values for every image augmented. If a single float248is used, a value between `0.0` and the passed float is sampled. In order to249ensure the value is always the same, please pass a tuple with two identical250floats: `(0.5, 0.5)`.251"""252253def __init__(self, factor, **kwargs):254super().__init__(**kwargs)255self.factor = parse_factor(factor)256257def get_random_transformation(self, **kwargs):258# kwargs holds {"images": image, "labels": label, etc...}259return self.factor() * 255260261def augment_image(self, image, transformation=None, **kwargs):262[*others, blue] = ops.unstack(image, axis=-1)263blue = ops.clip(blue + transformation, 0.0, 255.0)264return ops.stack([*others, blue], axis=-1)265266def augment_label(self, label, transformation=None, **kwargs):267# you can use transformation somehow if you want268269if transformation > 100:270# i.e. maybe class 2 corresponds to blue images271return 2.0272273return label274275def augment_bounding_boxes(self, bounding_boxes, transformation=None, **kwargs):276# you can also perform no-op augmentations on label types to support them in277# your pipeline.278return bounding_boxes279280281"""282To make use of these new methods, you will need to feed your inputs in with a283dictionary maintaining a mapping from images to targets.284285As of now, KerasCV supports the following label types:286287- labels via `augment_label()`.288- bounding_boxes via `augment_bounding_boxes()`.289290In order to use augmention layers alongside your prediction targets, you must package291your inputs as follows:292"""293294labels = ops.array([[1, 0]])295inputs = {"images": ops.convert_to_tensor(elephants), "labels": labels}296297"""298Now if we call our layer on the inputs:299"""300301layer = RandomBlueTint(factor=(0.6, 0.6))302augmented = layer(inputs)303print(augmented["labels"])304305"""306Both the inputs and labels are augmented.307Note how when `transformation` is > 100 the label is modified to contain 2.0 as308specified in the layer above.309310## `value_range` support311312Imagine you are using your new augmentation layer in many pipelines.313Some pipelines have values in the range `[0, 255]`, some pipelines have normalized their314images to the range `[-1, 1]`, and some use a value range of `[0, 1]`.315316If a user calls your layer with an image in value range `[0, 1]`, the outputs will be317nonsense!318"""319320layer = RandomBlueTint(factor=(0.1, 0.1))321elephants_0_1 = elephants / 255322print("min and max before augmentation:", elephants_0_1.min(), elephants_0_1.max())323augmented = layer(elephants_0_1)324print(325"min and max after augmentation:",326ops.convert_to_numpy(augmented).min(),327ops.convert_to_numpy(augmented).max(),328)329imshow(ops.convert_to_numpy(augmented * 255).astype(int))330331"""332Note that this is an incredibly weak augmentation!333Factor is only set to 0.1.334335Let's resolve this issue with KerasCV's `value_range` API.336"""337338339class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):340"""RandomBlueTint randomly applies a blue tint to images.341342Args:343value_range: value_range: a tuple or a list of two elements. The first value344represents the lower bound for values in passed images, the second represents345the upper bound. Images passed to the layer should have values within346`value_range`.347factor: A tuple of two floats, a single float or a348`keras_cv.FactorSampler`. `factor` controls the extent to which the349image is blue shifted. `factor=0.0` makes this layer perform a no-op350operation, while a value of 1.0 uses the degenerated result entirely.351Values between 0 and 1 result in linear interpolation between the original352image and a fully blue image.353Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is354sampled between the two values for every image augmented. If a single float355is used, a value between `0.0` and the passed float is sampled. In order to356ensure the value is always the same, please pass a tuple with two identical357floats: `(0.5, 0.5)`.358"""359360def __init__(self, value_range, factor, **kwargs):361super().__init__(**kwargs)362self.value_range = value_range363self.factor = parse_factor(factor)364365def get_random_transformation(self, **kwargs):366# kwargs holds {"images": image, "labels": label, etc...}367return self.factor() * 255368369def augment_image(self, image, transformation=None, **kwargs):370image = transform_value_range(image, self.value_range, (0, 255))371[*others, blue] = ops.unstack(image, axis=-1)372blue = ops.clip(blue + transformation, 0.0, 255.0)373result = ops.stack([*others, blue], axis=-1)374result = transform_value_range(result, (0, 255), self.value_range)375return result376377def augment_label(self, label, transformation=None, **kwargs):378# you can use transformation somehow if you want379380if transformation > 100:381# i.e. maybe class 2 corresponds to blue images382return 2.0383384return label385386def augment_bounding_boxes(self, bounding_boxes, transformation=None, **kwargs):387# you can also perform no-op augmentations on label types to support them in388# your pipeline.389return bounding_boxes390391392layer = RandomBlueTint(value_range=(0, 1), factor=(0.1, 0.1))393elephants_0_1 = elephants / 255394print("min and max before augmentation:", elephants_0_1.min(), elephants_0_1.max())395augmented = layer(elephants_0_1)396print(397"min and max after augmentation:",398ops.convert_to_numpy(augmented).min(),399ops.convert_to_numpy(augmented).max(),400)401imshow(ops.convert_to_numpy(augmented * 255).astype(int))402403"""404Now our elephants are only slgihtly blue tinted. This is the expected behavior when405using a factor of `0.1`. Great!406407Now users can configure the layer to support any value range they may need. Note that408only layers that interact with color information should use the value range API.409Many augmentation techniques, such as `RandomRotation` will not need this.410411## Auto vectorization performance412413If you are wondering:414415> Does implementing my augmentations on an sample-wise basis carry performance416implications?417418You are not alone!419420Luckily, I have performed extensive analysis on the performance of automatic421vectorization, manual vectorization, and unvectorized implementations.422In this benchmark, I implemented a RandomCutout layer using auto vectorization, no auto423vectorization and manual vectorization.424All of these were benchmarked inside of an `@tf.function` annotation.425They were also each benchmarked with the `jit_compile` argument.426427The following chart shows the results of this benchmark:428429430431_The primary takeaway should be that the difference between manual vectorization and432automatic vectorization is marginal!_433434Please note that Eager mode performance will be drastically different.435436## Common gotchas437438Some layers are not able to be automatically vectorizated.439An example of this is [GridMask](https://tinyurl.com/ffb5zzf7).440441If you receive an error when invoking your layer, try adding the following to your442constructor:443"""444445446class UnVectorizable(keras_cv.layers.BaseImageAugmentationLayer):447def __init__(self, **kwargs):448super().__init__(**kwargs)449# this disables BaseImageAugmentationLayer's Auto Vectorization450self.auto_vectorize = False451452453"""454Additionally, be sure to accept `**kwargs` to your `augment_*` methods to ensure455forwards compatibility. KerasCV will add additional label types in the future, and456if you do not include a `**kwargs` argument your augmentation layers will not be457forward compatible.458459## Conclusion and next steps460461KerasCV offers a standard set of APIs to streamline the process of implementing your462own data augmentation techniques.463These include `BaseImageAugmentationLayer`, the `FactorSampler` API and the464`value_range` API.465466We used these APIs to implement a highly configurable `RandomBlueTint` layer.467This layer can take inputs as standalone images, a dictionary with keys of `"images"`468and labels, inputs that are unbatched, or inputs that are batched. Inputs may be in any469value range, and the random distribution used to sample the tint values is end user470configurable.471472As a follow up exercises you can:473474- implement your own data augmentation technique using `BaseImageAugmentationLayer`475- [contribute an augmentation layer to KerasCV](https://github.com/keras-team/keras-cv)476- [read through the existing KerasCV augmentation layers](https://tinyurl.com/4txy4m3t)477"""478479480