Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/gradient_centralization.py
7899 views
1
"""
2
Title: Gradient Centralization for Better Training Performance
3
Author: [Rishit Dagli](https://github.com/Rishit-dagli)
4
Date created: 06/18/21
5
Last modified: 05/29/25
6
Description: Implement Gradient Centralization to improve training performance of DNNs.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)
9
Debugged by: [Alberto M. Esmorís](https://github.com/albertoesmp)
10
"""
11
12
"""
13
## Introduction
14
15
This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a
16
new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it
17
on Laurence Moroney's [Horses or Humans
18
Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient
19
Centralization can both speedup training process and improve the final generalization
20
performance of DNNs. It operates directly on gradients by centralizing the gradient
21
vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of
22
the loss function and its gradient so that the training process becomes more efficient
23
and stable.
24
25
This example requires `tensorflow_datasets` which can be installed with this command:
26
27
```
28
pip install tensorflow-datasets
29
```
30
"""
31
32
"""
33
## Setup
34
"""
35
36
from time import time
37
38
import keras
39
from keras import layers
40
from keras.optimizers import RMSprop
41
from keras import ops
42
43
from tensorflow import data as tf_data
44
import tensorflow_datasets as tfds
45
46
"""
47
## Prepare the data
48
49
For this example, we will be using the [Horses or Humans
50
dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).
51
"""
52
53
num_classes = 2
54
input_shape = (300, 300, 3)
55
dataset_name = "horses_or_humans"
56
batch_size = 128
57
AUTOTUNE = tf_data.AUTOTUNE
58
59
(train_ds, test_ds), metadata = tfds.load(
60
name=dataset_name,
61
split=[tfds.Split.TRAIN, tfds.Split.TEST],
62
with_info=True,
63
as_supervised=True,
64
)
65
66
print(f"Image shape: {metadata.features['image'].shape}")
67
print(f"Training images: {metadata.splits['train'].num_examples}")
68
print(f"Test images: {metadata.splits['test'].num_examples}")
69
70
"""
71
## Use Data Augmentation
72
73
We will rescale the data to `[0, 1]` and perform simple augmentations to our data.
74
"""
75
76
rescale = layers.Rescaling(1.0 / 255)
77
78
data_augmentation = [
79
layers.RandomFlip("horizontal_and_vertical"),
80
layers.RandomRotation(0.3),
81
layers.RandomZoom(0.2),
82
]
83
84
85
# Helper to apply augmentation
86
def apply_aug(x):
87
for aug in data_augmentation:
88
x = aug(x)
89
return x
90
91
92
def prepare(ds, shuffle=False, augment=False):
93
# Rescale dataset
94
ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)
95
96
if shuffle:
97
ds = ds.shuffle(1024)
98
99
# Batch dataset
100
ds = ds.batch(batch_size)
101
102
# Use data augmentation only on the training set
103
if augment:
104
ds = ds.map(
105
lambda x, y: (apply_aug(x), y),
106
num_parallel_calls=AUTOTUNE,
107
)
108
109
# Use buffered prefecting
110
return ds.prefetch(buffer_size=AUTOTUNE)
111
112
113
"""
114
Rescale and augment the data
115
"""
116
117
train_ds = prepare(train_ds, shuffle=True, augment=True)
118
test_ds = prepare(test_ds)
119
"""
120
## Define a model
121
122
In this section we will define a Convolutional neural network.
123
"""
124
125
126
def make_model():
127
return keras.Sequential(
128
[
129
layers.Input(shape=input_shape),
130
layers.Conv2D(16, (3, 3), activation="relu"),
131
layers.MaxPooling2D(2, 2),
132
layers.Conv2D(32, (3, 3), activation="relu"),
133
layers.Dropout(0.5),
134
layers.MaxPooling2D(2, 2),
135
layers.Conv2D(64, (3, 3), activation="relu"),
136
layers.Dropout(0.5),
137
layers.MaxPooling2D(2, 2),
138
layers.Conv2D(64, (3, 3), activation="relu"),
139
layers.MaxPooling2D(2, 2),
140
layers.Conv2D(64, (3, 3), activation="relu"),
141
layers.MaxPooling2D(2, 2),
142
layers.Flatten(),
143
layers.Dropout(0.5),
144
layers.Dense(512, activation="relu"),
145
layers.Dense(1, activation="sigmoid"),
146
]
147
)
148
149
150
"""
151
## Implement Gradient Centralization
152
153
We will now
154
subclass the `RMSProp` optimizer class modifying the
155
`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
156
Centralization. On a high level the idea is that let us say we obtain our gradients
157
through back propagation for a Dense or Convolution layer we then compute the mean of the
158
column vectors of the weight matrix, and then remove the mean from each column vector.
159
160
The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various
161
applications, including general image classification, fine-grained image classification,
162
detection and segmentation and Person ReID demonstrate that GC can consistently improve
163
the performance of DNN learning.
164
165
Also, for simplicity at the moment we are not implementing gradient cliiping functionality,
166
however this quite easy to implement.
167
168
At the moment we are just creating a subclass for the `RMSProp` optimizer
169
however you could easily reproduce this for any other optimizer or on a custom
170
optimizer in the same way. We will be using this class in the later section when
171
we train a model with Gradient Centralization.
172
"""
173
174
175
class GCRMSprop(RMSprop):
176
def get_gradients(self, loss, params):
177
# We here just provide a modified get_gradients() function since we are
178
# trying to just compute the centralized gradients.
179
180
grads = []
181
gradients = super().get_gradients()
182
for grad in gradients:
183
grad_len = len(grad.shape)
184
if grad_len > 1:
185
axis = list(range(grad_len - 1))
186
grad -= ops.mean(grad, axis=axis, keep_dims=True)
187
grads.append(grad)
188
189
return grads
190
191
192
optimizer = GCRMSprop(learning_rate=1e-4)
193
194
"""
195
## Training utilities
196
197
We will also create a callback which allows us to easily measure the total training time
198
and the time taken for each epoch since we are interested in comparing the effect of
199
Gradient Centralization on the model we built above.
200
"""
201
202
203
class TimeHistory(keras.callbacks.Callback):
204
def on_train_begin(self, logs={}):
205
self.times = []
206
207
def on_epoch_begin(self, batch, logs={}):
208
self.epoch_time_start = time()
209
210
def on_epoch_end(self, batch, logs={}):
211
self.times.append(time() - self.epoch_time_start)
212
213
214
"""
215
## Train the model without GC
216
217
We now train the model we built earlier without Gradient Centralization which we can
218
compare to the training performance of the model trained with Gradient Centralization.
219
"""
220
221
time_callback_no_gc = TimeHistory()
222
model = make_model()
223
model.compile(
224
loss="binary_crossentropy",
225
optimizer=RMSprop(learning_rate=1e-4),
226
metrics=["accuracy"],
227
)
228
229
model.summary()
230
231
"""
232
We also save the history since we later want to compare our model trained with and not
233
trained with Gradient Centralization
234
"""
235
236
history_no_gc = model.fit(
237
train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
238
)
239
240
"""
241
## Train the model with GC
242
243
We will now train the same model, this time using Gradient Centralization,
244
notice our optimizer is the one using Gradient Centralization this time.
245
"""
246
247
time_callback_gc = TimeHistory()
248
model = make_model()
249
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
250
251
model.summary()
252
253
history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
254
255
"""
256
## Comparing performance
257
"""
258
259
print("Not using Gradient Centralization")
260
print(f"Loss: {history_no_gc.history['loss'][-1]}")
261
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
262
print(f"Training Time: {sum(time_callback_no_gc.times)}")
263
264
print("Using Gradient Centralization")
265
print(f"Loss: {history_gc.history['loss'][-1]}")
266
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
267
print(f"Training Time: {sum(time_callback_gc.times)}")
268
269
"""
270
Readers are encouraged to try out Gradient Centralization on different datasets from
271
different domains and experiment with it's effect. You are strongly advised to check out
272
the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present
273
several studies on Gradient Centralization showing how it can improve general
274
performance, generalization, training time as well as more efficient.
275
276
Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this
277
implementation.
278
"""
279
280