Path: blob/master/examples/vision/gradient_centralization.py
7899 views
"""1Title: Gradient Centralization for Better Training Performance2Author: [Rishit Dagli](https://github.com/Rishit-dagli)3Date created: 06/18/214Last modified: 05/29/255Description: Implement Gradient Centralization to improve training performance of DNNs.6Accelerator: GPU7Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)8Debugged by: [Alberto M. EsmorÃs](https://github.com/albertoesmp)9"""1011"""12## Introduction1314This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a15new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it16on Laurence Moroney's [Horses or Humans17Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient18Centralization can both speedup training process and improve the final generalization19performance of DNNs. It operates directly on gradients by centralizing the gradient20vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of21the loss function and its gradient so that the training process becomes more efficient22and stable.2324This example requires `tensorflow_datasets` which can be installed with this command:2526```27pip install tensorflow-datasets28```29"""3031"""32## Setup33"""3435from time import time3637import keras38from keras import layers39from keras.optimizers import RMSprop40from keras import ops4142from tensorflow import data as tf_data43import tensorflow_datasets as tfds4445"""46## Prepare the data4748For this example, we will be using the [Horses or Humans49dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).50"""5152num_classes = 253input_shape = (300, 300, 3)54dataset_name = "horses_or_humans"55batch_size = 12856AUTOTUNE = tf_data.AUTOTUNE5758(train_ds, test_ds), metadata = tfds.load(59name=dataset_name,60split=[tfds.Split.TRAIN, tfds.Split.TEST],61with_info=True,62as_supervised=True,63)6465print(f"Image shape: {metadata.features['image'].shape}")66print(f"Training images: {metadata.splits['train'].num_examples}")67print(f"Test images: {metadata.splits['test'].num_examples}")6869"""70## Use Data Augmentation7172We will rescale the data to `[0, 1]` and perform simple augmentations to our data.73"""7475rescale = layers.Rescaling(1.0 / 255)7677data_augmentation = [78layers.RandomFlip("horizontal_and_vertical"),79layers.RandomRotation(0.3),80layers.RandomZoom(0.2),81]828384# Helper to apply augmentation85def apply_aug(x):86for aug in data_augmentation:87x = aug(x)88return x899091def prepare(ds, shuffle=False, augment=False):92# Rescale dataset93ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)9495if shuffle:96ds = ds.shuffle(1024)9798# Batch dataset99ds = ds.batch(batch_size)100101# Use data augmentation only on the training set102if augment:103ds = ds.map(104lambda x, y: (apply_aug(x), y),105num_parallel_calls=AUTOTUNE,106)107108# Use buffered prefecting109return ds.prefetch(buffer_size=AUTOTUNE)110111112"""113Rescale and augment the data114"""115116train_ds = prepare(train_ds, shuffle=True, augment=True)117test_ds = prepare(test_ds)118"""119## Define a model120121In this section we will define a Convolutional neural network.122"""123124125def make_model():126return keras.Sequential(127[128layers.Input(shape=input_shape),129layers.Conv2D(16, (3, 3), activation="relu"),130layers.MaxPooling2D(2, 2),131layers.Conv2D(32, (3, 3), activation="relu"),132layers.Dropout(0.5),133layers.MaxPooling2D(2, 2),134layers.Conv2D(64, (3, 3), activation="relu"),135layers.Dropout(0.5),136layers.MaxPooling2D(2, 2),137layers.Conv2D(64, (3, 3), activation="relu"),138layers.MaxPooling2D(2, 2),139layers.Conv2D(64, (3, 3), activation="relu"),140layers.MaxPooling2D(2, 2),141layers.Flatten(),142layers.Dropout(0.5),143layers.Dense(512, activation="relu"),144layers.Dense(1, activation="sigmoid"),145]146)147148149"""150## Implement Gradient Centralization151152We will now153subclass the `RMSProp` optimizer class modifying the154`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient155Centralization. On a high level the idea is that let us say we obtain our gradients156through back propagation for a Dense or Convolution layer we then compute the mean of the157column vectors of the weight matrix, and then remove the mean from each column vector.158159The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various160applications, including general image classification, fine-grained image classification,161detection and segmentation and Person ReID demonstrate that GC can consistently improve162the performance of DNN learning.163164Also, for simplicity at the moment we are not implementing gradient cliiping functionality,165however this quite easy to implement.166167At the moment we are just creating a subclass for the `RMSProp` optimizer168however you could easily reproduce this for any other optimizer or on a custom169optimizer in the same way. We will be using this class in the later section when170we train a model with Gradient Centralization.171"""172173174class GCRMSprop(RMSprop):175def get_gradients(self, loss, params):176# We here just provide a modified get_gradients() function since we are177# trying to just compute the centralized gradients.178179grads = []180gradients = super().get_gradients()181for grad in gradients:182grad_len = len(grad.shape)183if grad_len > 1:184axis = list(range(grad_len - 1))185grad -= ops.mean(grad, axis=axis, keep_dims=True)186grads.append(grad)187188return grads189190191optimizer = GCRMSprop(learning_rate=1e-4)192193"""194## Training utilities195196We will also create a callback which allows us to easily measure the total training time197and the time taken for each epoch since we are interested in comparing the effect of198Gradient Centralization on the model we built above.199"""200201202class TimeHistory(keras.callbacks.Callback):203def on_train_begin(self, logs={}):204self.times = []205206def on_epoch_begin(self, batch, logs={}):207self.epoch_time_start = time()208209def on_epoch_end(self, batch, logs={}):210self.times.append(time() - self.epoch_time_start)211212213"""214## Train the model without GC215216We now train the model we built earlier without Gradient Centralization which we can217compare to the training performance of the model trained with Gradient Centralization.218"""219220time_callback_no_gc = TimeHistory()221model = make_model()222model.compile(223loss="binary_crossentropy",224optimizer=RMSprop(learning_rate=1e-4),225metrics=["accuracy"],226)227228model.summary()229230"""231We also save the history since we later want to compare our model trained with and not232trained with Gradient Centralization233"""234235history_no_gc = model.fit(236train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]237)238239"""240## Train the model with GC241242We will now train the same model, this time using Gradient Centralization,243notice our optimizer is the one using Gradient Centralization this time.244"""245246time_callback_gc = TimeHistory()247model = make_model()248model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])249250model.summary()251252history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])253254"""255## Comparing performance256"""257258print("Not using Gradient Centralization")259print(f"Loss: {history_no_gc.history['loss'][-1]}")260print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")261print(f"Training Time: {sum(time_callback_no_gc.times)}")262263print("Using Gradient Centralization")264print(f"Loss: {history_gc.history['loss'][-1]}")265print(f"Accuracy: {history_gc.history['accuracy'][-1]}")266print(f"Training Time: {sum(time_callback_gc.times)}")267268"""269Readers are encouraged to try out Gradient Centralization on different datasets from270different domains and experiment with it's effect. You are strongly advised to check out271the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present272several studies on Gradient Centralization showing how it can improve general273performance, generalization, training time as well as more efficient.274275Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this276implementation.277"""278279280