Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/ipynb/barlow_twins.ipynb
3236 views
Kernel: Python 3

Barlow Twins for Contrastive SSL

Author: Abhiraam Eranti
Date created: 11/4/21
Last modified: 12/20/21
Description: A keras implementation of Barlow Twins (constrastive SSL with redundancy reduction).

Introduction

Self-supervised learning (SSL) is a relatively novel technique in which a model learns from unlabeled data, and is often used when the data is corrupted or if there is very little of it. A practical use for SSL is to create intermediate embeddings that are learned from the data. These embeddings are based on the dataset itself, with similar images having similar embeddings, and vice versa. They are then attached to the rest of the model, which uses those embeddings as information and effectively learns and makes predictions properly. These embeddings, ideally, should contain as much information and insight about the data as possible, so that the model can make better predictions. However, a common problem that arises is that the model creates embeddings that are redundant. For example, if two images are similar, the model will create embeddings that are just a string of 1's, or some other value that contains repeating bits of information. This is no better than a one-hot encoding or just having one bit as the model’s representations; it defeats the purpose of the embeddings, as they do not learn as much about the dataset as possible. For other approaches, the solution to the problem was to carefully configure the model such that it tries not to be redundant.

Barlow Twins is a new approach to this problem; while other solutions mainly tackle the first goal of invariance (similar images have similar embeddings), the Barlow Twins method also prioritizes the goal of reducing redundancy.

It also has the advantage of being much simpler than other methods, and its model architecture is symmetric, meaning that both twins in the model do the same thing. It is also near state-of-the-art on imagenet, even exceeding methods like SimCLR.

One disadvantage of Barlow Twins is that it is heavily dependent on augmentation, suffering major performance decreases in accuracy without them.

TL, DR: Barlow twins creates representations that are:

  • Invariant.

  • Not redundant, and carry as much info about the dataset.

Also, it is simpler than other methods.

This notebook can train a Barlow Twins model and reach up to 64% validation accuracy on the CIFAR-10 dataset.

image

High-Level Theory

The model takes two versions of the same image(with different augmentations) as input. Then it takes a prediction of each of them, creating representations. They are then used to make a cross-correlation matrix.

Cross-correlation matrix:

(pred_1.T @ pred_2) / batch_size

The cross-correlation matrix measures the correlation between the output neurons in the two representations made by the model predictions of the two augmented versions of data. Ideally, a cross-correlation matrix should look like an identity matrix if the two images are the same.

When this happens, it means that the representations:

  1. Are invariant. The diagonal shows the correlation between each representation's neurons and its corresponding augmented one. Because the two versions come from the same image, the diagonal of the matrix should show that there is a strong correlation between them. If the images are different, there shouldn't be a diagonal.

  2. Do not show signs of redundancy. If the neurons show correlation with a non-diagonal neuron, it means that it is not correctly identifying similarities between the two augmented images. This means that it is redundant.

Here is a good way of understanding in pseudocode(information from the original paper):

c[i][i] = 1 c[i][j] = 0 where: c is the cross-correlation matrix i is the index of one representation's neuron j is the index of the second representation's neuron

References

Setup

!pip install tensorflow-addons
import os # slightly faster improvements, on the first epoch 30 second decrease and a 1-2 second # decrease in epoch time. Overall saves approx. 5 min of training time # Allocates two threads for a gpu private which allows more operations to be # done faster os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" import tensorflow as tf # framework from tensorflow import keras # for tf.keras import tensorflow_addons as tfa # LAMB optimizer and gaussian_blur_2d function import numpy as np # np.random.random import matplotlib.pyplot as plt # graphs import datetime # tensorboard logs naming # XLA optimization for faster performance(up to 10-15 minutes total time saved) tf.config.optimizer.set_jit(True)

Load the CIFAR-10 dataset

[ (train_features, train_labels), (test_features, test_labels), ] = keras.datasets.cifar10.load_data() train_features = train_features / 255.0 test_features = test_features / 255.0

Necessary Hyperparameters

# Batch size of dataset BATCH_SIZE = 512 # Width and height of image IMAGE_SIZE = 32

Augmentation Utilities

The Barlow twins algorithm is heavily reliant on Augmentation. One unique feature of the method is that sometimes, augmentations probabilistically occur.

Augmentations

  • RandomToGrayscale: randomly applies grayscale to image 20% of the time

  • RandomColorJitter: randomly applies color jitter 80% of the time

  • RandomFlip: randomly flips image horizontally 50% of the time

  • RandomResizedCrop: randomly crops an image to a random size then resizes. This happens 100% of the time

  • RandomSolarize: randomly applies solarization to an image 20% of the time

  • RandomBlur: randomly blurs an image 20% of the time

class Augmentation(keras.layers.Layer): """Base augmentation class. Base augmentation class. Contains the random_execute method. Methods: random_execute: method that returns true or false based on a probability. Used to determine whether an augmentation will be run. """ def __init__(self): super().__init__() @tf.function def random_execute(self, prob: float) -> bool: """random_execute function. Arguments: prob: a float value from 0-1 that determines the probability. Returns: returns true or false based on the probability. """ return tf.random.uniform([], minval=0, maxval=1) < prob class RandomToGrayscale(Augmentation): """RandomToGrayscale class. RandomToGrayscale class. Randomly makes an image grayscaled based on the random_execute method. There is a 20% chance that an image will be grayscaled. Methods: call: method that grayscales an image 20% of the time. """ @tf.function def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Arguments: x: a tf.Tensor representing the image. Returns: returns a grayscaled version of the image 20% of the time and the original image 80% of the time. """ if self.random_execute(0.2): x = tf.image.rgb_to_grayscale(x) x = tf.tile(x, [1, 1, 3]) return x class RandomColorJitter(Augmentation): """RandomColorJitter class. RandomColorJitter class. Randomly adds color jitter to an image. Color jitter means to add random brightness, contrast, saturation, and hue to an image. There is a 80% chance that an image will be randomly color-jittered. Methods: call: method that color-jitters an image 80% of the time. """ @tf.function def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Adds color jitter to image, including: Brightness change by a max-delta of 0.8 Contrast change by a max-delta of 0.8 Saturation change by a max-delta of 0.8 Hue change by a max-delta of 0.2 Originally, the same deltas of the original paper were used, but a performance boost of almost 2% was found when doubling them. Arguments: x: a tf.Tensor representing the image. Returns: returns a color-jittered version of the image 80% of the time and the original image 20% of the time. """ if self.random_execute(0.8): x = tf.image.random_brightness(x, 0.8) x = tf.image.random_contrast(x, 0.4, 1.6) x = tf.image.random_saturation(x, 0.4, 1.6) x = tf.image.random_hue(x, 0.2) return x class RandomFlip(Augmentation): """RandomFlip class. RandomFlip class. Randomly flips image horizontally. There is a 50% chance that an image will be randomly flipped. Methods: call: method that flips an image 50% of the time. """ @tf.function def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Randomly flips the image. Arguments: x: a tf.Tensor representing the image. Returns: returns a flipped version of the image 50% of the time and the original image 50% of the time. """ if self.random_execute(0.5): x = tf.image.random_flip_left_right(x) return x class RandomResizedCrop(Augmentation): """RandomResizedCrop class. RandomResizedCrop class. Randomly crop an image to a random size, then resize the image back to the original size. Attributes: image_size: The dimension of the image Methods: __call__: method that does random resize crop to the image. """ def __init__(self, image_size): super().__init__() self.image_size = image_size def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Does random resize crop by randomly cropping an image to a random size 75% - 100% the size of the image. Then resizes it. Arguments: x: a tf.Tensor representing the image. Returns: returns a randomly cropped image. """ rand_size = tf.random.uniform( shape=[], minval=int(0.75 * self.image_size), maxval=1 * self.image_size, dtype=tf.int32, ) crop = tf.image.random_crop(x, (rand_size, rand_size, 3)) crop_resize = tf.image.resize(crop, (self.image_size, self.image_size)) return crop_resize class RandomSolarize(Augmentation): """RandomSolarize class. RandomSolarize class. Randomly solarizes an image. Solarization is when pixels accidentally flip to an inverted state. Methods: call: method that does random solarization 20% of the time. """ @tf.function def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Randomly solarizes the image. Arguments: x: a tf.Tensor representing the image. Returns: returns a solarized version of the image 20% of the time and the original image 80% of the time. """ if self.random_execute(0.2): # flips abnormally low pixels to abnormally high pixels x = tf.where(x < 10, x, 255 - x) return x class RandomBlur(Augmentation): """RandomBlur class. RandomBlur class. Randomly blurs an image. Methods: call: method that does random blur 20% of the time. """ @tf.function def call(self, x: tf.Tensor) -> tf.Tensor: """call function. Randomly solarizes the image. Arguments: x: a tf.Tensor representing the image. Returns: returns a blurred version of the image 20% of the time and the original image 80% of the time. """ if self.random_execute(0.2): s = np.random.random() return tfa.image.gaussian_filter2d(image=x, sigma=s) return x class RandomAugmentor(keras.Model): """RandomAugmentor class. RandomAugmentor class. Chains all the augmentations into one pipeline. Attributes: image_size: An integer represing the width and height of the image. Designed to be used for square images. random_resized_crop: Instance variable representing the RandomResizedCrop layer. random_flip: Instance variable representing the RandomFlip layer. random_color_jitter: Instance variable representing the RandomColorJitter layer. random_blur: Instance variable representing the RandomBlur layer random_to_grayscale: Instance variable representing the RandomToGrayscale layer random_solarize: Instance variable representing the RandomSolarize layer Methods: call: chains layers in pipeline together """ def __init__(self, image_size: int): super().__init__() self.image_size = image_size self.random_resized_crop = RandomResizedCrop(image_size) self.random_flip = RandomFlip() self.random_color_jitter = RandomColorJitter() self.random_blur = RandomBlur() self.random_to_grayscale = RandomToGrayscale() self.random_solarize = RandomSolarize() def call(self, x: tf.Tensor) -> tf.Tensor: x = self.random_resized_crop(x) x = self.random_flip(x) x = self.random_color_jitter(x) x = self.random_blur(x) x = self.random_to_grayscale(x) x = self.random_solarize(x) x = tf.clip_by_value(x, 0, 1) return x bt_augmentor = RandomAugmentor(IMAGE_SIZE)

Data Loading

A class that creates the barlow twins' dataset.

The dataset consists of two copies of each image, with each copy receiving different augmentations.

class BTDatasetCreator: """Barlow twins dataset creator class. BTDatasetCreator class. Responsible for creating the barlow twins' dataset. Attributes: options: tf.data.Options needed to configure a setting that may improve performance. seed: random seed for shuffling. Used to synchronize two augmented versions. augmentor: augmentor used for augmentation. Methods: __call__: creates barlow dataset. augmented_version: creates 1 half of the dataset. """ def __init__(self, augmentor: RandomAugmentor, seed: int = 1024): self.options = tf.data.Options() self.options.threading.max_intra_op_parallelism = 1 self.seed = seed self.augmentor = augmentor def augmented_version(self, ds: list) -> tf.data.Dataset: return ( tf.data.Dataset.from_tensor_slices(ds) .shuffle(1000, seed=self.seed) .map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE) .batch(BATCH_SIZE, drop_remainder=True) .prefetch(tf.data.AUTOTUNE) .with_options(self.options) ) def __call__(self, ds: list) -> tf.data.Dataset: a1 = self.augmented_version(ds) a2 = self.augmented_version(ds) return tf.data.Dataset.zip((a1, a2)).with_options(self.options) augment_versions = BTDatasetCreator(bt_augmentor)(train_features)

View examples of dataset.

sample_augment_versions = iter(augment_versions) def plot_values(batch: tuple): fig, axs = plt.subplots(3, 3) fig1, axs1 = plt.subplots(3, 3) fig.suptitle("Augmentation 1") fig1.suptitle("Augmentation 2") a1, a2 = batch # plots images on both tables for i in range(3): for j in range(3): # CHANGE(add / 255) axs[i][j].imshow(a1[3 * i + j]) axs[i][j].axis("off") axs1[i][j].imshow(a2[3 * i + j]) axs1[i][j].axis("off") plt.show() plot_values(next(sample_augment_versions))

Pseudocode of loss and model

The following sections follow the original author's pseudocode containing both model and loss functions(see diagram below). Also contains a reference of variables used.

pseudocode

Reference:

y_a: first augmented version of original image. y_b: second augmented version of original image. z_a: model representation(embeddings) of y_a. z_b: model representation(embeddings) of y_b. z_a_norm: normalized z_a. z_b_norm: normalized z_b. c: cross correlation matrix. c_diff: diagonal portion of loss(invariance term). off_diag: off-diagonal portion of loss(redundancy reduction term).

BarlowLoss: barlow twins model's loss function

Barlow Twins uses the cross correlation matrix for its loss. There are two parts to the loss function:

  • The invariance term(diagonal). This part is used to make the diagonals of the matrix into 1s. When this is the case, the matrix shows that the images are correlated(same).

  • The loss function subtracts 1 from the diagonal and squares the values.

  • The redundancy reduction term(off-diagonal). Here, the barlow twins loss function aims to make these values zero. As mentioned before, it is redundant if the representation neurons are correlated with values that are not on the diagonal.

  • Off diagonals are squared.

After this the two parts are summed together.

class BarlowLoss(keras.losses.Loss): """BarlowLoss class. BarlowLoss class. Creates a loss function based on the cross-correlation matrix. Attributes: batch_size: the batch size of the dataset lambda_amt: the value for lambda(used in cross_corr_matrix_loss) Methods: __init__: gets instance variables call: gets the loss based on the cross-correlation matrix make_diag_zeros: Used in calculating off-diagonal section of loss function; makes diagonals zeros. cross_corr_matrix_loss: creates loss based on cross correlation matrix. """ def __init__(self, batch_size: int): """__init__ method. Gets the instance variables Arguments: batch_size: An integer value representing the batch size of the dataset. Used for cross correlation matrix calculation. """ super().__init__() self.lambda_amt = 5e-3 self.batch_size = batch_size def get_off_diag(self, c: tf.Tensor) -> tf.Tensor: """get_off_diag method. Makes the diagonals of the cross correlation matrix zeros. This is used in the off-diagonal portion of the loss function, where we take the squares of the off-diagonal values and sum them. Arguments: c: A tf.tensor that represents the cross correlation matrix Returns: Returns a tf.tensor which represents the cross correlation matrix with its diagonals as zeros. """ zero_diag = tf.zeros(c.shape[-1]) return tf.linalg.set_diag(c, zero_diag) def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor: """cross_corr_matrix_loss method. Gets the loss based on the cross correlation matrix. We want the diagonals to be 1's and everything else to be zeros to show that the two augmented images are similar. Loss function procedure: take the diagonal of the cross-correlation matrix, subtract by 1, and square that value so no negatives. Take the off-diagonal of the cc-matrix(see get_off_diag()), square those values to get rid of negatives and increase the value, and multiply it by a lambda to weight it such that it is of equal value to the optimizer as the diagonal(there are more values off-diag then on-diag) Take the sum of the first and second parts and then sum them together. Arguments: c: A tf.tensor that represents the cross correlation matrix Returns: Returns a tf.tensor which represents the cross correlation matrix with its diagonals as zeros. """ # subtracts diagonals by one and squares them(first part) c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2) # takes off diagonal, squares it, multiplies with lambda(second part) off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt # sum first and second parts together loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag) return loss def normalize(self, output: tf.Tensor) -> tf.Tensor: """normalize method. Normalizes the model prediction. Arguments: output: the model prediction. Returns: Returns a normalized version of the model prediction. """ return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std( output, axis=0 ) def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor: """cross_corr_matrix method. Creates a cross correlation matrix from the predictions. It transposes the first prediction and multiplies this with the second, creating a matrix with shape (n_dense_units, n_dense_units). See build_twin() for more info. Then it divides this with the batch size. Arguments: z_a_norm: A normalized version of the first prediction. z_b_norm: A normalized version of the second prediction. Returns: Returns a cross correlation matrix. """ return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor: """call method. Makes the cross-correlation loss. Uses the CreateCrossCorr class to make the cross corr matrix, then finds the loss and returns it(see cross_corr_matrix_loss()). Arguments: z_a: The prediction of the first set of augmented data. z_b: the prediction of the second set of augmented data. Returns: Returns a (rank-0) tf.Tensor that represents the loss. """ z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b) c = self.cross_corr_matrix(z_a_norm, z_b_norm) loss = self.cross_corr_matrix_loss(c) return loss

Barlow Twins' Model Architecture

The model has two parts:

  • The encoder network, which is a resnet-34.

  • The projector network, which creates the model embeddings.

  • This consists of an MLP with 3 dense-batchnorm-relu layers.

Resnet encoder network implementation:

class ResNet34: """Resnet34 class. Responsible for the Resnet 34 architecture. Modified from https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2. https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2. View their website for more information. """ def identity_block(self, x, filter): # copy tensor to variable called x_skip x_skip = x # Layer 1 x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x) x = tf.keras.layers.BatchNormalization(axis=3)(x) x = tf.keras.layers.Activation("relu")(x) # Layer 2 x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x) x = tf.keras.layers.BatchNormalization(axis=3)(x) # Add Residue x = tf.keras.layers.Add()([x, x_skip]) x = tf.keras.layers.Activation("relu")(x) return x def convolutional_block(self, x, filter): # copy tensor to variable called x_skip x_skip = x # Layer 1 x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x) x = tf.keras.layers.BatchNormalization(axis=3)(x) x = tf.keras.layers.Activation("relu")(x) # Layer 2 x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x) x = tf.keras.layers.BatchNormalization(axis=3)(x) # Processing Residue with conv(1,1) x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip) # Add Residue x = tf.keras.layers.Add()([x, x_skip]) x = tf.keras.layers.Activation("relu")(x) return x def __call__(self, shape=(32, 32, 3)): # Step 1 (Setup Input Layer) x_input = tf.keras.layers.Input(shape) x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input) # Step 2 (Initial Conv layer along with maxPool) x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x) x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.Activation("relu")(x) x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x) # Define size of sub-blocks and initial filter size block_layers = [3, 4, 6, 3] filter_size = 64 # Step 3 Add the Resnet Blocks for i in range(4): if i == 0: # For sub-block 1 Residual/Convolutional block not needed for j in range(block_layers[i]): x = self.identity_block(x, filter_size) else: # One Residual/Convolutional Block followed by Identity blocks # The filter size will go on increasing by a factor of 2 filter_size = filter_size * 2 x = self.convolutional_block(x, filter_size) for j in range(block_layers[i] - 1): x = self.identity_block(x, filter_size) # Step 4 End Dense Network x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x) x = tf.keras.layers.Flatten()(x) model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34") return model

Projector network:

def build_twin() -> keras.Model: """build_twin method. Builds a barlow twins model consisting of an encoder(resnet-34) and a projector, which generates embeddings for the images Returns: returns a barlow twins model """ # number of dense neurons in the projector n_dense_neurons = 5000 # encoder network resnet = ResNet34()() last_layer = resnet.layers[-1].output # intermediate layers of the projector network n_layers = 2 for i in range(n_layers): dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}") if i == 0: x = dense(last_layer) else: x = dense(x) x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x) x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x) x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x) model = keras.Model(resnet.input, x) return model

Training Loop Model

See pseudocode for reference.

class BarlowModel(keras.Model): """BarlowModel class. BarlowModel class. Responsible for making predictions and handling gradient descent with the optimizer. Attributes: model: the barlow model architecture. loss_tracker: the loss metric. Methods: train_step: one train step; do model predictions, loss, and optimizer step. metrics: Returns metrics. """ def __init__(self): super().__init__() self.model = build_twin() self.loss_tracker = keras.metrics.Mean(name="loss") @property def metrics(self): return [self.loss_tracker] def train_step(self, batch: tf.Tensor) -> tf.Tensor: """train_step method. Do one train step. Make model predictions, find loss, pass loss to optimizer, and make optimizer apply gradients. Arguments: batch: one batch of data to be given to the loss function. Returns: Returns a dictionary with the loss metric. """ # get the two augmentations from the batch y_a, y_b = batch with tf.GradientTape() as tape: # get two versions of predictions z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True) loss = self.loss(z_a, z_b) grads_model = tape.gradient(loss, self.model.trainable_variables) self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables)) self.loss_tracker.update_state(loss) return {"loss": self.loss_tracker.result()}

Model Training

  • Used the LAMB optimizer, instead of ADAM or SGD.

  • Similar to the LARS optimizer used in the paper, and lets the model converge much faster than other methods.

  • Expected training time: 1 hour 30 min. Go and eat a snack or take a nap or something.

# sets up model, optimizer, loss bm = BarlowModel() # chose the LAMB optimizer due to high batch sizes. Converged MUCH faster # than ADAM or SGD optimizer = tfa.optimizers.LAMB() loss = BarlowLoss(BATCH_SIZE) bm.compile(optimizer=optimizer, loss=loss) # Expected training time: 1 hours 30 min history = bm.fit(augment_versions, epochs=160) plt.plot(history.history["loss"]) plt.show()

Evaluation

Linear evaluation: to evaluate the model's performance, we add a linear dense layer at the end and freeze the main model's weights, only letting the dense layer to be tuned. If the model actually learned something, then the accuracy would be significantly higher than random chance.

Accuracy on CIFAR-10 : 64% for this notebook. This is much better than the 10% we get from random guessing.

# Approx: 64% accuracy with this barlow twins model. xy_ds = ( tf.data.Dataset.from_tensor_slices((train_features, train_labels)) .shuffle(1000) .batch(BATCH_SIZE, drop_remainder=True) .prefetch(tf.data.AUTOTUNE) ) test_ds = ( tf.data.Dataset.from_tensor_slices((test_features, test_labels)) .shuffle(1000) .batch(BATCH_SIZE, drop_remainder=True) .prefetch(tf.data.AUTOTUNE) ) model = keras.models.Sequential( [ bm.model, keras.layers.Dense( 10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02) ), ] ) model.layers[0].trainable = False linear_optimizer = tfa.optimizers.LAMB() model.compile( optimizer=linear_optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) model.fit(xy_ds, epochs=35, validation_data=test_ds)

Conclusion

  • Barlow Twins is a simple and concise method for contrastive and self-supervised learning.

  • With this resnet-34 model architecture, we were able to reach 62-64% validation accuracy.

Use-Cases of Barlow-Twins(and contrastive learning in General)

  • Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy when it wasn't even trained with the labels. It can be used when you have little labeled data but a lot of unlabeled data.

  • You do barlow twins training on the unlabeled data, and then you do secondary training with the labeled data.