Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_cv/custom_image_augmentations.py
3283 views
1
"""
2
Title: Custom Image Augmentations with BaseImageAugmentationLayer
3
Author: [lukewood](https://twitter.com/luke_wood_ml)
4
Date created: 2022/04/26
5
Last modified: 2023/11/29
6
Description: Use BaseImageAugmentationLayer to implement custom data augmentations.
7
Accelerator: None
8
"""
9
10
"""
11
## Overview
12
Data augmentation is an integral part of training any robust computer vision model.
13
While KerasCV offers a plethora of prebuild high quality data augmentation techniques,
14
you may still want to implement your own custom technique.
15
KerasCV offers a helpful base class for writing data augmentation layers:
16
`BaseImageAugmentationLayer`.
17
Any augmentation layer built with `BaseImageAugmentationLayer` will automatically be
18
compatible with the KerasCV `RandomAugmentationPipeline` class.
19
20
This guide will show you how to implement your own custom augmentation layers using
21
`BaseImageAugmentationLayer`. As an example, we will implement a layer that tints all
22
images blue.
23
24
Currently, KerasCV's preprocessing layers only support the TensorFlow backend with Keras 3.
25
"""
26
27
"""shell
28
pip install -q --upgrade keras-cv
29
pip install -q --upgrade keras # Upgrade to Keras 3
30
"""
31
32
import os
33
34
os.environ["KERAS_BACKEND"] = "tensorflow"
35
36
import keras
37
from keras import ops
38
from keras import layers
39
import keras_cv
40
import matplotlib.pyplot as plt
41
42
"""
43
First, let's implement some helper functions for visualization and some transformations.
44
"""
45
46
47
def imshow(img):
48
img = img.astype(int)
49
plt.axis("off")
50
plt.imshow(img)
51
plt.show()
52
53
54
def gallery_show(images):
55
images = images.astype(int)
56
for i in range(9):
57
image = images[i]
58
plt.subplot(3, 3, i + 1)
59
plt.imshow(image.astype("uint8"))
60
plt.axis("off")
61
plt.show()
62
63
64
def transform_value_range(images, original_range, target_range):
65
images = (images - original_range[0]) / (original_range[1] - original_range[0])
66
scale_factor = target_range[1] - target_range[0]
67
return (images * scale_factor) + target_range[0]
68
69
70
def parse_factor(param, min_value=0.0, max_value=1.0, seed=None):
71
if isinstance(param, keras_cv.core.FactorSampler):
72
return param
73
if isinstance(param, float) or isinstance(param, int):
74
param = (min_value, param)
75
if param[0] == param[1]:
76
return keras_cv.core.ConstantFactorSampler(param[0])
77
return keras_cv.core.UniformFactorSampler(param[0], param[1], seed=seed)
78
79
80
"""
81
## BaseImageAugmentationLayer Introduction
82
83
Image augmentation should operate on a sample-wise basis; not batch-wise.
84
This is a common mistake many machine learning practitioners make when implementing
85
custom techniques.
86
`BaseImageAugmentation` offers a set of clean abstractions to make implementing image
87
augmentation techniques on a sample wise basis much easier.
88
This is done by allowing the end user to override an `augment_image()` method and then
89
performing automatic vectorization under the hood.
90
91
Most augmentation techniques also must sample from one or more random distributions.
92
KerasCV offers an abstraction to make random sampling end user configurable: the
93
`FactorSampler` API.
94
95
Finally, many augmentation techniques requires some information about the pixel values
96
present in the input images. KerasCV offers the `value_range` API to simplify the handling of this.
97
98
In our example, we will use the `FactorSampler` API, the `value_range` API, and
99
`BaseImageAugmentationLayer` to implement a robust, configurable, and correct `RandomBlueTint` layer.
100
101
## Overriding `augment_image()`
102
103
Let's start off with the minimum:
104
"""
105
106
107
class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):
108
def augment_image(self, image, *args, transformation=None, **kwargs):
109
# image is of shape (height, width, channels)
110
[*others, blue] = ops.unstack(image, axis=-1)
111
blue = ops.clip(blue + 100, 0.0, 255.0)
112
return ops.stack([*others, blue], axis=-1)
113
114
115
"""
116
Our layer overrides `BaseImageAugmentationLayer.augment_image()`. This method is
117
used to augment images given to the layer. By default, using
118
`BaseImageAugmentationLayer` gives you a few nice features for free:
119
120
- support for unbatched inputs (HWC Tensor)
121
- support for batched inputs (BHWC Tensor)
122
- automatic vectorization on batched inputs (more information on this in automatic
123
vectorization performance)
124
125
Let's check out the result. First, let's download a sample image:
126
"""
127
128
SIZE = (300, 300)
129
elephants = keras.utils.get_file(
130
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
131
)
132
elephants = keras.utils.load_img(elephants, target_size=SIZE)
133
elephants = keras.utils.img_to_array(elephants)
134
imshow(elephants)
135
136
"""
137
Next, let's augment it and visualize the result:
138
"""
139
140
layer = RandomBlueTint()
141
augmented = layer(elephants)
142
imshow(ops.convert_to_numpy(augmented))
143
144
"""
145
Looks great! We can also call our layer on batched inputs:
146
"""
147
148
layer = RandomBlueTint()
149
augmented = layer(ops.expand_dims(elephants, axis=0))
150
imshow(ops.convert_to_numpy(augmented)[0])
151
152
"""
153
## Adding Random Behavior with the `FactorSampler` API.
154
155
Usually an image augmentation technique should not do the same thing on every
156
invocation of the layer's `__call__` method.
157
KerasCV offers the `FactorSampler` API to allow users to provide configurable random
158
distributions.
159
"""
160
161
162
class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):
163
"""RandomBlueTint randomly applies a blue tint to images.
164
165
Args:
166
factor: A tuple of two floats, a single float or a
167
`keras_cv.FactorSampler`. `factor` controls the extent to which the
168
image is blue shifted. `factor=0.0` makes this layer perform a no-op
169
operation, while a value of 1.0 uses the degenerated result entirely.
170
Values between 0 and 1 result in linear interpolation between the original
171
image and a fully blue image.
172
Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is
173
sampled between the two values for every image augmented. If a single float
174
is used, a value between `0.0` and the passed float is sampled. In order to
175
ensure the value is always the same, please pass a tuple with two identical
176
floats: `(0.5, 0.5)`.
177
"""
178
179
def __init__(self, factor, **kwargs):
180
super().__init__(**kwargs)
181
self.factor = parse_factor(factor)
182
183
def augment_image(self, image, *args, transformation=None, **kwargs):
184
[*others, blue] = ops.unstack(image, axis=-1)
185
blue_shift = self.factor() * 255
186
blue = ops.clip(blue + blue_shift, 0.0, 255.0)
187
return ops.stack([*others, blue], axis=-1)
188
189
190
"""
191
Now, we can configure the random behavior of ou `RandomBlueTint` layer.
192
We can give it a range of values to sample from:
193
"""
194
195
many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)
196
layer = RandomBlueTint(factor=0.5)
197
augmented = layer(many_elephants)
198
gallery_show(ops.convert_to_numpy(augmented))
199
200
"""
201
Each image is augmented differently with a random factor sampled from the range
202
`(0, 0.5)`.
203
204
We can also configure the layer to draw from a normal distribution:
205
"""
206
207
many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)
208
factor = keras_cv.core.NormalFactorSampler(
209
mean=0.3, stddev=0.1, min_value=0.0, max_value=1.0
210
)
211
layer = RandomBlueTint(factor=factor)
212
augmented = layer(many_elephants)
213
gallery_show(ops.convert_to_numpy(augmented))
214
215
"""
216
As you can see, the augmentations now are drawn from a normal distributions.
217
There are various types of `FactorSamplers` including `UniformFactorSampler`,
218
`NormalFactorSampler`, and `ConstantFactorSampler`. You can also implement you own.
219
220
## Overriding `get_random_transformation()`
221
222
Now, suppose that your layer impacts the prediction targets: whether they are bounding
223
boxes, classification labels, or regression targets.
224
Your layer will need to have information about what augmentations are taken on the image
225
when augmenting the label.
226
Luckily, `BaseImageAugmentationLayer` was designed with this in mind.
227
228
To handle this issue, `BaseImageAugmentationLayer` has an overridable
229
`get_random_transformation()` method alongside with `augment_label()`,
230
`augment_target()` and `augment_bounding_boxes()`.
231
`augment_segmentation_map()` and others will be added in the future.
232
233
Let's add this to our layer.
234
"""
235
236
237
class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):
238
"""RandomBlueTint randomly applies a blue tint to images.
239
240
Args:
241
factor: A tuple of two floats, a single float or a
242
`keras_cv.FactorSampler`. `factor` controls the extent to which the
243
image is blue shifted. `factor=0.0` makes this layer perform a no-op
244
operation, while a value of 1.0 uses the degenerated result entirely.
245
Values between 0 and 1 result in linear interpolation between the original
246
image and a fully blue image.
247
Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is
248
sampled between the two values for every image augmented. If a single float
249
is used, a value between `0.0` and the passed float is sampled. In order to
250
ensure the value is always the same, please pass a tuple with two identical
251
floats: `(0.5, 0.5)`.
252
"""
253
254
def __init__(self, factor, **kwargs):
255
super().__init__(**kwargs)
256
self.factor = parse_factor(factor)
257
258
def get_random_transformation(self, **kwargs):
259
# kwargs holds {"images": image, "labels": label, etc...}
260
return self.factor() * 255
261
262
def augment_image(self, image, transformation=None, **kwargs):
263
[*others, blue] = ops.unstack(image, axis=-1)
264
blue = ops.clip(blue + transformation, 0.0, 255.0)
265
return ops.stack([*others, blue], axis=-1)
266
267
def augment_label(self, label, transformation=None, **kwargs):
268
# you can use transformation somehow if you want
269
270
if transformation > 100:
271
# i.e. maybe class 2 corresponds to blue images
272
return 2.0
273
274
return label
275
276
def augment_bounding_boxes(self, bounding_boxes, transformation=None, **kwargs):
277
# you can also perform no-op augmentations on label types to support them in
278
# your pipeline.
279
return bounding_boxes
280
281
282
"""
283
To make use of these new methods, you will need to feed your inputs in with a
284
dictionary maintaining a mapping from images to targets.
285
286
As of now, KerasCV supports the following label types:
287
288
- labels via `augment_label()`.
289
- bounding_boxes via `augment_bounding_boxes()`.
290
291
In order to use augmention layers alongside your prediction targets, you must package
292
your inputs as follows:
293
"""
294
295
labels = ops.array([[1, 0]])
296
inputs = {"images": ops.convert_to_tensor(elephants), "labels": labels}
297
298
"""
299
Now if we call our layer on the inputs:
300
"""
301
302
layer = RandomBlueTint(factor=(0.6, 0.6))
303
augmented = layer(inputs)
304
print(augmented["labels"])
305
306
"""
307
Both the inputs and labels are augmented.
308
Note how when `transformation` is > 100 the label is modified to contain 2.0 as
309
specified in the layer above.
310
311
## `value_range` support
312
313
Imagine you are using your new augmentation layer in many pipelines.
314
Some pipelines have values in the range `[0, 255]`, some pipelines have normalized their
315
images to the range `[-1, 1]`, and some use a value range of `[0, 1]`.
316
317
If a user calls your layer with an image in value range `[0, 1]`, the outputs will be
318
nonsense!
319
"""
320
321
layer = RandomBlueTint(factor=(0.1, 0.1))
322
elephants_0_1 = elephants / 255
323
print("min and max before augmentation:", elephants_0_1.min(), elephants_0_1.max())
324
augmented = layer(elephants_0_1)
325
print(
326
"min and max after augmentation:",
327
ops.convert_to_numpy(augmented).min(),
328
ops.convert_to_numpy(augmented).max(),
329
)
330
imshow(ops.convert_to_numpy(augmented * 255).astype(int))
331
332
"""
333
Note that this is an incredibly weak augmentation!
334
Factor is only set to 0.1.
335
336
Let's resolve this issue with KerasCV's `value_range` API.
337
"""
338
339
340
class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):
341
"""RandomBlueTint randomly applies a blue tint to images.
342
343
Args:
344
value_range: value_range: a tuple or a list of two elements. The first value
345
represents the lower bound for values in passed images, the second represents
346
the upper bound. Images passed to the layer should have values within
347
`value_range`.
348
factor: A tuple of two floats, a single float or a
349
`keras_cv.FactorSampler`. `factor` controls the extent to which the
350
image is blue shifted. `factor=0.0` makes this layer perform a no-op
351
operation, while a value of 1.0 uses the degenerated result entirely.
352
Values between 0 and 1 result in linear interpolation between the original
353
image and a fully blue image.
354
Values should be between `0.0` and `1.0`. If a tuple is used, a `factor` is
355
sampled between the two values for every image augmented. If a single float
356
is used, a value between `0.0` and the passed float is sampled. In order to
357
ensure the value is always the same, please pass a tuple with two identical
358
floats: `(0.5, 0.5)`.
359
"""
360
361
def __init__(self, value_range, factor, **kwargs):
362
super().__init__(**kwargs)
363
self.value_range = value_range
364
self.factor = parse_factor(factor)
365
366
def get_random_transformation(self, **kwargs):
367
# kwargs holds {"images": image, "labels": label, etc...}
368
return self.factor() * 255
369
370
def augment_image(self, image, transformation=None, **kwargs):
371
image = transform_value_range(image, self.value_range, (0, 255))
372
[*others, blue] = ops.unstack(image, axis=-1)
373
blue = ops.clip(blue + transformation, 0.0, 255.0)
374
result = ops.stack([*others, blue], axis=-1)
375
result = transform_value_range(result, (0, 255), self.value_range)
376
return result
377
378
def augment_label(self, label, transformation=None, **kwargs):
379
# you can use transformation somehow if you want
380
381
if transformation > 100:
382
# i.e. maybe class 2 corresponds to blue images
383
return 2.0
384
385
return label
386
387
def augment_bounding_boxes(self, bounding_boxes, transformation=None, **kwargs):
388
# you can also perform no-op augmentations on label types to support them in
389
# your pipeline.
390
return bounding_boxes
391
392
393
layer = RandomBlueTint(value_range=(0, 1), factor=(0.1, 0.1))
394
elephants_0_1 = elephants / 255
395
print("min and max before augmentation:", elephants_0_1.min(), elephants_0_1.max())
396
augmented = layer(elephants_0_1)
397
print(
398
"min and max after augmentation:",
399
ops.convert_to_numpy(augmented).min(),
400
ops.convert_to_numpy(augmented).max(),
401
)
402
imshow(ops.convert_to_numpy(augmented * 255).astype(int))
403
404
"""
405
Now our elephants are only slgihtly blue tinted. This is the expected behavior when
406
using a factor of `0.1`. Great!
407
408
Now users can configure the layer to support any value range they may need. Note that
409
only layers that interact with color information should use the value range API.
410
Many augmentation techniques, such as `RandomRotation` will not need this.
411
412
## Auto vectorization performance
413
414
If you are wondering:
415
416
> Does implementing my augmentations on an sample-wise basis carry performance
417
implications?
418
419
You are not alone!
420
421
Luckily, I have performed extensive analysis on the performance of automatic
422
vectorization, manual vectorization, and unvectorized implementations.
423
In this benchmark, I implemented a RandomCutout layer using auto vectorization, no auto
424
vectorization and manual vectorization.
425
All of these were benchmarked inside of an `@tf.function` annotation.
426
They were also each benchmarked with the `jit_compile` argument.
427
428
The following chart shows the results of this benchmark:
429
430
![Auto Vectorization Performance Chart](https://i.imgur.com/NeNhDoi.png)
431
432
_The primary takeaway should be that the difference between manual vectorization and
433
automatic vectorization is marginal!_
434
435
Please note that Eager mode performance will be drastically different.
436
437
## Common gotchas
438
439
Some layers are not able to be automatically vectorizated.
440
An example of this is [GridMask](https://tinyurl.com/ffb5zzf7).
441
442
If you receive an error when invoking your layer, try adding the following to your
443
constructor:
444
"""
445
446
447
class UnVectorizable(keras_cv.layers.BaseImageAugmentationLayer):
448
def __init__(self, **kwargs):
449
super().__init__(**kwargs)
450
# this disables BaseImageAugmentationLayer's Auto Vectorization
451
self.auto_vectorize = False
452
453
454
"""
455
Additionally, be sure to accept `**kwargs` to your `augment_*` methods to ensure
456
forwards compatibility. KerasCV will add additional label types in the future, and
457
if you do not include a `**kwargs` argument your augmentation layers will not be
458
forward compatible.
459
460
## Conclusion and next steps
461
462
KerasCV offers a standard set of APIs to streamline the process of implementing your
463
own data augmentation techniques.
464
These include `BaseImageAugmentationLayer`, the `FactorSampler` API and the
465
`value_range` API.
466
467
We used these APIs to implement a highly configurable `RandomBlueTint` layer.
468
This layer can take inputs as standalone images, a dictionary with keys of `"images"`
469
and labels, inputs that are unbatched, or inputs that are batched. Inputs may be in any
470
value range, and the random distribution used to sample the tint values is end user
471
configurable.
472
473
As a follow up exercises you can:
474
475
- implement your own data augmentation technique using `BaseImageAugmentationLayer`
476
- [contribute an augmentation layer to KerasCV](https://github.com/keras-team/keras-cv)
477
- [read through the existing KerasCV augmentation layers](https://tinyurl.com/4txy4m3t)
478
"""
479
480