Distilling Vision Transformers
Author: Sayak Paul
Date created: 2022/04/05
Last modified: 2022/04/08
Description: Distillation of Vision Transformers through attention.
Introduction
In the original Vision Transformers (ViT) paper (Dosovitskiy et al.), the authors concluded that to perform on par with Convolutional Neural Networks (CNNs), ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly due to the lack of inductive biases in the ViT architecture -- unlike CNNs, they don't have layers that exploit locality. In a follow-up paper (Steiner et al.), the authors show that it is possible to substantially improve the performance of ViTs with stronger regularization and longer training.
Many groups have proposed different ways to deal with the problem of data-intensiveness of ViT training. One such way was shown in the Data-efficient image Transformers, (DeiT) paper (Touvron et al.). The authors introduced a distillation technique that is specific to transformer-based vision models. DeiT is among the first works to show that it's possible to train ViTs well without using larger datasets.
In this example, we implement the distillation recipe proposed in DeiT. This requires us to slightly tweak the original ViT architecture and write a custom training loop to implement the distillation recipe.
To run the example, you'll need TensorFlow Addons, which you can install with the following command:
To comfortably navigate through this example, you'll be expected to know how a ViT and knowledge distillation work. The following are good resources in case you needed a refresher:
Imports
Constants
You probably noticed that DROPOUT_RATE
has been set 0.0. Dropout has been used in the implementation to keep it complete. For smaller models (like the one used in this example), you don't need it, but for bigger models, using dropout helps.
Load the tf_flowers
dataset and prepare preprocessing utilities
The authors use an array of different augmentation techniques, including MixUp (Zhang et al.), RandAugment (Cubuk et al.), and so on. However, to keep the example simple to work through, we'll discard them.
Implementing the DeiT variants of ViT
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend it to support DeiT's components.
First, we'll implement a layer for Stochastic Depth (Huang et al.) which is used in DeiT for regularization.
Now, we'll implement the MLP and Transformer blocks.
We'll now implement a ViTClassifier
class building on top of the components we just developed. Here we'll be following the original pooling strategy used in the ViT paper -- use a class token and use the feature representations corresponding to it for classification.
This class can be used standalone as ViT and is end-to-end trainable. Just remove the distilled
phrase in MODEL_TYPE
and it should work with vit_tiny = ViTClassifier()
. Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken from the DeiT paper):
Apart from the class token, DeiT has another token for distillation. During distillation, the logits corresponding to the class token are compared to the true labels, and the logits corresponding to the distillation token are compared to the teacher's predictions.
Let's verify if the ViTDistilled
class can be initialized and called as expected.
Implementing the trainer
Unlike what happens in standard knowledge distillation (Hinton et al.), where a temperature-scaled softmax is used as well as KL divergence, DeiT authors use the following loss function:
Here,
CE is cross-entropy
psi
is the softmax functionZ_s denotes student predictions
y denotes true labels
y_t denotes teacher predictions
Load the teacher model
This model is based on the BiT family of ResNets (Kolesnikov et al.) fine-tuned on the tf_flowers
dataset. You can refer to this notebook to know how the training was performed. The teacher model has about 212 Million parameters which is about 40x more than the student.
Training through distillation
If we had trained the same model (the ViTClassifier
) from scratch with the exact same hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code to reproduce this result:
Notes
Through the use of distillation, we're effectively transferring the inductive biases of a CNN-based teacher model.
Interestingly enough, this distillation strategy works better with a CNN as the teacher model rather than a Transformer as shown in the paper.
The use of regularization to train DeiT models is very important.
ViT models are initialized with a combination of different initializers including truncated normal, random normal, Glorot uniform, etc. If you're looking for end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
If you want to explore the pre-trained DeiT models in TensorFlow and Keras with code for fine-tuning, check out these models on TF-Hub.
Acknowledgements
Ross Wightman for keeping
timm
updated with readable implementations. I referred to the implementations of ViT and DeiT a lot during implementing them in TensorFlow.Aritra Roy Gosthipaty who implemented some portions of the
ViTClassifier
in another project.Google Developers Experts program for supporting me with GCP credits which were used to run experiments for this example.
Example available on HuggingFace: