Class Attention Image Transformers with LayerScale
Author: Sayak Paul
Date created: 2022/09/19
Last modified: 2022/11/21
Description: Implementing an image transformer equipped with Class Attention and LayerScale.
Introduction
In this tutorial, we implement the CaiT (Class-Attention in Image Transformers) proposed in Going deeper with Image Transformers by Touvron et al. Depth scaling, i.e. increasing the model depth for obtaining better performance and generalization has been quite successful for convolutional neural networks (Tan et al., Dollár et al., for example). But applying the same model scaling principles to Vision Transformers (Dosovitskiy et al.) doesn't translate equally well -- their performance gets saturated quickly with depth scaling. Note that one assumption here is that the underlying pre-training dataset is always kept fixed when performing model scaling.
In the CaiT paper, the authors investigate this phenomenon and propose modifications to the vanilla ViT (Vision Transformers) architecture to mitigate this problem.
The tutorial is structured like so:
Implementation of the individual blocks of CaiT
Collating all the blocks to create the CaiT model
Loading a pre-trained CaiT model
Obtaining prediction results
Visualization of the different attention layers of CaiT
The readers are assumed to be familiar with Vision Transformers already. Here is an implementation of Vision Transformers in Keras: Image classification with Vision Transformer.
Imports
The LayerScale layer
We begin by implementing a LayerScale layer which is one of the two modifications proposed in the CaiT paper.
When increasing the depth of the ViT models, they meet with optimization instability and eventually don't converge. The residual connections within each Transformer block introduce information bottleneck. When there is an increased amount of depth, this bottleneck can quickly explode and deviate the optimization pathway for the underlying model.
The following equations denote where residual connections are added within a Transformer block:

where, SA stands for self-attention, FFN stands for feed-forward network, and eta denotes the LayerNorm operator (Ba et al.).
LayerScale is formally implemented like so:

where, the lambdas are learnable parameters and are initialized with a very small value ({0.1, 1e-5, 1e-6}). diag represents a diagonal matrix.
Intuitively, LayerScale helps control the contribution of the residual branches. The learnable parameters of LayerScale are initialized to a small value to let the branches act like identity functions and then let them figure out the degrees of interactions during the training. The diagonal matrix additionally helps control the contributions of the individual dimensions of the residual inputs as it is applied on a per-channel basis.
The practical implementation of LayerScale is simpler than it might sound.
Stochastic depth layer
Since its introduction (Huang et al.), Stochastic Depth has become a favorite component in almost all modern neural network architectures. CaiT is no exception. Discussing Stochastic Depth is out of scope for this notebook. You can refer to this resource in case you need a refresher.
Class attention
The vanilla ViT uses self-attention (SA) layers for modelling how the image patches and the learnable CLS token interact with each other. The CaiT authors propose to decouple the attention layers responsible for attending to the image patches and the CLS tokens.
When using ViTs for any discriminative tasks (classification, for example), we usually take the representations belonging to the CLS token and then pass them to the task-specific heads. This is as opposed to using something like global average pooling as is typically done in convolutional neural networks.
The interactions between the CLS token and other image patches are processed uniformly through self-attention layers. As the CaiT authors point out, this setup has got an entangled effect. On one hand, the self-attention layers are responsible for modelling the image patches. On the other hand, they're also responsible for summarizing the modelled information via the CLS token so that it's useful for the learning objective.
To help disentangle these two things, the authors propose to:
Introduce the CLS token at a later stage in the network.
Model the interaction between the CLS token and the representations related to the image patches through a separate set of attention layers. The authors call this Class Attention (CA).
The figure below (taken from the original paper) depicts this idea:

This is achieved by treating the CLS token embeddings as the queries in the CA layers. CLS token embeddings and the image patch embeddings are fed as keys as well values.
Note that "embeddings" and "representations" have been used interchangeably here.
Talking Head Attention
The CaiT authors use the Talking Head attention (Shazeer et al.) instead of the vanilla scaled dot-product multi-head attention used in the original Transformer paper (Vaswani et al.). They introduce two linear projections before and after the softmax operations for obtaining better results.
For a more rigorous treatment of the Talking Head attention and the vanilla attention mechanisms, please refer to their respective papers (linked above).
Feed-forward Network
Next, we implement the feed-forward network which is one of the components within a Transformer block.
Other blocks
In the next two cells, we implement the remaining blocks as standalone functions:
LayerScaleBlockClassAttention()
which returns akeras.Model
. It is a Transformer block equipped with Class Attention, LayerScale, and Stochastic Depth. It operates on the CLS embeddings and the image patch embeddings.LayerScaleBlock()
which returns akeras.model
. It is also a Transformer block that operates only on the embeddings of the image patches. It is equipped with LayerScale and Stochastic Depth.
Given all these blocks, we are now ready to collate them into the final CaiT model.
Putting the pieces together: The CaiT model
Having the SA and CA layers segregated this way helps the model to focus on underlying objectives more concretely:
model dependencies in between the image patches
summarize the information from the image patches in a CLS token that can be used for the task at hand
Now that we have defined the CaiT model, it's time to test it. We will start by defining a model configuration that will be passed to our CaiT
class for initialization.
Defining Model Configuration
Most of the configuration variables should sound familiar to you if you already know the ViT architecture. Point of focus is given to sa_ffn_layers
and ca_ffn_layers
that control the number of SA-Transformer blocks and CA-Transformer blocks. You can easily amend this get_config()
method to instantiate a CaiT model for your own dataset.
Model Instantiation
We can successfully perform inference with the model. But what about implementation correctness? There are many ways to verify it:
Obtain the performance of the model (given it's been populated with the pre-trained parameters) on the ImageNet-1k validation set (as the pretraining dataset was ImageNet-1k).
Fine-tune the model on a different dataset.
In order to verify that, we will load another instance of the same model that has been already populated with the pre-trained parameters. Please refer to this repository (developed by the author of this notebook) for more details. Additionally, the repository provides code to verify model performance on the ImageNet-1k validation set as well as fine-tuning.
Load a pretrained model
Inference utilities
In the next couple of cells, we develop preprocessing utilities needed to run inference with the pretrained model.
Now, we retrieve the ImageNet-1k labels and load them as the model we're loading was pretrained on the ImageNet-1k dataset.
Load an Image
Obtain Predictions
Now that we have obtained the predictions (which appear to be as expected), we can further extend our investigation. Following the CaiT authors, we can investigate the attention scores from the attention layers. This helps us to get deeper insights into the modifications introduced in the CaiT paper.
Visualizing the Attention Layers
We start by inspecting the shape of the attention weights returned by a Class Attention layer.
The shape denotes we have got attention weights for each of the individual attention heads. They quantify the information about how the CLS token is related to itself and the rest of the image patches.
Next, we write a utility to:
Visualize what the individual attention heads in the Class Attention layers are focusing on. This helps us to get an idea of how the spatial-class relationship is induced in the CaiT model.
Obtain a saliency map from the first Class Attention layer that helps to understand how CA layer aggregates information from the region(s) of interest in the images.
This utility is referred from Figures 6 and 7 of the original CaiT paper. This is also a part of this notebook (developed by the author of this tutorial).
In the first CA layer, we notice that the model is focusing solely on the region of interest.
Whereas in the second CA layer, the model is trying to focus more on the context that contains discriminative signals.
Finally, we obtain the saliency map for the given image.
Conclusion
In this notebook, we implemented the CaiT model. It shows how to mitigate the issues in ViTs when trying scale their depth while keeping the pretraining dataset fixed. I hope the additional visualizations provided in the notebook spark excitement in the community and people develop interesting methods to probe what models like ViT learn.
Acknowledgement
Thanks to the ML Developer Programs team at Google providing Google Cloud Platform support.