Path: blob/master/examples/vision/ipynb/mobilevit.ipynb
3236 views
MobileViT: A mobile-friendly Transformer-based model for image classification
Author: Sayak Paul
Date created: 2021/10/20
Last modified: 2024/02/11
Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
Introduction
In this example, we implement the MobileViT architecture (Mehta et al.), which combines the benefits of Transformers (Vaswani et al.) and convolutions. With Transformers, we can capture long-range dependencies that result in global representations. With convolutions, we can capture spatial relationships that model locality.
Besides combining the properties of Transformers and convolutions, the authors introduce MobileViT as a general-purpose mobile-friendly backbone for different image recognition tasks. Their findings suggest that, performance-wise, MobileViT is better than other models with the same or higher complexity (MobileNetV3, for example), while being efficient on mobile devices.
Note: This example should be run with Tensorflow 2.13 and higher.
Imports
Hyperparameters
MobileViT utilities
The MobileViT architecture is comprised of the following blocks:
Strided 3x3 convolutions that process the input image.
MobileNetV2-style inverted residual blocks for downsampling the resolution of the intermediate feature maps.
MobileViT blocks that combine the benefits of Transformers and convolutions. It is presented in the figure below (taken from the original paper):
More on the MobileViT block:
First, the feature representations (A) go through convolution blocks that capture local relationships. The expected shape of a single entry here would be
(h, w, num_channels)
.Then they get unfolded into another vector with shape
(p, n, num_channels)
, wherep
is the area of a small patch, andn
is(h * w) / p
. So, we end up withn
non-overlapping patches.This unfolded vector is then passed through a Tranformer block that captures global relationships between the patches.
The output vector (B) is again folded into a vector of shape
(h, w, num_channels)
resembling a feature map coming out of convolutions.
Vectors A and B are then passed through two more convolutional layers to fuse the local and global representations. Notice how the spatial resolution of the final vector remains unchanged at this point. The authors also present an explanation of how the MobileViT block resembles a convolution block of a CNN. For more details, please refer to the original paper.
Next, we combine these blocks together and implement the MobileViT architecture (XXS variant). The following figure (taken from the original paper) presents a schematic representation of the architecture:
Dataset preparation
We will be using the tf_flowers
dataset to demonstrate the model. Unlike other Transformer-based architectures, MobileViT uses a simple augmentation pipeline primarily because it has the properties of a CNN.
The authors use a multi-scale data sampler to help the model learn representations of varied scales. In this example, we discard this part.
Load and prepare the dataset
Train a MobileViT (XXS) model
Results and TFLite conversion
With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite) and can be converted with the following code:
To learn more about different quantization recipes available in TFLite and running inference with TFLite models, check out this official resource.
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.