Path: blob/master/guides/keras_cv/simsiam_with_kerascv.py
3282 views
"""1Title: SimSiam Training with TensorFlow Similarity and KerasCV2Author: [lukewood](https://lukewood.xyz), Ian Stenbit, Owen Vallis3Date created: 2023/01/224Last modified: 2023/01/225Description: Train a KerasCV model using unlabelled data with SimSiam.6"""78"""9## Overview1011[TensorFlow similarity](https://github.com/tensorflow/similarity) makes it easy to train12KerasCV models on unlabelled corpuses of data using contrastive learning algorithms such13as SimCLR, SimSiam, and Barlow Twins. In this guide, we will train a KerasCV model14using the SimSiam implementation from TensorFlow Similarity.1516## Background1718Self-supervised learning is an approach to pre-training models using unlabeled data.19This approach drastically increases accuracy when you have very few labeled examples but20a lot of unlabelled data.21The key insight is that you can train a self-supervised model to learn data22representations by contrasting multiple augmented views of the same example.23These learned representations capture data invariants, e.g., object translation, color24jitter, noise, etc. Training a simple linear classifier on top of the frozen25representations is easier and requires fewer labels because the pre-trained model26already produces meaningful and generally useful features.2728Overall, self-supervised pre-training learns representations which are [more generic and29robust than other approaches to augmented training and pre-training](https://arxiv.org/abs/2002.05709).30An overview of the general contrastive learning process is shown below:31323334In this tutorial, we will use the [SimSiam](https://arxiv.org/abs/2011.10566) algorithm35for contrastive learning. As of 2022, SimSiam is the state of the art algorithm for36contrastive learning; allowing for unprecedented scores on CIFAR-100 and other datasets.3738You may need to install:3940```41pip -q install tensorflow_similarity42pip -q install keras-cv43```4445To get started, we will sort out some imports.46"""47import resource48import gc49import os50import random51import time52import tensorflow_addons as tfa53import keras_cv54from pathlib import Path55import matplotlib.pyplot as plt56import numpy as np57from tensorflow import keras58from tensorflow.keras import layers59from tabulate import tabulate60import tensorflow_similarity as tfsim # main package61import tensorflow as tf62from keras_cv import layers as cv_layers6364import tensorflow_datasets as tfds6566"""67Lets sort out some high level config issues and define some constants.68The resource limit increase is required to load STL-10, `tfsim.utils.tf_cap_memory()`69prevents TensorFlow from hogging the GPU memory in a cluster, and70`tfds.disable_progress_bar()` makes tfds less noisy.71"""7273low, high = resource.getrlimit(resource.RLIMIT_NOFILE)74resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))75tfsim.utils.tf_cap_memory() # Avoid GPU memory blow up76tfds.disable_progress_bar()7778BATCH_SIZE = 51279PRE_TRAIN_EPOCHS = 5080VAL_STEPS_PER_EPOCH = 2081WEIGHT_DECAY = 5e-482INIT_LR = 3e-2 * int(BATCH_SIZE / 256)83WARMUP_LR = 0.084WARMUP_STEPS = 085DIM = 20488687"""88## Data loading8990Next, we will load the STL-10 dataset. STL-10 is a dataset consisting of 100k unlabelled91images, 5k labelled training images, and 10k labelled test images. Due to this distribution,92STL-10 is commonly used as a benchmark for contrastive learning models.9394First lets load our unlabelled data95"""96train_ds = tfds.load("stl10", split="unlabelled")97train_ds = train_ds.map(98lambda entry: entry["image"], num_parallel_calls=tf.data.AUTOTUNE99)100train_ds = train_ds.map(101lambda image: tf.cast(image, tf.float32), num_parallel_calls=tf.data.AUTOTUNE102)103train_ds = train_ds.shuffle(buffer_size=8 * BATCH_SIZE, reshuffle_each_iteration=True)104105"""106Next, we need to prepare some labelled samples.107This is done so that TensorFlow similarity can probe the learned embedding to ensure108that the model is learning appropriately.109"""110111(x_raw_train, y_raw_train), ds_info = tfds.load(112"stl10", split="train", as_supervised=True, batch_size=-1, with_info=True113)114x_raw_train, y_raw_train = tf.cast(x_raw_train, tf.float32), tf.cast(115y_raw_train, tf.float32116)117x_test, y_test = tfds.load(118"stl10",119split="test",120as_supervised=True,121batch_size=-1,122)123x_test, y_test = tf.cast(x_test, tf.float32), tf.cast(y_test, tf.float32)124125"""126In self supervised learning, queries and indexes are labeled subset datasets used to127evaluate the quality of the produced latent embedding. The following code assembles128these datasets:129"""130131# Compute the indices for query, index, val, and train splits132query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []133for cid in range(ds_info.features["label"].num_classes):134idxs = tf.random.shuffle(tf.where(y_raw_train == cid))135idxs = tf.reshape(idxs, (-1,))136query_idxs.extend(idxs[:100]) # 200 query examples per class137index_idxs.extend(idxs[100:200]) # 200 index examples per class138val_idxs.extend(idxs[200:300]) # 100 validation examples per class139train_idxs.extend(idxs[300:]) # The remaining are used for training140141random.shuffle(query_idxs)142random.shuffle(index_idxs)143random.shuffle(val_idxs)144random.shuffle(train_idxs)145146147def create_split(idxs: list) -> tuple:148x, y = [], []149for idx in idxs:150x.append(x_raw_train[int(idx)])151y.append(y_raw_train[int(idx)])152return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(153np.array(y), dtype=tf.int64154)155156157x_query, y_query = create_split(query_idxs)158x_index, y_index = create_split(index_idxs)159x_val, y_val = create_split(val_idxs)160x_train, y_train = create_split(train_idxs)161162PRE_TRAIN_STEPS_PER_EPOCH = tf.data.experimental.cardinality(train_ds) // BATCH_SIZE163PRE_TRAIN_STEPS_PER_EPOCH = int(PRE_TRAIN_STEPS_PER_EPOCH.numpy())164165print(166tabulate(167[168["train", tf.data.experimental.cardinality(train_ds), None],169["val", x_val.shape, y_val.shape],170["query", x_query.shape, y_query.shape],171["index", x_index.shape, y_index.shape],172["test", x_test.shape, y_test.shape],173],174headers=["# of Examples", "Labels"],175)176)177178"""179## Augmentations180181Self-supervised networks require at least two augmented "views" of each example.182This can be created using a dataset and an augmentation function.183The dataset treats each example in the batch as its own class and then the augment184function produces two separate views for each example.185186This means the resulting batch will yield tuples containing the two views, i.e.,187Tuple[(BATCH_SIZE, 32, 32, 3), (BATCH_SIZE, 32, 32, 3)].188189Using KerasCV, it is trivial to construct an augmenter that performs as the one190described in the original SimSiam paper. Lets do that below.191"""192193target_size = (96, 96)194crop_area_factor = (0.08, 1)195aspect_ratio_factor = (3 / 4, 4 / 3)196grayscale_rate = 0.2197color_jitter_rate = 0.8198brightness_factor = 0.2199contrast_factor = 0.8200saturation_factor = (0.3, 0.7)201hue_factor = 0.2202203augmenter = keras.Sequential(204[205cv_layers.RandomFlip("horizontal"),206cv_layers.RandomCropAndResize(207target_size,208crop_area_factor=crop_area_factor,209aspect_ratio_factor=aspect_ratio_factor,210),211cv_layers.RandomApply(212cv_layers.Grayscale(output_channels=3), rate=grayscale_rate213),214cv_layers.RandomApply(215cv_layers.RandomColorJitter(216value_range=(0, 255),217brightness_factor=brightness_factor,218contrast_factor=contrast_factor,219saturation_factor=saturation_factor,220hue_factor=hue_factor,221),222rate=color_jitter_rate,223),224],225)226227"""228Next, lets pass our images through this pipeline.229Note that KerasCV supports batched augmentation, so batching before230augmentation dramatically improves performance231232"""233234235@tf.function()236def process(img):237return augmenter(img), augmenter(img)238239240def prepare_dataset(dataset):241dataset = dataset.repeat()242dataset = dataset.shuffle(1024)243dataset = dataset.batch(BATCH_SIZE)244dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE)245return dataset.prefetch(tf.data.AUTOTUNE)246247248train_ds = prepare_dataset(train_ds)249250val_ds = tf.data.Dataset.from_tensor_slices(x_val)251val_ds = prepare_dataset(val_ds)252253print("train_ds", train_ds)254print("val_ds", val_ds)255256"""257Lets visualize our pairs using the `tfsim.visualization` utility package.258"""259display_imgs = next(train_ds.as_numpy_iterator())260max_pixel = np.max([display_imgs[0].max(), display_imgs[1].max()])261min_pixel = np.min([display_imgs[0].min(), display_imgs[1].min()])262263tfsim.visualization.visualize_views(264views=display_imgs,265num_imgs=16,266views_per_col=8,267max_pixel_value=max_pixel,268min_pixel_value=min_pixel,269)270271"""272## Model Creation273274Now that our data and augmentation pipeline is setup, we can move on to275constructing the contrastive learning pipeline. First, lets produce a backbone.276For this task, we will use a KerasCV ResNet18 model as the backbone.277"""278279280def get_backbone(input_shape):281inputs = layers.Input(shape=input_shape)282x = inputs283x = keras_cv.models.ResNet18(284input_shape=input_shape,285include_rescaling=True,286include_top=False,287pooling="avg",288)(x)289return tfsim.models.SimilarityModel(inputs, x)290291292backbone = get_backbone((96, 96, 3))293backbone.summary()294295"""296This MLP is common to all the self-supervised models and is typically a stack of 3297layers of the same size. However, SimSiam only uses 2 layers for the smaller CIFAR298images. Having too much capacity in the models can make it difficult for the loss to299stabilize and converge.300301Note: This is the model output that is returned by `ContrastiveModel.predict()` and302represents the distance based embedding. This embedding can be used for the KNN303lookups and matching classification metrics. However, when using the pre-train304model for downstream tasks, only the `ContrastiveModel.backbone` is used.305"""306307308def get_projector(input_dim, dim, activation="relu", num_layers: int = 3):309inputs = tf.keras.layers.Input((input_dim,), name="projector_input")310x = inputs311312for i in range(num_layers - 1):313x = tf.keras.layers.Dense(314dim,315use_bias=False,316kernel_initializer=tf.keras.initializers.LecunUniform(),317name=f"projector_layer_{i}",318)(x)319x = tf.keras.layers.BatchNormalization(320epsilon=1.001e-5, name=f"batch_normalization_{i}"321)(x)322x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_{i}")(323x324)325x = tf.keras.layers.Dense(326dim,327use_bias=False,328kernel_initializer=tf.keras.initializers.LecunUniform(),329name="projector_output",330)(x)331x = tf.keras.layers.BatchNormalization(332epsilon=1.001e-5,333center=False, # Page:5, Paragraph:2 of SimSiam paper334scale=False, # Page:5, Paragraph:2 of SimSiam paper335name=f"batch_normalization_ouput",336)(x)337# Metric Logging layer. Monitors the std of the layer activations.338# Degenerate solutions colapse to 0 while valid solutions will move339# towards something like 0.0220. The actual number will depend on the layer size.340o = tfsim.layers.ActivationStdLoggingLayer(name="proj_std")(x)341projector = tf.keras.Model(inputs, o, name="projector")342return projector343344345projector = get_projector(input_dim=backbone.output.shape[-1], dim=DIM, num_layers=2)346projector.summary()347348349"""350Finally, we must construct the predictor. The predictor is used in SimSiam, and is a351simple stack of two MLP layers, containing a bottleneck in the hidden layer.352"""353354355def get_predictor(input_dim, hidden_dim=512, activation="relu"):356inputs = tf.keras.layers.Input(shape=(input_dim,), name="predictor_input")357x = inputs358359x = tf.keras.layers.Dense(360hidden_dim,361use_bias=False,362kernel_initializer=tf.keras.initializers.LecunUniform(),363name="predictor_layer_0",364)(x)365x = tf.keras.layers.BatchNormalization(366epsilon=1.001e-5, name="batch_normalization_0"367)(x)368x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_0")(x)369370x = tf.keras.layers.Dense(371input_dim,372kernel_initializer=tf.keras.initializers.LecunUniform(),373name="predictor_output",374)(x)375# Metric Logging layer. Monitors the std of the layer activations.376# Degenerate solutions colapse to 0 while valid solutions will move377# towards something like 0.0220. The actual number will depend on the layer size.378o = tfsim.layers.ActivationStdLoggingLayer(name="pred_std")(x)379predictor = tf.keras.Model(inputs, o, name="predictor")380return predictor381382383predictor = get_predictor(input_dim=DIM, hidden_dim=512)384predictor.summary()385386387"""388## Training389390First, we need to initialize our training model, loss, and optimizer.391"""392loss = tfsim.losses.SimSiamLoss(projection_type="cosine_distance", name="simsiam")393394contrastive_model = tfsim.models.ContrastiveModel(395backbone=backbone,396projector=projector,397predictor=predictor, # NOTE: simiam requires predictor model.398algorithm="simsiam",399name="simsiam",400)401lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(402initial_learning_rate=INIT_LR,403decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,404)405wd_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(406initial_learning_rate=WEIGHT_DECAY,407decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,408)409optimizer = tfa.optimizers.SGDW(410learning_rate=lr_decayed_fn, weight_decay=wd_decayed_fn, momentum=0.9411)412413"""414Next we can compile the model the same way you compile any other Keras model.415"""416417contrastive_model.compile(418optimizer=optimizer,419loss=loss,420)421422"""423We track the training using `EvalCallback`.424`EvalCallback` creates an index at the end of each epoch and provides a proxy for the425nearest neighbor matching classification using `binary_accuracy`.426Calculates how often the query label matches the derived lookup label.427428Accuracy is technically (TP+TN)/(TP+FP+TN+FN), but here we filter all429queries above the distance threshold. In the case of binary matching, this430makes all the TPs and FPs below the distance threshold and all the TNs and431FNs above the distance threshold.432433As we are only concerned with the matches below the distance threshold, the434accuracy simplifies to TP/(TP+FP) and is equivalent to the precision with435respect to the unfiltered queries. However, we also want to consider the436query coverage at the distance threshold, i.e., the percentage of queries437that return a match, computed as (TP+FP)/(TP+FP+TN+FN). Therefore, we can438take $ precision \times query_coverage $ to produce a measure that capture439the precision scaled by the query coverage. This simplifies down to the440binary accuracy presented here, giving TP/(TP+FP+TN+FN).441"""442443DATA_PATH = Path("./")444log_dir = DATA_PATH / "models" / "logs" / f"{loss.name}_{time.time()}"445chkpt_dir = DATA_PATH / "models" / "checkpoints" / f"{loss.name}_{time.time()}"446447callbacks = [448tfsim.callbacks.EvalCallback(449tf.cast(x_query, tf.float32),450y_query,451tf.cast(x_index, tf.float32),452y_index,453metrics=["binary_accuracy"],454k=1,455tb_logdir=log_dir,456),457tf.keras.callbacks.TensorBoard(458log_dir=log_dir,459histogram_freq=1,460update_freq=100,461),462tf.keras.callbacks.ModelCheckpoint(463filepath=chkpt_dir,464monitor="val_loss",465mode="min",466save_best_only=True,467save_weights_only=True,468),469]470471"""472All that is left to do is run fit()!473"""474475print(train_ds)476print(val_ds)477history = contrastive_model.fit(478train_ds,479epochs=PRE_TRAIN_EPOCHS,480steps_per_epoch=PRE_TRAIN_STEPS_PER_EPOCH,481validation_data=val_ds,482validation_steps=VAL_STEPS_PER_EPOCH,483callbacks=callbacks,484)485486487"""488## Plotting and Evaluation489"""490491plt.figure(figsize=(15, 4))492plt.subplot(1, 3, 1)493plt.plot(history.history["loss"])494plt.grid()495plt.title(f"{loss.name} - loss")496497plt.subplot(1, 3, 2)498plt.plot(history.history["proj_std"], label="proj")499if "pred_std" in history.history:500plt.plot(history.history["pred_std"], label="pred")501plt.grid()502plt.title(f"{loss.name} - std metrics")503plt.legend()504505plt.subplot(1, 3, 3)506plt.plot(history.history["binary_accuracy"], label="acc")507plt.grid()508plt.title(f"{loss.name} - match metrics")509plt.legend()510511plt.show()512513514"""515## Fine Tuning on the Labelled Data516517As a final step we will fine tune a classifier on 10% of the training data. This will518allow us to evaluate the quality of our learned representation. First, we handle data519loading:520"""521522eval_augmenter = keras.Sequential(523[524keras_cv.layers.RandomCropAndResize(525(96, 96), crop_area_factor=(0.8, 1.0), aspect_ratio_factor=(1.0, 1.0)526),527keras_cv.layers.RandomFlip(mode="horizontal"),528]529)530531eval_train_ds = tf.data.Dataset.from_tensor_slices(532(x_raw_train, tf.keras.utils.to_categorical(y_raw_train, 10))533)534eval_train_ds = eval_train_ds.repeat()535eval_train_ds = eval_train_ds.shuffle(1024)536eval_train_ds = eval_train_ds.map(lambda x, y: (eval_augmenter(x), y), tf.data.AUTOTUNE)537eval_train_ds = eval_train_ds.batch(BATCH_SIZE)538eval_train_ds = eval_train_ds.prefetch(tf.data.AUTOTUNE)539540eval_val_ds = tf.data.Dataset.from_tensor_slices(541(x_test, tf.keras.utils.to_categorical(y_test, 10))542)543eval_val_ds = eval_val_ds.repeat()544eval_val_ds = eval_val_ds.shuffle(1024)545eval_val_ds = eval_val_ds.batch(BATCH_SIZE)546eval_val_ds = eval_val_ds.prefetch(tf.data.AUTOTUNE)547548"""549## Benchmark Against a Naive Model550551Finally, lets setup a naive model that does not leverage the unlabeled data corpus.552"""553554TEST_EPOCHS = 50555TEST_STEPS_PER_EPOCH = x_raw_train.shape[0] // BATCH_SIZE556557558def get_eval_model(img_size, backbone, total_steps, trainable=True, lr=1.8):559backbone.trainable = trainable560inputs = tf.keras.layers.Input((img_size, img_size, 3), name="eval_input")561x = backbone(inputs, training=trainable)562o = tf.keras.layers.Dense(10, activation="softmax")(x)563model = tf.keras.Model(inputs, o)564cosine_decayed_lr = tf.keras.experimental.CosineDecay(565initial_learning_rate=lr, decay_steps=total_steps566)567opt = tf.keras.optimizers.SGD(cosine_decayed_lr, momentum=0.9)568model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])569return model570571572no_pt_eval_model = get_eval_model(573img_size=96,574backbone=get_backbone((96, 96, 3)),575total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,576trainable=True,577lr=1e-3,578)579no_pt_history = no_pt_eval_model.fit(580eval_train_ds,581batch_size=BATCH_SIZE,582epochs=TEST_EPOCHS,583steps_per_epoch=TEST_STEPS_PER_EPOCH,584validation_data=eval_val_ds,585validation_steps=VAL_STEPS_PER_EPOCH,586)587588"""589Pretty bad results! Lets try fine-tuning our SimSiam pretrained model:590"""591592pt_eval_model = get_eval_model(593img_size=96,594backbone=contrastive_model.backbone,595total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,596trainable=False,597lr=30.0,598)599pt_eval_model.summary()600pt_history = pt_eval_model.fit(601eval_train_ds,602batch_size=BATCH_SIZE,603epochs=TEST_EPOCHS,604steps_per_epoch=TEST_STEPS_PER_EPOCH,605validation_data=eval_val_ds,606validation_steps=VAL_STEPS_PER_EPOCH,607)608609"""610All that is left to do is evaluate the models:611"""612print(613"no pretrain",614no_pt_eval_model.evaluate(615eval_val_ds,616steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,617),618)619print(620"pretrained",621pt_eval_model.evaluate(622eval_val_ds,623steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,624),625)626"""627Awesome! Our pretrained model stomped the non-pretrained model.628This accuracy is quite good for a ResNet18 on the STL-10 dataset.629For better results, try using an EfficientNetV2B0 instead.630Unfortunately, this will require a higher end graphics card as631SimSiam has a minimum batch size of 512.632633## Conclusion634635TensorFlow Similarity can be used to easily train KerasCV models using636contrastive algorithms such as SimCLR, SimSiam and BarlowTwins.637This allows you to leverage large corpuses of unlabelled data in your638model trainining pipeline.639640Some follow-up exercises to this tutorial:641642- Train a [`keras_cv.models.EfficientNetV2B0`](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/efficientnet_v2.py)643on STL-10644- Experiment with other data augmentation techniques in pretraining645- Train a model using the [BarlowTwins implementation](https://github.com/tensorflow/similarity/blob/master/examples/unsupervised_hello_world.ipynb) in TensorFlow similarity646- Try pretraining on your own dataset647"""648649650