Path: blob/master/examples/vision/ipynb/mlp_image_classification.ipynb
3236 views
Image classification with modern MLP models
Author: Khalid Salama
Date created: 2021/05/30
Last modified: 2023/08/03
Description: Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.
Introduction
This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image classification, demonstrated on the CIFAR-100 dataset:
The MLP-Mixer model, by Ilya Tolstikhin et al., based on two types of MLPs.
The FNet model, by James Lee-Thorp et al., based on unparameterized Fourier Transform.
The gMLP model, by Hanxiao Liu et al., based on MLP with gating.
The purpose of the example is not to compare between these models, as they might perform differently on different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their main building blocks.
Setup
Prepare the data
Configure the hyperparameters
Build a classification model
We implement a method that builds a classifier given the processing blocks.
Define an experiment
We implement a utility function to compile, train, and evaluate a given model.
Use data augmentation
Implement patch extraction as a layer
Implement position embedding as a layer
The MLP-Mixer model
The MLP-Mixer is an architecture based exclusively on multi-layer perceptrons (MLPs), that contains two types of MLP layers:
One applied independently to image patches, which mixes the per-location features.
The other applied across patches (along channels), which mixes spatial information.
This is similar to a depthwise separable convolution based model such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization instead of batch normalization.
Implement the MLP-Mixer module
Build, train, and evaluate the MLP-Mixer model
Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.
The MLP-Mixer model tends to have much less number of parameters compared to convolutional and transformer-based models, which leads to less training and serving computational cost.
As mentioned in the MLP-Mixer paper, when pre-trained on large datasets, or with modern regularization schemes, the MLP-Mixer attains competitive scores to state-of-the-art models. You can obtain better results by increasing the embedding dimensions, increasing the number of mixer blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes.
The FNet model
The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer in the Transformer block with a parameter-free 2D Fourier transformation layer:
One 1D Fourier Transform is applied along the patches.
One 1D Fourier Transform is applied along the channels.
Implement the FNet module
Build, train, and evaluate the FNet model
Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.
As shown in the FNet paper, better results can be achieved by increasing the embedding dimensions, increasing the number of FNet blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes. The FNet scales very efficiently to long inputs, runs much faster than attention-based Transformer models, and produces competitive accuracy results.
The gMLP model
The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU). The SGU enables cross-patch interactions across the spatial (channel) dimension, by:
Transforming the input spatially by applying linear projection across patches (along channels).
Applying element-wise multiplication of the input and its spatial transformation.
Implement the gMLP module
Build, train, and evaluate the gMLP model
Note that training the model with the current settings on a V100 GPUs takes around 9 seconds per epoch.
As shown in the gMLP paper, better results can be achieved by increasing the embedding dimensions, increasing the number of gMLP blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes. Note that, the paper used advanced regularization strategies, such as MixUp and CutMix, as well as AutoAugment.