Path: blob/master/guides/ipynb/keras_cv/simsiam_with_kerascv.ipynb
3282 views
SimSiam Training with TensorFlow Similarity and KerasCV
Author: lukewood, Ian Stenbit, Owen Vallis
Date created: 2023/01/22
Last modified: 2023/01/22
Description: Train a KerasCV model using unlabelled data with SimSiam.
Overview
TensorFlow similarity makes it easy to train KerasCV models on unlabelled corpuses of data using contrastive learning algorithms such as SimCLR, SimSiam, and Barlow Twins. In this guide, we will train a KerasCV model using the SimSiam implementation from TensorFlow Similarity.
Background
Self-supervised learning is an approach to pre-training models using unlabeled data. This approach drastically increases accuracy when you have very few labeled examples but a lot of unlabelled data. The key insight is that you can train a self-supervised model to learn data representations by contrasting multiple augmented views of the same example. These learned representations capture data invariants, e.g., object translation, color jitter, noise, etc. Training a simple linear classifier on top of the frozen representations is easier and requires fewer labels because the pre-trained model already produces meaningful and generally useful features.
Overall, self-supervised pre-training learns representations which are more generic and robust than other approaches to augmented training and pre-training. An overview of the general contrastive learning process is shown below:
In this tutorial, we will use the SimSiam algorithm for contrastive learning. As of 2022, SimSiam is the state of the art algorithm for contrastive learning; allowing for unprecedented scores on CIFAR-100 and other datasets.
You may need to install:
To get started, we will sort out some imports.
Lets sort out some high level config issues and define some constants. The resource limit increase is required to load STL-10, tfsim.utils.tf_cap_memory()
prevents TensorFlow from hogging the GPU memory in a cluster, and tfds.disable_progress_bar()
makes tfds less noisy.
Data loading
Next, we will load the STL-10 dataset. STL-10 is a dataset consisting of 100k unlabelled images, 5k labelled training images, and 10k labelled test images. Due to this distribution, STL-10 is commonly used as a benchmark for contrastive learning models.
First lets load our unlabelled data
Next, we need to prepare some labelled samples. This is done so that TensorFlow similarity can probe the learned embedding to ensure that the model is learning appropriately.
In self supervised learning, queries and indexes are labeled subset datasets used to evaluate the quality of the produced latent embedding. The following code assembles these datasets:
Augmentations
Self-supervised networks require at least two augmented "views" of each example. This can be created using a dataset and an augmentation function. The dataset treats each example in the batch as its own class and then the augment function produces two separate views for each example.
This means the resulting batch will yield tuples containing the two views, i.e., Tuple[(BATCH_SIZE, 32, 32, 3), (BATCH_SIZE, 32, 32, 3)].
Using KerasCV, it is trivial to construct an augmenter that performs as the one described in the original SimSiam paper. Lets do that below.
Next, lets pass our images through this pipeline. Note that KerasCV supports batched augmentation, so batching before augmentation dramatically improves performance
Lets visualize our pairs using the tfsim.visualization
utility package.
Model Creation
Now that our data and augmentation pipeline is setup, we can move on to constructing the contrastive learning pipeline. First, lets produce a backbone. For this task, we will use a KerasCV ResNet18 model as the backbone.
This MLP is common to all the self-supervised models and is typically a stack of 3 layers of the same size. However, SimSiam only uses 2 layers for the smaller CIFAR images. Having too much capacity in the models can make it difficult for the loss to stabilize and converge.
Note: This is the model output that is returned by ContrastiveModel.predict()
and represents the distance based embedding. This embedding can be used for the KNN lookups and matching classification metrics. However, when using the pre-train model for downstream tasks, only the ContrastiveModel.backbone
is used.
Finally, we must construct the predictor. The predictor is used in SimSiam, and is a simple stack of two MLP layers, containing a bottleneck in the hidden layer.
Training
First, we need to initialize our training model, loss, and optimizer.
Next we can compile the model the same way you compile any other Keras model.
We track the training using EvalCallback
. EvalCallback
creates an index at the end of each epoch and provides a proxy for the nearest neighbor matching classification using binary_accuracy
. Calculates how often the query label matches the derived lookup label.
Accuracy is technically (TP+TN)/(TP+FP+TN+FN), but here we filter all queries above the distance threshold. In the case of binary matching, this makes all the TPs and FPs below the distance threshold and all the TNs and FNs above the distance threshold.
As we are only concerned with the matches below the distance threshold, the accuracy simplifies to TP/(TP+FP) and is equivalent to the precision with respect to the unfiltered queries. However, we also want to consider the query coverage at the distance threshold, i.e., the percentage of queries that retrun a match, computed as (TP+FP)/(TP+FP+TN+FN). Therefore, we can take to produce a measure that capture the precision scaled by the query coverage. This simplifies down to the binary accuracy presented here, giving TP/(TP+FP+TN+FN).
All that is left to do is run fit()!
Plotting and Evaluation
Fine Tuning on the Labelled Data
As a final step we will fine tune a classifier on 10% of the training data. This will allow us to evaluate the quality of our learned representation. First, we handle data loading:
Benchmark Against a Naive Model
Finally, lets setup a naive model that does not leverage the unlabeled data corpus.
Pretty bad results! Lets try fine-tuning our SimSiam pretrained model:
All that is left to do is evaluate the models:
Awesome! Our pretrained model stomped the non-pretrained model. 71% accuracy is quite good for a ResNet18 on the STL-10 dataset. For better results, try using an EfficientNetV2B0 instead. Unfortunately, this will require a higher end graphics card as SimSiam has a minimum batch size of 512.
Conclusion
TensorFlow Similarity can be used to easily train KerasCV models using contrastive algorithms such as SimCLR, SimSiam and BarlowTwins. This allows you to leverage large corpuses of unlabelled data in your model trainining pipeline.
Some follow-up exercises to this tutorial:
Train a
keras_cv.models.EfficientNetV2B0
on STL-10Experiment with other data augmentation techniques in pretraining
Train a model using the BarlowTwins implementation in TensorFlow similarity
Try pretraining on your own dataset