Path: blob/master/examples/vision/ipynb/masked_image_modeling.ipynb
3236 views
Masked image modeling with Autoencoders
Author: Aritra Roy Gosthipaty, Sayak Paul
Date created: 2021/12/20
Last modified: 2021/12/21
Description: Implementing Masked Autoencoders for self-supervised pretraining.
Introduction
In deep learning, models with growing capacity and capability can easily overfit on large datasets (ImageNet-1K). In the field of natural language processing, the appetite for data has been successfully addressed by self-supervised pretraining.
In the academic paper Masked Autoencoders Are Scalable Vision Learners by He et. al. the authors propose a simple yet effective method to pretrain large vision models (here ViT Huge). Inspired from the pretraining algorithm of BERT (Devlin et al.), they mask patches of an image and, through an autoencoder predict the masked patches. In the spirit of "masked language modeling", this pretraining task could be referred to as "masked image modeling".
In this example, we implement Masked Autoencoders Are Scalable Vision Learners with the CIFAR-10 dataset. After pretraining a scaled down version of ViT, we also implement the linear evaluation pipeline on CIFAR-10.
This implementation covers (MAE refers to Masked Autoencoder):
The masking algorithm
MAE encoder
MAE decoder
Evaluation with linear probing
As a reference, we reuse some of the code presented in this example.
Imports
Hyperparameters for pretraining
Please feel free to change the hyperparameters and check your results. The best way to get an intuition about the architecture is to experiment with it. Our hyperparameters are heavily inspired by the design guidelines laid out by the authors in the original paper.
Load and prepare the CIFAR-10 dataset
Data augmentation
In previous self-supervised pretraining methodologies (SimCLR alike), we have noticed that the data augmentation pipeline plays an important role. On the other hand the authors of this paper point out that Masked Autoencoders do not rely on augmentations. They propose a simple augmentation pipeline of:
Resizing
Random cropping (fixed-sized or random sized)
Random horizontal flipping
A layer for extracting patches from images
This layer takes images as input and divides them into patches. The layer also includes two utility method:
show_patched_image
-- Takes a batch of images and its corresponding patches to plot a random pair of image and patches.reconstruct_from_patch
-- Takes a single instance of patches and stitches them together into the original image.
Let's visualize the image patches.
Patch encoding with masking
Quoting the paper
Following ViT, we divide an image into regular non-overlapping patches. Then we sample a subset of patches and mask (i.e., remove) the remaining ones. Our sampling strategy is straightforward: we sample random patches without replacement, following a uniform distribution. We simply refer to this as “random sampling”.
This layer includes masking and encoding the patches.
The utility methods of the layer are:
get_random_indices
-- Provides the mask and unmask indices.generate_masked_image
-- Takes patches and unmask indices, results in a random masked image. This is an essential utility method for our training monitor callback (defined later).
Let's see the masking process in action on a sample image.
MLP
This serves as the fully connected feed forward network of the transformer architecture.
MAE encoder
The MAE encoder is ViT. The only point to note here is that the encoder outputs a layer normalized output.
MAE decoder
The authors point out that they use an asymmetric autoencoder model. They use a lightweight decoder that takes "<10% computation per token vs. the encoder". We are not specific with the "<10% computation" in our implementation but have used a smaller decoder (both in terms of depth and projection dimensions).
MAE trainer
This is the trainer module. We wrap the encoder and decoder inside of a tf.keras.Model
subclass. This allows us to customize what happens in the model.fit()
loop.
Model initialization
Training callbacks
Visualization callback
Learning rate scheduler
Model compilation and training
Evaluation with linear probing
Extract the encoder model along with other layers
We are using average pooling to extract learned representations from the MAE encoder. Another approach would be to use a learnable dummy token inside the encoder during pretraining (resembling the [CLS] token). Then we can extract representations from that token during the downstream tasks.
Prepare datasets for linear probing
Perform linear probing
We believe that with a more sophisticated hyperparameter tuning process and a longer pretraining it is possible to improve this performance further. For comparison, we took the encoder architecture and trained it from scratch in a fully supervised manner. This gave us ~76% test top-1 accuracy. The authors of MAE demonstrates strong performance on the ImageNet-1k dataset as well as other downstream tasks like object detection and semantic segmentation.
Final notes
We refer the interested readers to other examples on self-supervised learning present on keras.io:
This idea of using BERT flavored pretraining in computer vision was also explored in Selfie, but it could not demonstrate strong results. Another concurrent work that explores the idea of masked image modeling is SimMIM. Finally, as a fun fact, we, the authors of this example also explored the idea of "reconstruction as a pretext task" in 2020 but we could not prevent the network from representation collapse, and hence we did not get strong downstream performance.
We would like to thank Xinlei Chen (one of the authors of MAE) for helpful discussions. We are grateful to JarvisLabs and Google Developers Experts program for helping with GPU credits.