Path: blob/master/examples/vision/ipynb/barlow_twins.ipynb
3236 views
Barlow Twins for Contrastive SSL
Author: Abhiraam Eranti
Date created: 11/4/21
Last modified: 12/20/21
Description: A keras implementation of Barlow Twins (constrastive SSL with redundancy reduction).
Introduction
Self-supervised learning (SSL) is a relatively novel technique in which a model learns from unlabeled data, and is often used when the data is corrupted or if there is very little of it. A practical use for SSL is to create intermediate embeddings that are learned from the data. These embeddings are based on the dataset itself, with similar images having similar embeddings, and vice versa. They are then attached to the rest of the model, which uses those embeddings as information and effectively learns and makes predictions properly. These embeddings, ideally, should contain as much information and insight about the data as possible, so that the model can make better predictions. However, a common problem that arises is that the model creates embeddings that are redundant. For example, if two images are similar, the model will create embeddings that are just a string of 1's, or some other value that contains repeating bits of information. This is no better than a one-hot encoding or just having one bit as the model’s representations; it defeats the purpose of the embeddings, as they do not learn as much about the dataset as possible. For other approaches, the solution to the problem was to carefully configure the model such that it tries not to be redundant.
Barlow Twins is a new approach to this problem; while other solutions mainly tackle the first goal of invariance (similar images have similar embeddings), the Barlow Twins method also prioritizes the goal of reducing redundancy.
It also has the advantage of being much simpler than other methods, and its model architecture is symmetric, meaning that both twins in the model do the same thing. It is also near state-of-the-art on imagenet, even exceeding methods like SimCLR.
One disadvantage of Barlow Twins is that it is heavily dependent on augmentation, suffering major performance decreases in accuracy without them.
TL, DR: Barlow twins creates representations that are:
Invariant.
Not redundant, and carry as much info about the dataset.
Also, it is simpler than other methods.
This notebook can train a Barlow Twins model and reach up to 64% validation accuracy on the CIFAR-10 dataset.
High-Level Theory
The model takes two versions of the same image(with different augmentations) as input. Then it takes a prediction of each of them, creating representations. They are then used to make a cross-correlation matrix.
Cross-correlation matrix:
The cross-correlation matrix measures the correlation between the output neurons in the two representations made by the model predictions of the two augmented versions of data. Ideally, a cross-correlation matrix should look like an identity matrix if the two images are the same.
When this happens, it means that the representations:
Are invariant. The diagonal shows the correlation between each representation's neurons and its corresponding augmented one. Because the two versions come from the same image, the diagonal of the matrix should show that there is a strong correlation between them. If the images are different, there shouldn't be a diagonal.
Do not show signs of redundancy. If the neurons show correlation with a non-diagonal neuron, it means that it is not correctly identifying similarities between the two augmented images. This means that it is redundant.
Here is a good way of understanding in pseudocode(information from the original paper):
Taken from the original paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction
References
Paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction
Original Implementation: facebookresearch/barlowtwins
Setup
Load the CIFAR-10 dataset
Necessary Hyperparameters
Augmentation Utilities
The Barlow twins algorithm is heavily reliant on Augmentation. One unique feature of the method is that sometimes, augmentations probabilistically occur.
Augmentations
RandomToGrayscale: randomly applies grayscale to image 20% of the time
RandomColorJitter: randomly applies color jitter 80% of the time
RandomFlip: randomly flips image horizontally 50% of the time
RandomResizedCrop: randomly crops an image to a random size then resizes. This happens 100% of the time
RandomSolarize: randomly applies solarization to an image 20% of the time
RandomBlur: randomly blurs an image 20% of the time
Data Loading
A class that creates the barlow twins' dataset.
The dataset consists of two copies of each image, with each copy receiving different augmentations.
View examples of dataset.
Pseudocode of loss and model
The following sections follow the original author's pseudocode containing both model and loss functions(see diagram below). Also contains a reference of variables used.
Reference:
BarlowLoss: barlow twins model's loss function
Barlow Twins uses the cross correlation matrix for its loss. There are two parts to the loss function:
The invariance term(diagonal). This part is used to make the diagonals of the matrix into 1s. When this is the case, the matrix shows that the images are correlated(same).
The loss function subtracts 1 from the diagonal and squares the values.
The redundancy reduction term(off-diagonal). Here, the barlow twins loss function aims to make these values zero. As mentioned before, it is redundant if the representation neurons are correlated with values that are not on the diagonal.
Off diagonals are squared.
After this the two parts are summed together.
Barlow Twins' Model Architecture
The model has two parts:
The encoder network, which is a resnet-34.
The projector network, which creates the model embeddings.
This consists of an MLP with 3 dense-batchnorm-relu layers.
Resnet encoder network implementation:
Projector network:
Training Loop Model
See pseudocode for reference.
Model Training
Used the LAMB optimizer, instead of ADAM or SGD.
Similar to the LARS optimizer used in the paper, and lets the model converge much faster than other methods.
Expected training time: 1 hour 30 min. Go and eat a snack or take a nap or something.
Evaluation
Linear evaluation: to evaluate the model's performance, we add a linear dense layer at the end and freeze the main model's weights, only letting the dense layer to be tuned. If the model actually learned something, then the accuracy would be significantly higher than random chance.
Accuracy on CIFAR-10 : 64% for this notebook. This is much better than the 10% we get from random guessing.
Conclusion
Barlow Twins is a simple and concise method for contrastive and self-supervised learning.
With this resnet-34 model architecture, we were able to reach 62-64% validation accuracy.
Use-Cases of Barlow-Twins(and contrastive learning in General)
Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy when it wasn't even trained with the labels. It can be used when you have little labeled data but a lot of unlabeled data.
You do barlow twins training on the unlabeled data, and then you do secondary training with the labeled data.
Helpful links
Thanks to Sayak Paul for his implementation. It helped me with debugging and comparisons of accuracy, loss.
Thanks to Yashowardhan Shinde for writing the article.