Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mixup.py
8097 views
1
"""
2
Title: MixUp augmentation for image classification
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/03/06
5
Last modified: 2023/07/24
6
Description: Data augmentation using the mixup technique for image classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
"""
13
14
"""
15
_mixup_ is a *domain-agnostic* data augmentation technique proposed in [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
16
by Zhang et al. It's implemented with the following formulas:
17
18
![](https://i.ibb.co/DRyHYww/image.png)
19
20
(Note that the lambda values are values with the [0, 1] range and are sampled from the
21
[Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution).)
22
23
The technique is quite systematically named. We are literally mixing up the features and
24
their corresponding labels. Implementation-wise it's simple. Neural networks are prone
25
to [memorizing corrupt labels](https://arxiv.org/abs/1611.03530). mixup relaxes this by
26
combining different features with one another (same happens for the labels too) so that
27
a network does not get overconfident about the relationship between the features and
28
their labels.
29
30
mixup is specifically useful when we are not sure about selecting a set of augmentation
31
transforms for a given dataset, medical imaging datasets, for example. mixup can be
32
extended to a variety of data modalities such as computer vision, naturallanguage
33
processing, speech, and so on.
34
"""
35
36
"""
37
## Setup
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "tensorflow"
43
44
import numpy as np
45
import keras
46
import matplotlib.pyplot as plt
47
48
from keras import layers
49
50
# TF imports related to tf.data preprocessing
51
from tensorflow import data as tf_data
52
from tensorflow import image as tf_image
53
from tensorflow.random import gamma as tf_random_gamma
54
55
"""
56
## Prepare the dataset
57
58
In this example, we will be using the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. But this same recipe can
59
be used for other classification datasets as well.
60
"""
61
62
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
63
64
x_train = x_train.astype("float32") / 255.0
65
x_train = np.reshape(x_train, (-1, 28, 28, 1))
66
y_train = keras.ops.one_hot(y_train, 10)
67
68
x_test = x_test.astype("float32") / 255.0
69
x_test = np.reshape(x_test, (-1, 28, 28, 1))
70
y_test = keras.ops.one_hot(y_test, 10)
71
72
"""
73
## Define hyperparameters
74
"""
75
76
AUTO = tf_data.AUTOTUNE
77
BATCH_SIZE = 64
78
EPOCHS = 10
79
80
"""
81
## Convert the data into TensorFlow `Dataset` objects
82
"""
83
84
# Put aside a few samples to create our validation set
85
val_samples = 2000
86
x_val, y_val = x_train[:val_samples], y_train[:val_samples]
87
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]
88
89
train_ds_one = (
90
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
91
.shuffle(BATCH_SIZE * 100)
92
.batch(BATCH_SIZE)
93
)
94
train_ds_two = (
95
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
96
.shuffle(BATCH_SIZE * 100)
97
.batch(BATCH_SIZE)
98
)
99
# Because we will be mixing up the images and their corresponding labels, we will be
100
# combining two shuffled datasets from the same training data.
101
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))
102
103
val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)
104
105
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
106
107
"""
108
## Define the mixup technique function
109
110
To perform the mixup routine, we create new virtual datasets using the training data from
111
the same dataset, and apply a lambda value within the [0, 1] range sampled from a [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution)
112
— such that, for example, `new_x = lambda * x1 + (1 - lambda) * x2` (where
113
`x1` and `x2` are images) and the same equation is applied to the labels as well.
114
"""
115
116
117
def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
118
gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
119
gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
120
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
121
122
123
def mix_up(ds_one, ds_two, alpha=0.2):
124
# Unpack two datasets
125
images_one, labels_one = ds_one
126
images_two, labels_two = ds_two
127
batch_size = keras.ops.shape(images_one)[0]
128
129
# Sample lambda and reshape it to do the mixup
130
l = sample_beta_distribution(batch_size, alpha, alpha)
131
x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
132
y_l = keras.ops.reshape(l, (batch_size, 1))
133
134
# Perform mixup on both images and labels by combining a pair of images/labels
135
# (one from each dataset) into one image/label
136
images = images_one * x_l + images_two * (1 - x_l)
137
labels = labels_one * y_l + labels_two * (1 - y_l)
138
return (images, labels)
139
140
141
"""
142
**Note** that here , we are combining two images to create a single one. Theoretically,
143
we can combine as many we want but that comes at an increased computation cost. In
144
certain cases, it may not help improve the performance as well.
145
"""
146
147
"""
148
## Visualize the new augmented dataset
149
"""
150
151
# First create the new dataset using our `mix_up` utility
152
train_ds_mu = train_ds.map(
153
lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2),
154
num_parallel_calls=AUTO,
155
)
156
157
# Let's preview 9 samples from the dataset
158
sample_images, sample_labels = next(iter(train_ds_mu))
159
plt.figure(figsize=(10, 10))
160
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
161
ax = plt.subplot(3, 3, i + 1)
162
plt.imshow(image.numpy().squeeze())
163
print(label.numpy().tolist())
164
plt.axis("off")
165
166
"""
167
## Model building
168
"""
169
170
171
def get_training_model():
172
model = keras.Sequential(
173
[
174
layers.Input(shape=(28, 28, 1)),
175
layers.Conv2D(16, (5, 5), activation="relu"),
176
layers.MaxPooling2D(pool_size=(2, 2)),
177
layers.Conv2D(32, (5, 5), activation="relu"),
178
layers.MaxPooling2D(pool_size=(2, 2)),
179
layers.Dropout(0.2),
180
layers.GlobalAveragePooling2D(),
181
layers.Dense(128, activation="relu"),
182
layers.Dense(10, activation="softmax"),
183
]
184
)
185
return model
186
187
188
"""
189
For the sake of reproducibility, we serialize the initial random weights of our shallow
190
network.
191
"""
192
193
initial_model = get_training_model()
194
initial_model.save_weights("initial_weights.weights.h5")
195
196
"""
197
## 1. Train the model with the mixed up dataset
198
"""
199
200
model = get_training_model()
201
model.load_weights("initial_weights.weights.h5")
202
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
203
model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
204
_, test_acc = model.evaluate(test_ds)
205
print("Test accuracy: {:.2f}%".format(test_acc * 100))
206
207
"""
208
## 2. Train the model *without* the mixed up dataset
209
"""
210
211
model = get_training_model()
212
model.load_weights("initial_weights.weights.h5")
213
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
214
# Notice that we are NOT using the mixed up dataset here
215
model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
216
_, test_acc = model.evaluate(test_ds)
217
print("Test accuracy: {:.2f}%".format(test_acc * 100))
218
219
"""
220
Readers are encouraged to try out mixup on different datasets from different domains and
221
experiment with the lambda parameter. You are strongly advised to check out the
222
[original paper](https://arxiv.org/abs/1710.09412) as well - the authors present several ablation studies on mixup
223
showing how it can improve generalization, as well as show their results of combining
224
more than two images to create a single one.
225
"""
226
227
"""
228
## Notes
229
230
* With mixup, you can create synthetic examples — especially when you lack a large
231
dataset - without incurring high computational costs.
232
* [Label smoothing](https://www.pyimagesearch.com/2019/12/30/label-smoothing-with-keras-tensorflow-and-deep-learning/) and mixup usually do not work well together because label smoothing
233
already modifies the hard labels by some factor.
234
* mixup does not work well when you are using [Supervised Contrastive
235
Learning](https://arxiv.org/abs/2004.11362) (SCL) since SCL expects the true labels
236
during its pre-training phase.
237
* A few other benefits of mixup include (as described in the [paper](https://arxiv.org/abs/1710.09412)) robustness to
238
adversarial examples and stabilized GAN (Generative Adversarial Networks) training.
239
* There are a number of data augmentation techniques that extend mixup such as
240
[CutMix](https://arxiv.org/abs/1905.04899) and [AugMix](https://arxiv.org/abs/1912.02781).
241
"""
242
243