Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/md/keras_cv/simsiam_with_kerascv.md
3299 views

SimSiam Training with TensorFlow Similarity and KerasCV

Author: lukewood, Ian Stenbit, Owen Vallis
Date created: 2023/01/22
Last modified: 2023/01/22
Description: Train a KerasCV model using unlabelled data with SimSiam.

View in Colab GitHub source


Overview

TensorFlow similarity makes it easy to train KerasCV models on unlabelled corpuses of data using contrastive learning algorithms such as SimCLR, SimSiam, and Barlow Twins. In this guide, we will train a KerasCV model using the SimSiam implementation from TensorFlow Similarity.


Background

Self-supervised learning is an approach to pre-training models using unlabeled data. This approach drastically increases accuracy when you have very few labeled examples but a lot of unlabelled data. The key insight is that you can train a self-supervised model to learn data representations by contrasting multiple augmented views of the same example. These learned representations capture data invariants, e.g., object translation, color jitter, noise, etc. Training a simple linear classifier on top of the frozen representations is easier and requires fewer labels because the pre-trained model already produces meaningful and generally useful features.

Overall, self-supervised pre-training learns representations which are more generic and robust than other approaches to augmented training and pre-training. An overview of the general contrastive learning process is shown below:

Contrastive overview

In this tutorial, we will use the SimSiam algorithm for contrastive learning. As of 2022, SimSiam is the state of the art algorithm for contrastive learning; allowing for unprecedented scores on CIFAR-100 and other datasets.

You may need to install:

pip -q install tensorflow_similarity pip -q install keras-cv

To get started, we will sort out some imports.

import resource import gc import os import random import time import tensorflow_addons as tfa import keras_cv from pathlib import Path import matplotlib.pyplot as plt import numpy as np from tensorflow import keras from tensorflow.keras import layers from tabulate import tabulate import tensorflow_similarity as tfsim # main package import tensorflow as tf from keras_cv import layers as cv_layers import tensorflow_datasets as tfds
``` You do not have Waymo Open Dataset installed, so KerasCV Waymo metrics are not available.

Your CPU supports instructions that this binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2 For maximum performance, you can install NMSLIB from sources pip install --no-binary :all: nmslib

</div> Lets sort out some high level config issues and define some constants. The resource limit increase is required to load STL-10, `tfsim.utils.tf_cap_memory()` prevents TensorFlow from hogging the GPU memory in a cluster, and `tfds.disable_progress_bar()` makes tfds less noisy. ```python low, high = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) tfsim.utils.tf_cap_memory() # Avoid GPU memory blow up tfds.disable_progress_bar() BATCH_SIZE = 512 PRE_TRAIN_EPOCHS = 50 VAL_STEPS_PER_EPOCH = 20 WEIGHT_DECAY = 5e-4 INIT_LR = 3e-2 * int(BATCH_SIZE / 256) WARMUP_LR = 0.0 WARMUP_STEPS = 0 DIM = 2048

Data loading

Next, we will load the STL-10 dataset. STL-10 is a dataset consisting of 100k unlabelled images, 5k labelled training images, and 10k labelled test images. Due to this distribution, STL-10 is commonly used as a benchmark for contrastive learning models.

First lets load our unlabelled data

train_ds = tfds.load("stl10", split="unlabelled") train_ds = train_ds.map( lambda entry: entry["image"], num_parallel_calls=tf.data.AUTOTUNE ) train_ds = train_ds.map( lambda image: tf.cast(image, tf.float32), num_parallel_calls=tf.data.AUTOTUNE ) train_ds = train_ds.shuffle(buffer_size=8 * BATCH_SIZE, reshuffle_each_iteration=True)
``` Downloading and preparing dataset 2.46 GiB (download: 2.46 GiB, generated: 1.86 GiB, total: 4.32 GiB) to ~/tensorflow_datasets/stl10/1.0.0... Dataset stl10 downloaded and prepared to ~/tensorflow_datasets/stl10/1.0.0. Subsequent calls will reuse this data. WARNING:tensorflow:From /home/lukewood/.local/lib/python3.7/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda functions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

WARNING:tensorflow:From /home/lukewood/.local/lib/python3.7/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda functions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

</div> Next, we need to prepare some labelled samples. This is done so that TensorFlow similarity can probe the learned embedding to ensure that the model is learning appropriately. ```python (x_raw_train, y_raw_train), ds_info = tfds.load( "stl10", split="train", as_supervised=True, batch_size=-1, with_info=True ) x_raw_train, y_raw_train = tf.cast(x_raw_train, tf.float32), tf.cast( y_raw_train, tf.float32 ) x_test, y_test = tfds.load( "stl10", split="test", as_supervised=True, batch_size=-1, ) x_test, y_test = tf.cast(x_test, tf.float32), tf.cast(y_test, tf.float32)

In self supervised learning, queries and indexes are labeled subset datasets used to evaluate the quality of the produced latent embedding. The following code assembles these datasets:

# Compute the indices for query, index, val, and train splits query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], [] for cid in range(ds_info.features["label"].num_classes): idxs = tf.random.shuffle(tf.where(y_raw_train == cid)) idxs = tf.reshape(idxs, (-1,)) query_idxs.extend(idxs[:100]) # 200 query examples per class index_idxs.extend(idxs[100:200]) # 200 index examples per class val_idxs.extend(idxs[200:300]) # 100 validation examples per class train_idxs.extend(idxs[300:]) # The remaining are used for training random.shuffle(query_idxs) random.shuffle(index_idxs) random.shuffle(val_idxs) random.shuffle(train_idxs) def create_split(idxs: list) -> tuple: x, y = [], [] for idx in idxs: x.append(x_raw_train[int(idx)]) y.append(y_raw_train[int(idx)]) return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor( np.array(y), dtype=tf.int64 ) x_query, y_query = create_split(query_idxs) x_index, y_index = create_split(index_idxs) x_val, y_val = create_split(val_idxs) x_train, y_train = create_split(train_idxs) PRE_TRAIN_STEPS_PER_EPOCH = tf.data.experimental.cardinality(train_ds) // BATCH_SIZE PRE_TRAIN_STEPS_PER_EPOCH = int(PRE_TRAIN_STEPS_PER_EPOCH.numpy()) print( tabulate( [ ["train", tf.data.experimental.cardinality(train_ds), None], ["val", x_val.shape, y_val.shape], ["query", x_query.shape, y_query.shape], ["index", x_index.shape, y_index.shape], ["test", x_test.shape, y_test.shape], ], headers=["# of Examples", "Labels"], ) )
``` # of Examples Labels ----- ----------------- -------- train 100000 val (1000, 96, 96, 3) (1000,) query (1000, 96, 96, 3) (1000,) index (1000, 96, 96, 3) (1000,) test (8000, 96, 96, 3) (8000,)
</div> --- ## Augmentations Self-supervised networks require at least two augmented "views" of each example. This can be created using a dataset and an augmentation function. The dataset treats each example in the batch as its own class and then the augment function produces two separate views for each example. This means the resulting batch will yield tuples containing the two views, i.e., Tuple[(BATCH_SIZE, 32, 32, 3), (BATCH_SIZE, 32, 32, 3)]. Using KerasCV, it is trivial to construct an augmenter that performs as the one described in the original SimSiam paper. Lets do that below. ```python target_size = (96, 96) crop_area_factor = (0.08, 1) aspect_ratio_factor = (3 / 4, 4 / 3) grayscale_rate = 0.2 color_jitter_rate = 0.8 brightness_factor = 0.2 contrast_factor = 0.8 saturation_factor = (0.3, 0.7) hue_factor = 0.2 augmenter = keras.Sequential( [ cv_layers.RandomFlip("horizontal"), cv_layers.RandomCropAndResize( target_size, crop_area_factor=crop_area_factor, aspect_ratio_factor=aspect_ratio_factor, ), cv_layers.RandomApply( cv_layers.Grayscale(output_channels=3), rate=grayscale_rate ), cv_layers.RandomApply( cv_layers.RandomColorJitter( value_range=(0, 255), brightness_factor=brightness_factor, contrast_factor=contrast_factor, saturation_factor=saturation_factor, hue_factor=hue_factor, ), rate=color_jitter_rate, ), ], )

Next, lets pass our images through this pipeline. Note that KerasCV supports batched augmentation, so batching before augmentation dramatically improves performance

@tf.function() def process(img): return augmenter(img), augmenter(img) def prepare_dataset(dataset): dataset = dataset.repeat() dataset = dataset.shuffle(1024) dataset = dataset.batch(BATCH_SIZE) dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE) return dataset.prefetch(tf.data.AUTOTUNE) train_ds = prepare_dataset(train_ds) val_ds = tf.data.Dataset.from_tensor_slices(x_val) val_ds = prepare_dataset(val_ds) print("train_ds", train_ds) print("val_ds", val_ds)
``` train_ds val_ds
</div> Lets visualize our pairs using the `tfsim.visualization` utility package. ```python display_imgs = next(train_ds.as_numpy_iterator()) max_pixel = np.max([display_imgs[0].max(), display_imgs[1].max()]) min_pixel = np.min([display_imgs[0].min(), display_imgs[1].min()]) tfsim.visualization.visualize_views( views=display_imgs, num_imgs=16, views_per_col=8, max_pixel_value=max_pixel, min_pixel_value=min_pixel, )

png


Model Creation

Now that our data and augmentation pipeline is setup, we can move on to constructing the contrastive learning pipeline. First, lets produce a backbone. For this task, we will use a KerasCV ResNet18 model as the backbone.

def get_backbone(input_shape): inputs = layers.Input(shape=input_shape) x = inputs x = keras_cv.models.ResNet18( input_shape=input_shape, include_rescaling=True, include_top=False, pooling="avg", )(x) return tfsim.models.SimilarityModel(inputs, x) backbone = get_backbone((96, 96, 3)) backbone.summary()
``` Model: "similarity_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 96, 96, 3)] 0

resnet18 (Functional) (None, 512) 11186112

================================================================= Total params: 11,186,112 Trainable params: 11,176,512 Non-trainable params: 9,600


</div> This MLP is common to all the self-supervised models and is typically a stack of 3 layers of the same size. However, SimSiam only uses 2 layers for the smaller CIFAR images. Having too much capacity in the models can make it difficult for the loss to stabilize and converge. Note: This is the model output that is returned by `ContrastiveModel.predict()` and represents the distance based embedding. This embedding can be used for the KNN lookups and matching classification metrics. However, when using the pre-train model for downstream tasks, only the `ContrastiveModel.backbone` is used. ```python def get_projector(input_dim, dim, activation="relu", num_layers: int = 3): inputs = tf.keras.layers.Input((input_dim,), name="projector_input") x = inputs for i in range(num_layers - 1): x = tf.keras.layers.Dense( dim, use_bias=False, kernel_initializer=tf.keras.initializers.LecunUniform(), name=f"projector_layer_{i}", )(x) x = tf.keras.layers.BatchNormalization( epsilon=1.001e-5, name=f"batch_normalization_{i}" )(x) x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_{i}")( x ) x = tf.keras.layers.Dense( dim, use_bias=False, kernel_initializer=tf.keras.initializers.LecunUniform(), name="projector_output", )(x) x = tf.keras.layers.BatchNormalization( epsilon=1.001e-5, center=False, # Page:5, Paragraph:2 of SimSiam paper scale=False, # Page:5, Paragraph:2 of SimSiam paper name=f"batch_normalization_ouput", )(x) # Metric Logging layer. Monitors the std of the layer activations. # Degenerate solutions colapse to 0 while valid solutions will move # towards something like 0.0220. The actual number will depend on the layer size. o = tfsim.layers.ActivationStdLoggingLayer(name="proj_std")(x) projector = tf.keras.Model(inputs, o, name="projector") return projector projector = get_projector(input_dim=backbone.output.shape[-1], dim=DIM, num_layers=2) projector.summary()
``` Model: "projector" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= projector_input (InputLayer [(None, 512)] 0 )

projector_layer_0 (Dense) (None, 2048) 1048576

batch_normalization_0 (Batc (None, 2048) 8192 hNormalization)

relu_activation_0 (Activati (None, 2048) 0 on)

projector_output (Dense) (None, 2048) 4194304

batch_normalization_ouput ( (None, 2048) 4096 BatchNormalization)

proj_std (ActivationStdLogg (None, 2048) 0 ingLayer)

================================================================= Total params: 5,255,168 Trainable params: 5,246,976 Non-trainable params: 8,192


</div> Finally, we must construct the predictor. The predictor is used in SimSiam, and is a simple stack of two MLP layers, containing a bottleneck in the hidden layer. ```python def get_predictor(input_dim, hidden_dim=512, activation="relu"): inputs = tf.keras.layers.Input(shape=(input_dim,), name="predictor_input") x = inputs x = tf.keras.layers.Dense( hidden_dim, use_bias=False, kernel_initializer=tf.keras.initializers.LecunUniform(), name="predictor_layer_0", )(x) x = tf.keras.layers.BatchNormalization( epsilon=1.001e-5, name="batch_normalization_0" )(x) x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_0")(x) x = tf.keras.layers.Dense( input_dim, kernel_initializer=tf.keras.initializers.LecunUniform(), name="predictor_output", )(x) # Metric Logging layer. Monitors the std of the layer activations. # Degenerate solutions colapse to 0 while valid solutions will move # towards something like 0.0220. The actual number will depend on the layer size. o = tfsim.layers.ActivationStdLoggingLayer(name="pred_std")(x) predictor = tf.keras.Model(inputs, o, name="predictor") return predictor predictor = get_predictor(input_dim=DIM, hidden_dim=512) predictor.summary()
``` Model: "predictor" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= predictor_input (InputLayer [(None, 2048)] 0 )

predictor_layer_0 (Dense) (None, 512) 1048576

batch_normalization_0 (Batc (None, 512) 2048 hNormalization)

relu_activation_0 (Activati (None, 512) 0 on)

predictor_output (Dense) (None, 2048) 1050624

pred_std (ActivationStdLogg (None, 2048) 0 ingLayer)

================================================================= Total params: 2,101,248 Trainable params: 2,100,224 Non-trainable params: 1,024


</div> --- ## Training First, we need to initialize our training model, loss, and optimizer. ```python loss = tfsim.losses.SimSiamLoss(projection_type="cosine_distance", name="simsiam") contrastive_model = tfsim.models.ContrastiveModel( backbone=backbone, projector=projector, predictor=predictor, # NOTE: simiam requires predictor model. algorithm="simsiam", name="simsiam", ) lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay( initial_learning_rate=INIT_LR, decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH, ) wd_decayed_fn = tf.keras.optimizers.schedules.CosineDecay( initial_learning_rate=WEIGHT_DECAY, decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH, ) optimizer = tfa.optimizers.SGDW( learning_rate=lr_decayed_fn, weight_decay=wd_decayed_fn, momentum=0.9 )

Next we can compile the model the same way you compile any other Keras model.

contrastive_model.compile( optimizer=optimizer, loss=loss, )

We track the training using EvalCallback. EvalCallback creates an index at the end of each epoch and provides a proxy for the nearest neighbor matching classification using binary_accuracy. Calculates how often the query label matches the derived lookup label.

Accuracy is technically (TP+TN)/(TP+FP+TN+FN), but here we filter all queries above the distance threshold. In the case of binary matching, this makes all the TPs and FPs below the distance threshold and all the TNs and FNs above the distance threshold.

As we are only concerned with the matches below the distance threshold, the accuracy simplifies to TP/(TP+FP) and is equivalent to the precision with respect to the unfiltered queries. However, we also want to consider the query coverage at the distance threshold, i.e., the percentage of queries that return a match, computed as (TP+FP)/(TP+FP+TN+FN). Therefore, we can take precision×querycoverage precision \times query_coverage to produce a measure that capture the precision scaled by the query coverage. This simplifies down to the binary accuracy presented here, giving TP/(TP+FP+TN+FN).

DATA_PATH = Path("./") log_dir = DATA_PATH / "models" / "logs" / f"{loss.name}_{time.time()}" chkpt_dir = DATA_PATH / "models" / "checkpoints" / f"{loss.name}_{time.time()}" callbacks = [ tfsim.callbacks.EvalCallback( tf.cast(x_query, tf.float32), y_query, tf.cast(x_index, tf.float32), y_index, metrics=["binary_accuracy"], k=1, tb_logdir=log_dir, ), tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=1, update_freq=100, ), tf.keras.callbacks.ModelCheckpoint( filepath=chkpt_dir, monitor="val_loss", mode="min", save_best_only=True, save_weights_only=True, ), ]
``` TensorBoard logging enable in models/logs/simsiam_1674516693.2898047/index
</div> All that is left to do is run fit()! ```python print(train_ds) print(val_ds) history = contrastive_model.fit( train_ds, epochs=PRE_TRAIN_EPOCHS, steps_per_epoch=PRE_TRAIN_STEPS_PER_EPOCH, validation_data=val_ds, validation_steps=VAL_STEPS_PER_EPOCH, callbacks=callbacks, )
``` Epoch 1/50 195/195 [==============================] - ETA: 0s - loss: 0.3137 - proj_std: 0.0122 - pred_std: 0.0076binary_accuracy: 0.2270 195/195 [==============================] - 90s 398ms/step - loss: 0.3137 - proj_std: 0.0122 - pred_std: 0.0076 - val_loss: 0.0764 - val_proj_std: 0.0068 - val_pred_std: 0.0026 - binary_accuracy: 0.2270 Epoch 2/50 195/195 [==============================] - ETA: 0s - loss: 0.1469 - proj_std: 0.0101 - pred_std: 0.0048binary_accuracy: 0.2260 195/195 [==============================] - 76s 390ms/step - loss: 0.1469 - proj_std: 0.0101 - pred_std: 0.0048 - val_loss: 0.3514 - val_proj_std: 0.0162 - val_pred_std: 0.0096 - binary_accuracy: 0.2260 Epoch 3/50 195/195 [==============================] - ETA: 0s - loss: 0.2779 - proj_std: 0.0145 - pred_std: 0.0102binary_accuracy: 0.2550 195/195 [==============================] - 74s 379ms/step - loss: 0.2779 - proj_std: 0.0145 - pred_std: 0.0102 - val_loss: 0.4038 - val_proj_std: 0.0158 - val_pred_std: 0.0122 - binary_accuracy: 0.2550 Epoch 4/50 195/195 [==============================] - ETA: 0s - loss: 0.2763 - proj_std: 0.0167 - pred_std: 0.0145binary_accuracy: 0.2630 195/195 [==============================] - 75s 385ms/step - loss: 0.2763 - proj_std: 0.0167 - pred_std: 0.0145 - val_loss: 0.1668 - val_proj_std: 0.0135 - val_pred_std: 0.0114 - binary_accuracy: 0.2630 Epoch 5/50 195/195 [==============================] - ETA: 0s - loss: 0.2235 - proj_std: 0.0174 - pred_std: 0.0154binary_accuracy: 0.2530 195/195 [==============================] - 63s 326ms/step - loss: 0.2235 - proj_std: 0.0174 - pred_std: 0.0154 - val_loss: 0.1847 - val_proj_std: 0.0144 - val_pred_std: 0.0127 - binary_accuracy: 0.2530 Epoch 6/50 195/195 [==============================] - ETA: 0s - loss: 0.2091 - proj_std: 0.0181 - pred_std: 0.0165binary_accuracy: 0.2580 195/195 [==============================] - 68s 350ms/step - loss: 0.2091 - proj_std: 0.0181 - pred_std: 0.0165 - val_loss: 0.2072 - val_proj_std: 0.0189 - val_pred_std: 0.0176 - binary_accuracy: 0.2580 Epoch 7/50 195/195 [==============================] - ETA: 0s - loss: 0.2385 - proj_std: 0.0191 - pred_std: 0.0178binary_accuracy: 0.2700 195/195 [==============================] - 69s 354ms/step - loss: 0.2385 - proj_std: 0.0191 - pred_std: 0.0178 - val_loss: 0.4700 - val_proj_std: 0.0186 - val_pred_std: 0.0199 - binary_accuracy: 0.2700 Epoch 8/50 195/195 [==============================] - ETA: 0s - loss: 0.1986 - proj_std: 0.0186 - pred_std: 0.0174binary_accuracy: 0.2750 195/195 [==============================] - 68s 350ms/step - loss: 0.1986 - proj_std: 0.0186 - pred_std: 0.0174 - val_loss: 0.3135 - val_proj_std: 0.0192 - val_pred_std: 0.0181 - binary_accuracy: 0.2750 Epoch 9/50 195/195 [==============================] - ETA: 0s - loss: 0.2182 - proj_std: 0.0191 - pred_std: 0.0180binary_accuracy: 0.2670 195/195 [==============================] - 68s 350ms/step - loss: 0.2182 - proj_std: 0.0191 - pred_std: 0.0180 - val_loss: 0.2822 - val_proj_std: 0.0155 - val_pred_std: 0.0135 - binary_accuracy: 0.2670 Epoch 10/50 195/195 [==============================] - ETA: 0s - loss: 0.1991 - proj_std: 0.0185 - pred_std: 0.0173binary_accuracy: 0.3090 195/195 [==============================] - 69s 353ms/step - loss: 0.1991 - proj_std: 0.0185 - pred_std: 0.0173 - val_loss: 0.1550 - val_proj_std: 0.0134 - val_pred_std: 0.0117 - binary_accuracy: 0.3090 Epoch 11/50 195/195 [==============================] - ETA: 0s - loss: 0.2080 - proj_std: 0.0185 - pred_std: 0.0175binary_accuracy: 0.2840 195/195 [==============================] - 69s 353ms/step - loss: 0.2080 - proj_std: 0.0185 - pred_std: 0.0175 - val_loss: 0.2511 - val_proj_std: 0.0185 - val_pred_std: 0.0181 - binary_accuracy: 0.2840 Epoch 12/50 195/195 [==============================] - ETA: 0s - loss: 0.1934 - proj_std: 0.0186 - pred_std: 0.0176binary_accuracy: 0.2980 195/195 [==============================] - 68s 352ms/step - loss: 0.1934 - proj_std: 0.0186 - pred_std: 0.0176 - val_loss: 0.1785 - val_proj_std: 0.0122 - val_pred_std: 0.0104 - binary_accuracy: 0.2980 Epoch 13/50 195/195 [==============================] - ETA: 0s - loss: 0.1945 - proj_std: 0.0190 - pred_std: 0.0181binary_accuracy: 0.3020 195/195 [==============================] - 69s 356ms/step - loss: 0.1945 - proj_std: 0.0190 - pred_std: 0.0181 - val_loss: 0.1189 - val_proj_std: 0.0118 - val_pred_std: 0.0107 - binary_accuracy: 0.3020 Epoch 14/50 195/195 [==============================] - ETA: 0s - loss: 0.2009 - proj_std: 0.0194 - pred_std: 0.0187binary_accuracy: 0.3320 195/195 [==============================] - 68s 350ms/step - loss: 0.2009 - proj_std: 0.0194 - pred_std: 0.0187 - val_loss: 0.1736 - val_proj_std: 0.0127 - val_pred_std: 0.0114 - binary_accuracy: 0.3320 Epoch 15/50 195/195 [==============================] - ETA: 0s - loss: 0.2029 - proj_std: 0.0194 - pred_std: 0.0186binary_accuracy: 0.3320 195/195 [==============================] - 69s 353ms/step - loss: 0.2029 - proj_std: 0.0194 - pred_std: 0.0186 - val_loss: 0.1638 - val_proj_std: 0.0154 - val_pred_std: 0.0148 - binary_accuracy: 0.3320 Epoch 16/50 195/195 [==============================] - ETA: 0s - loss: 0.1972 - proj_std: 0.0196 - pred_std: 0.0190binary_accuracy: 0.3060 195/195 [==============================] - 68s 348ms/step - loss: 0.1972 - proj_std: 0.0196 - pred_std: 0.0190 - val_loss: 0.2987 - val_proj_std: 0.0203 - val_pred_std: 0.0202 - binary_accuracy: 0.3060 Epoch 17/50 195/195 [==============================] - ETA: 0s - loss: 0.1885 - proj_std: 0.0196 - pred_std: 0.0190binary_accuracy: 0.3050 195/195 [==============================] - 68s 350ms/step - loss: 0.1885 - proj_std: 0.0196 - pred_std: 0.0190 - val_loss: 0.1805 - val_proj_std: 0.0161 - val_pred_std: 0.0150 - binary_accuracy: 0.3050 Epoch 18/50 195/195 [==============================] - ETA: 0s - loss: 0.1933 - proj_std: 0.0196 - pred_std: 0.0189binary_accuracy: 0.3270 195/195 [==============================] - 68s 352ms/step - loss: 0.1933 - proj_std: 0.0196 - pred_std: 0.0189 - val_loss: 0.1917 - val_proj_std: 0.0159 - val_pred_std: 0.0151 - binary_accuracy: 0.3270 Epoch 19/50 195/195 [==============================] - ETA: 0s - loss: 0.1934 - proj_std: 0.0196 - pred_std: 0.0189binary_accuracy: 0.3260 195/195 [==============================] - 69s 353ms/step - loss: 0.1934 - proj_std: 0.0196 - pred_std: 0.0189 - val_loss: 0.1808 - val_proj_std: 0.0171 - val_pred_std: 0.0165 - binary_accuracy: 0.3260 Epoch 20/50 195/195 [==============================] - ETA: 0s - loss: 0.1757 - proj_std: 0.0194 - pred_std: 0.0187binary_accuracy: 0.3180 195/195 [==============================] - 68s 349ms/step - loss: 0.1757 - proj_std: 0.0194 - pred_std: 0.0187 - val_loss: 0.1957 - val_proj_std: 0.0176 - val_pred_std: 0.0167 - binary_accuracy: 0.3180 Epoch 21/50 195/195 [==============================] - ETA: 0s - loss: 0.1752 - proj_std: 0.0194 - pred_std: 0.0188binary_accuracy: 0.3000 195/195 [==============================] - 78s 403ms/step - loss: 0.1752 - proj_std: 0.0194 - pred_std: 0.0188 - val_loss: 0.2070 - val_proj_std: 0.0178 - val_pred_std: 0.0172 - binary_accuracy: 0.3000 Epoch 22/50 195/195 [==============================] - ETA: 0s - loss: 0.1743 - proj_std: 0.0195 - pred_std: 0.0190binary_accuracy: 0.3280 195/195 [==============================] - 62s 317ms/step - loss: 0.1743 - proj_std: 0.0195 - pred_std: 0.0190 - val_loss: 0.2240 - val_proj_std: 0.0181 - val_pred_std: 0.0176 - binary_accuracy: 0.3280 Epoch 23/50 195/195 [==============================] - ETA: 0s - loss: 0.1692 - proj_std: 0.0193 - pred_std: 0.0188binary_accuracy: 0.3140 195/195 [==============================] - 68s 348ms/step - loss: 0.1692 - proj_std: 0.0193 - pred_std: 0.0188 - val_loss: 0.1892 - val_proj_std: 0.0186 - val_pred_std: 0.0181 - binary_accuracy: 0.3140 Epoch 24/50 195/195 [==============================] - ETA: 0s - loss: 0.1529 - proj_std: 0.0190 - pred_std: 0.0184binary_accuracy: 0.3280 195/195 [==============================] - 69s 353ms/step - loss: 0.1529 - proj_std: 0.0190 - pred_std: 0.0184 - val_loss: 0.2405 - val_proj_std: 0.0196 - val_pred_std: 0.0194 - binary_accuracy: 0.3280 Epoch 25/50 195/195 [==============================] - ETA: 0s - loss: 0.1425 - proj_std: 0.0187 - pred_std: 0.0182binary_accuracy: 0.3560 195/195 [==============================] - 75s 384ms/step - loss: 0.1425 - proj_std: 0.0187 - pred_std: 0.0182 - val_loss: 0.1602 - val_proj_std: 0.0181 - val_pred_std: 0.0178 - binary_accuracy: 0.3560 Epoch 26/50 195/195 [==============================] - ETA: 0s - loss: 0.1277 - proj_std: 0.0186 - pred_std: 0.0182binary_accuracy: 0.3080 195/195 [==============================] - 63s 322ms/step - loss: 0.1277 - proj_std: 0.0186 - pred_std: 0.0182 - val_loss: 0.1815 - val_proj_std: 0.0193 - val_pred_std: 0.0192 - binary_accuracy: 0.3080 Epoch 27/50 195/195 [==============================] - ETA: 0s - loss: 0.1326 - proj_std: 0.0189 - pred_std: 0.0185binary_accuracy: 0.3540 195/195 [==============================] - 69s 357ms/step - loss: 0.1326 - proj_std: 0.0189 - pred_std: 0.0185 - val_loss: 0.1919 - val_proj_std: 0.0177 - val_pred_std: 0.0174 - binary_accuracy: 0.3540 Epoch 28/50 195/195 [==============================] - ETA: 0s - loss: 0.1383 - proj_std: 0.0187 - pred_std: 0.0183binary_accuracy: 0.4060 195/195 [==============================] - 75s 388ms/step - loss: 0.1383 - proj_std: 0.0187 - pred_std: 0.0183 - val_loss: 0.1795 - val_proj_std: 0.0170 - val_pred_std: 0.0165 - binary_accuracy: 0.4060 Epoch 29/50 195/195 [==============================] - ETA: 0s - loss: 0.1348 - proj_std: 0.0177 - pred_std: 0.0172binary_accuracy: 0.3410 195/195 [==============================] - 61s 312ms/step - loss: 0.1348 - proj_std: 0.0177 - pred_std: 0.0172 - val_loss: 0.2115 - val_proj_std: 0.0187 - val_pred_std: 0.0185 - binary_accuracy: 0.3410 Epoch 30/50 195/195 [==============================] - ETA: 0s - loss: 0.1198 - proj_std: 0.0178 - pred_std: 0.0174binary_accuracy: 0.3520 195/195 [==============================] - 78s 401ms/step - loss: 0.1198 - proj_std: 0.0178 - pred_std: 0.0174 - val_loss: 0.1277 - val_proj_std: 0.0124 - val_pred_std: 0.0115 - binary_accuracy: 0.3520 Epoch 31/50 195/195 [==============================] - ETA: 0s - loss: 0.1185 - proj_std: 0.0180 - pred_std: 0.0176binary_accuracy: 0.3840 195/195 [==============================] - 68s 349ms/step - loss: 0.1185 - proj_std: 0.0180 - pred_std: 0.0176 - val_loss: 0.1637 - val_proj_std: 0.0187 - val_pred_std: 0.0185 - binary_accuracy: 0.3840 Epoch 32/50 195/195 [==============================] - ETA: 0s - loss: 0.1228 - proj_std: 0.0181 - pred_std: 0.0177binary_accuracy: 0.3790 195/195 [==============================] - 61s 312ms/step - loss: 0.1228 - proj_std: 0.0181 - pred_std: 0.0177 - val_loss: 0.1381 - val_proj_std: 0.0185 - val_pred_std: 0.0182 - binary_accuracy: 0.3790 Epoch 33/50 195/195 [==============================] - ETA: 0s - loss: 0.1180 - proj_std: 0.0176 - pred_std: 0.0173binary_accuracy: 0.4050 195/195 [==============================] - 70s 358ms/step - loss: 0.1180 - proj_std: 0.0176 - pred_std: 0.0173 - val_loss: 0.1273 - val_proj_std: 0.0188 - val_pred_std: 0.0186 - binary_accuracy: 0.4050 Epoch 34/50 195/195 [==============================] - ETA: 0s - loss: 0.1145 - proj_std: 0.0176 - pred_std: 0.0173binary_accuracy: 0.3880 195/195 [==============================] - 67s 342ms/step - loss: 0.1145 - proj_std: 0.0176 - pred_std: 0.0173 - val_loss: 0.1958 - val_proj_std: 0.0191 - val_pred_std: 0.0193 - binary_accuracy: 0.3880 Epoch 35/50 195/195 [==============================] - ETA: 0s - loss: 0.1112 - proj_std: 0.0175 - pred_std: 0.0172binary_accuracy: 0.3840 195/195 [==============================] - 68s 348ms/step - loss: 0.1112 - proj_std: 0.0175 - pred_std: 0.0172 - val_loss: 0.1372 - val_proj_std: 0.0186 - val_pred_std: 0.0185 - binary_accuracy: 0.3840 Epoch 36/50 195/195 [==============================] - ETA: 0s - loss: 0.1149 - proj_std: 0.0173 - pred_std: 0.0171binary_accuracy: 0.4030 195/195 [==============================] - 67s 343ms/step - loss: 0.1149 - proj_std: 0.0173 - pred_std: 0.0171 - val_loss: 0.1284 - val_proj_std: 0.0165 - val_pred_std: 0.0163 - binary_accuracy: 0.4030 Epoch 37/50 195/195 [==============================] - ETA: 0s - loss: 0.1108 - proj_std: 0.0174 - pred_std: 0.0171binary_accuracy: 0.4100 195/195 [==============================] - 71s 366ms/step - loss: 0.1108 - proj_std: 0.0174 - pred_std: 0.0171 - val_loss: 0.1387 - val_proj_std: 0.0145 - val_pred_std: 0.0141 - binary_accuracy: 0.4100 Epoch 38/50 195/195 [==============================] - ETA: 0s - loss: 0.1028 - proj_std: 0.0174 - pred_std: 0.0172binary_accuracy: 0.4180 195/195 [==============================] - 66s 338ms/step - loss: 0.1028 - proj_std: 0.0174 - pred_std: 0.0172 - val_loss: 0.1183 - val_proj_std: 0.0182 - val_pred_std: 0.0180 - binary_accuracy: 0.4180 Epoch 39/50 195/195 [==============================] - ETA: 0s - loss: 0.1011 - proj_std: 0.0171 - pred_std: 0.0170binary_accuracy: 0.4020 195/195 [==============================] - 69s 357ms/step - loss: 0.1011 - proj_std: 0.0171 - pred_std: 0.0170 - val_loss: 0.1056 - val_proj_std: 0.0177 - val_pred_std: 0.0176 - binary_accuracy: 0.4020 Epoch 40/50 195/195 [==============================] - ETA: 0s - loss: 0.1081 - proj_std: 0.0167 - pred_std: 0.0165binary_accuracy: 0.4670 195/195 [==============================] - 67s 346ms/step - loss: 0.1081 - proj_std: 0.0167 - pred_std: 0.0165 - val_loss: 0.1144 - val_proj_std: 0.0182 - val_pred_std: 0.0182 - binary_accuracy: 0.4670 Epoch 41/50 195/195 [==============================] - ETA: 0s - loss: 0.1060 - proj_std: 0.0166 - pred_std: 0.0165binary_accuracy: 0.4280 195/195 [==============================] - 68s 349ms/step - loss: 0.1060 - proj_std: 0.0166 - pred_std: 0.0165 - val_loss: 0.1180 - val_proj_std: 0.0175 - val_pred_std: 0.0174 - binary_accuracy: 0.4280 Epoch 42/50 195/195 [==============================] - ETA: 0s - loss: 0.1063 - proj_std: 0.0163 - pred_std: 0.0162binary_accuracy: 0.4220 195/195 [==============================] - 69s 356ms/step - loss: 0.1063 - proj_std: 0.0163 - pred_std: 0.0162 - val_loss: 0.1143 - val_proj_std: 0.0173 - val_pred_std: 0.0171 - binary_accuracy: 0.4220 Epoch 43/50 195/195 [==============================] - ETA: 0s - loss: 0.1050 - proj_std: 0.0162 - pred_std: 0.0161binary_accuracy: 0.4310 195/195 [==============================] - 69s 353ms/step - loss: 0.1050 - proj_std: 0.0162 - pred_std: 0.0161 - val_loss: 0.1171 - val_proj_std: 0.0169 - val_pred_std: 0.0168 - binary_accuracy: 0.4310 Epoch 44/50 195/195 [==============================] - ETA: 0s - loss: 0.1013 - proj_std: 0.0159 - pred_std: 0.0157binary_accuracy: 0.4140 195/195 [==============================] - 75s 386ms/step - loss: 0.1013 - proj_std: 0.0159 - pred_std: 0.0157 - val_loss: 0.1106 - val_proj_std: 0.0161 - val_pred_std: 0.0159 - binary_accuracy: 0.4140 Epoch 45/50 195/195 [==============================] - ETA: 0s - loss: 0.1035 - proj_std: 0.0160 - pred_std: 0.0159binary_accuracy: 0.4350 195/195 [==============================] - 63s 324ms/step - loss: 0.1035 - proj_std: 0.0160 - pred_std: 0.0159 - val_loss: 0.1086 - val_proj_std: 0.0171 - val_pred_std: 0.0171 - binary_accuracy: 0.4350 Epoch 46/50 195/195 [==============================] - ETA: 0s - loss: 0.0999 - proj_std: 0.0157 - pred_std: 0.0157binary_accuracy: 0.4510 195/195 [==============================] - 69s 354ms/step - loss: 0.0999 - proj_std: 0.0157 - pred_std: 0.0157 - val_loss: 0.1000 - val_proj_std: 0.0164 - val_pred_std: 0.0164 - binary_accuracy: 0.4510 Epoch 47/50 195/195 [==============================] - ETA: 0s - loss: 0.1002 - proj_std: 0.0157 - pred_std: 0.0156binary_accuracy: 0.4680 195/195 [==============================] - 68s 351ms/step - loss: 0.1002 - proj_std: 0.0157 - pred_std: 0.0156 - val_loss: 0.1067 - val_proj_std: 0.0163 - val_pred_std: 0.0163 - binary_accuracy: 0.4680 Epoch 48/50 195/195 [==============================] - ETA: 0s - loss: 0.0980 - proj_std: 0.0155 - pred_std: 0.0153binary_accuracy: 0.4410 195/195 [==============================] - 68s 352ms/step - loss: 0.0980 - proj_std: 0.0155 - pred_std: 0.0153 - val_loss: 0.0986 - val_proj_std: 0.0159 - val_pred_std: 0.0159 - binary_accuracy: 0.4410 Epoch 49/50 195/195 [==============================] - ETA: 0s - loss: 0.0944 - proj_std: 0.0155 - pred_std: 0.0154binary_accuracy: 0.4520 195/195 [==============================] - 69s 355ms/step - loss: 0.0944 - proj_std: 0.0155 - pred_std: 0.0154 - val_loss: 0.0949 - val_proj_std: 0.0164 - val_pred_std: 0.0163 - binary_accuracy: 0.4520 Epoch 50/50 195/195 [==============================] - ETA: 0s - loss: 0.0937 - proj_std: 0.0155 - pred_std: 0.0154binary_accuracy: 0.4570 195/195 [==============================] - 67s 347ms/step - loss: 0.0937 - proj_std: 0.0155 - pred_std: 0.0154 - val_loss: 0.0978 - val_proj_std: 0.0166 - val_pred_std: 0.0165 - binary_accuracy: 0.4570
</div> --- ## Plotting and Evaluation ```python plt.figure(figsize=(15, 4)) plt.subplot(1, 3, 1) plt.plot(history.history["loss"]) plt.grid() plt.title(f"{loss.name} - loss") plt.subplot(1, 3, 2) plt.plot(history.history["proj_std"], label="proj") if "pred_std" in history.history: plt.plot(history.history["pred_std"], label="pred") plt.grid() plt.title(f"{loss.name} - std metrics") plt.legend() plt.subplot(1, 3, 3) plt.plot(history.history["binary_accuracy"], label="acc") plt.grid() plt.title(f"{loss.name} - match metrics") plt.legend() plt.show()

png


Fine Tuning on the Labelled Data

As a final step we will fine tune a classifier on 10% of the training data. This will allow us to evaluate the quality of our learned representation. First, we handle data loading:

eval_augmenter = keras.Sequential( [ keras_cv.layers.RandomCropAndResize( (96, 96), crop_area_factor=(0.8, 1.0), aspect_ratio_factor=(1.0, 1.0) ), keras_cv.layers.RandomFlip(mode="horizontal"), ] ) eval_train_ds = tf.data.Dataset.from_tensor_slices( (x_raw_train, tf.keras.utils.to_categorical(y_raw_train, 10)) ) eval_train_ds = eval_train_ds.repeat() eval_train_ds = eval_train_ds.shuffle(1024) eval_train_ds = eval_train_ds.map(lambda x, y: (eval_augmenter(x), y), tf.data.AUTOTUNE) eval_train_ds = eval_train_ds.batch(BATCH_SIZE) eval_train_ds = eval_train_ds.prefetch(tf.data.AUTOTUNE) eval_val_ds = tf.data.Dataset.from_tensor_slices( (x_test, tf.keras.utils.to_categorical(y_test, 10)) ) eval_val_ds = eval_val_ds.repeat() eval_val_ds = eval_val_ds.shuffle(1024) eval_val_ds = eval_val_ds.batch(BATCH_SIZE) eval_val_ds = eval_val_ds.prefetch(tf.data.AUTOTUNE)

Benchmark Against a Naive Model

Finally, lets setup a naive model that does not leverage the unlabeled data corpus.

TEST_EPOCHS = 50 TEST_STEPS_PER_EPOCH = x_raw_train.shape[0] // BATCH_SIZE def get_eval_model(img_size, backbone, total_steps, trainable=True, lr=1.8): backbone.trainable = trainable inputs = tf.keras.layers.Input((img_size, img_size, 3), name="eval_input") x = backbone(inputs, training=trainable) o = tf.keras.layers.Dense(10, activation="softmax")(x) model = tf.keras.Model(inputs, o) cosine_decayed_lr = tf.keras.experimental.CosineDecay( initial_learning_rate=lr, decay_steps=total_steps ) opt = tf.keras.optimizers.SGD(cosine_decayed_lr, momentum=0.9) model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"]) return model no_pt_eval_model = get_eval_model( img_size=96, backbone=get_backbone((96, 96, 3)), total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH, trainable=True, lr=1e-3, ) no_pt_history = no_pt_eval_model.fit( eval_train_ds, batch_size=BATCH_SIZE, epochs=TEST_EPOCHS, steps_per_epoch=TEST_STEPS_PER_EPOCH, validation_data=eval_val_ds, validation_steps=VAL_STEPS_PER_EPOCH, )
``` Epoch 1/50 9/9 [==============================] - 6s 249ms/step - loss: 2.4969 - acc: 0.1302 - val_loss: 2.2889 - val_acc: 0.1669 Epoch 2/50 9/9 [==============================] - 1s 139ms/step - loss: 2.2002 - acc: 0.1888 - val_loss: 2.1074 - val_acc: 0.2160 Epoch 3/50 9/9 [==============================] - 1s 139ms/step - loss: 2.0066 - acc: 0.2619 - val_loss: 1.9138 - val_acc: 0.2968 Epoch 4/50 9/9 [==============================] - 1s 139ms/step - loss: 1.8394 - acc: 0.3227 - val_loss: 1.7825 - val_acc: 0.3326 Epoch 5/50 9/9 [==============================] - 1s 140ms/step - loss: 1.7191 - acc: 0.3585 - val_loss: 1.7004 - val_acc: 0.3545 Epoch 6/50 9/9 [==============================] - 1s 140ms/step - loss: 1.6458 - acc: 0.3806 - val_loss: 1.6473 - val_acc: 0.3734 Epoch 7/50 9/9 [==============================] - 1s 139ms/step - loss: 1.5798 - acc: 0.4030 - val_loss: 1.6009 - val_acc: 0.3907 Epoch 8/50 9/9 [==============================] - 1s 139ms/step - loss: 1.5244 - acc: 0.4332 - val_loss: 1.5696 - val_acc: 0.4029 Epoch 9/50 9/9 [==============================] - 1s 140ms/step - loss: 1.4977 - acc: 0.4325 - val_loss: 1.5416 - val_acc: 0.4126 Epoch 10/50 9/9 [==============================] - 1s 139ms/step - loss: 1.4555 - acc: 0.4559 - val_loss: 1.5087 - val_acc: 0.4271 Epoch 11/50 9/9 [==============================] - 1s 140ms/step - loss: 1.4294 - acc: 0.4627 - val_loss: 1.4897 - val_acc: 0.4384 Epoch 12/50 9/9 [==============================] - 1s 139ms/step - loss: 1.4031 - acc: 0.4820 - val_loss: 1.4759 - val_acc: 0.4410 Epoch 13/50 9/9 [==============================] - 1s 141ms/step - loss: 1.3625 - acc: 0.4941 - val_loss: 1.4501 - val_acc: 0.4486 Epoch 14/50 9/9 [==============================] - 1s 140ms/step - loss: 1.3443 - acc: 0.5026 - val_loss: 1.4390 - val_acc: 0.4525 Epoch 15/50 9/9 [==============================] - 1s 139ms/step - loss: 1.3235 - acc: 0.5067 - val_loss: 1.4308 - val_acc: 0.4578 Epoch 16/50 9/9 [==============================] - 1s 139ms/step - loss: 1.2863 - acc: 0.5328 - val_loss: 1.4089 - val_acc: 0.4650 Epoch 17/50 9/9 [==============================] - 1s 140ms/step - loss: 1.2851 - acc: 0.5339 - val_loss: 1.3944 - val_acc: 0.4700 Epoch 18/50 9/9 [==============================] - 1s 141ms/step - loss: 1.2501 - acc: 0.5464 - val_loss: 1.3887 - val_acc: 0.4773 Epoch 19/50 9/9 [==============================] - 1s 139ms/step - loss: 1.2324 - acc: 0.5510 - val_loss: 1.3783 - val_acc: 0.4820 Epoch 20/50 9/9 [==============================] - 1s 140ms/step - loss: 1.2223 - acc: 0.5562 - val_loss: 1.3655 - val_acc: 0.4848 Epoch 21/50 9/9 [==============================] - 1s 140ms/step - loss: 1.2070 - acc: 0.5664 - val_loss: 1.3579 - val_acc: 0.4867 Epoch 22/50 9/9 [==============================] - 1s 141ms/step - loss: 1.1820 - acc: 0.5738 - val_loss: 1.3482 - val_acc: 0.4913 Epoch 23/50 9/9 [==============================] - 1s 139ms/step - loss: 1.1688 - acc: 0.5790 - val_loss: 1.3375 - val_acc: 0.4964 Epoch 24/50 9/9 [==============================] - 1s 141ms/step - loss: 1.1514 - acc: 0.5896 - val_loss: 1.3403 - val_acc: 0.4966 Epoch 25/50 9/9 [==============================] - 1s 138ms/step - loss: 1.1307 - acc: 0.5961 - val_loss: 1.3321 - val_acc: 0.5025 Epoch 26/50 9/9 [==============================] - 1s 139ms/step - loss: 1.1341 - acc: 0.6009 - val_loss: 1.3220 - val_acc: 0.5035 Epoch 27/50 9/9 [==============================] - 1s 139ms/step - loss: 1.1177 - acc: 0.5987 - val_loss: 1.3149 - val_acc: 0.5074 Epoch 28/50 9/9 [==============================] - 1s 139ms/step - loss: 1.1078 - acc: 0.6068 - val_loss: 1.3089 - val_acc: 0.5137 Epoch 29/50 9/9 [==============================] - 1s 141ms/step - loss: 1.0929 - acc: 0.6046 - val_loss: 1.3015 - val_acc: 0.5139 Epoch 30/50 9/9 [==============================] - 1s 138ms/step - loss: 1.0915 - acc: 0.6139 - val_loss: 1.3064 - val_acc: 0.5149 Epoch 31/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0634 - acc: 0.6254 - val_loss: 1.2955 - val_acc: 0.5123 Epoch 32/50 9/9 [==============================] - 1s 141ms/step - loss: 1.0675 - acc: 0.6254 - val_loss: 1.2979 - val_acc: 0.5167 Epoch 33/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0595 - acc: 0.6289 - val_loss: 1.2911 - val_acc: 0.5186 Epoch 34/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0397 - acc: 0.6328 - val_loss: 1.2906 - val_acc: 0.5208 Epoch 35/50 9/9 [==============================] - 1s 139ms/step - loss: 1.0415 - acc: 0.6378 - val_loss: 1.2863 - val_acc: 0.5222 Epoch 36/50 9/9 [==============================] - 1s 139ms/step - loss: 1.0435 - acc: 0.6257 - val_loss: 1.2830 - val_acc: 0.5215 Epoch 37/50 9/9 [==============================] - 1s 144ms/step - loss: 1.0242 - acc: 0.6461 - val_loss: 1.2820 - val_acc: 0.5268 Epoch 38/50 9/9 [==============================] - 1s 141ms/step - loss: 1.0212 - acc: 0.6421 - val_loss: 1.2766 - val_acc: 0.5259 Epoch 39/50 9/9 [==============================] - 1s 141ms/step - loss: 1.0213 - acc: 0.6385 - val_loss: 1.2770 - val_acc: 0.5259 Epoch 40/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0224 - acc: 0.6428 - val_loss: 1.2742 - val_acc: 0.5262 Epoch 41/50 9/9 [==============================] - 1s 142ms/step - loss: 0.9994 - acc: 0.6510 - val_loss: 1.2755 - val_acc: 0.5238 Epoch 42/50 9/9 [==============================] - 1s 141ms/step - loss: 1.0154 - acc: 0.6474 - val_loss: 1.2784 - val_acc: 0.5244 Epoch 43/50 9/9 [==============================] - 1s 139ms/step - loss: 1.0176 - acc: 0.6441 - val_loss: 1.2680 - val_acc: 0.5247 Epoch 44/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0101 - acc: 0.6471 - val_loss: 1.2711 - val_acc: 0.5288 Epoch 45/50 9/9 [==============================] - 1s 139ms/step - loss: 1.0080 - acc: 0.6536 - val_loss: 1.2691 - val_acc: 0.5275 Epoch 46/50 9/9 [==============================] - 1s 143ms/step - loss: 1.0038 - acc: 0.6428 - val_loss: 1.2706 - val_acc: 0.5302 Epoch 47/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0070 - acc: 0.6573 - val_loss: 1.2678 - val_acc: 0.5293 Epoch 48/50 9/9 [==============================] - 1s 140ms/step - loss: 1.0030 - acc: 0.6450 - val_loss: 1.2723 - val_acc: 0.5278 Epoch 49/50 9/9 [==============================] - 1s 139ms/step - loss: 1.0080 - acc: 0.6447 - val_loss: 1.2691 - val_acc: 0.5252 Epoch 50/50 9/9 [==============================] - 1s 142ms/step - loss: 1.0093 - acc: 0.6497 - val_loss: 1.2712 - val_acc: 0.5278
</div> Pretty bad results! Lets try fine-tuning our SimSiam pretrained model: ```python pt_eval_model = get_eval_model( img_size=96, backbone=contrastive_model.backbone, total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH, trainable=False, lr=30.0, ) pt_eval_model.summary() pt_history = pt_eval_model.fit( eval_train_ds, batch_size=BATCH_SIZE, epochs=TEST_EPOCHS, steps_per_epoch=TEST_STEPS_PER_EPOCH, validation_data=eval_val_ds, validation_steps=VAL_STEPS_PER_EPOCH, )
``` Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= eval_input (InputLayer) [(None, 96, 96, 3)] 0

similarity_model (Similarit (None, 512) 11186112 yModel)

dense_1 (Dense) (None, 10) 5130

================================================================= Total params: 11,191,242 Trainable params: 5,130 Non-trainable params: 11,186,112


Epoch 1/50 9/9 [==============================] - 3s 172ms/step - loss: 18.2303 - acc: 0.2563 - val_loss: 16.9489 - val_acc: 0.3463 Epoch 2/50 9/9 [==============================] - 1s 109ms/step - loss: 24.5528 - acc: 0.3498 - val_loss: 19.1886 - val_acc: 0.4050 Epoch 3/50 9/9 [==============================] - 1s 110ms/step - loss: 18.3920 - acc: 0.4477 - val_loss: 20.0611 - val_acc: 0.4456 Epoch 4/50 9/9 [==============================] - 1s 113ms/step - loss: 15.4172 - acc: 0.4993 - val_loss: 12.2465 - val_acc: 0.5116 Epoch 5/50 9/9 [==============================] - 1s 110ms/step - loss: 10.5517 - acc: 0.5217 - val_loss: 8.5560 - val_acc: 0.5474 Epoch 6/50 9/9 [==============================] - 1s 112ms/step - loss: 7.4812 - acc: 0.5395 - val_loss: 7.9182 - val_acc: 0.5053 Epoch 7/50 9/9 [==============================] - 1s 112ms/step - loss: 8.4429 - acc: 0.5024 - val_loss: 7.7339 - val_acc: 0.5071 Epoch 8/50 9/9 [==============================] - 1s 113ms/step - loss: 8.6143 - acc: 0.5109 - val_loss: 10.4784 - val_acc: 0.5157 Epoch 9/50 9/9 [==============================] - 1s 111ms/step - loss: 8.7506 - acc: 0.5061 - val_loss: 7.8201 - val_acc: 0.4914 Epoch 10/50 9/9 [==============================] - 1s 111ms/step - loss: 8.7927 - acc: 0.4996 - val_loss: 9.6188 - val_acc: 0.4668 Epoch 11/50 9/9 [==============================] - 1s 114ms/step - loss: 7.2190 - acc: 0.4844 - val_loss: 8.8605 - val_acc: 0.4240 Epoch 12/50 9/9 [==============================] - 1s 111ms/step - loss: 8.0435 - acc: 0.4681 - val_loss: 8.2731 - val_acc: 0.4425 Epoch 13/50 9/9 [==============================] - 1s 112ms/step - loss: 7.1718 - acc: 0.5048 - val_loss: 5.4667 - val_acc: 0.5485 Epoch 14/50 9/9 [==============================] - 1s 111ms/step - loss: 6.7500 - acc: 0.5111 - val_loss: 5.5898 - val_acc: 0.5158 Epoch 15/50 9/9 [==============================] - 1s 110ms/step - loss: 5.1562 - acc: 0.5467 - val_loss: 3.7606 - val_acc: 0.5587 Epoch 16/50 9/9 [==============================] - 1s 111ms/step - loss: 3.5923 - acc: 0.5814 - val_loss: 5.0881 - val_acc: 0.5336 Epoch 17/50 9/9 [==============================] - 1s 110ms/step - loss: 5.1907 - acc: 0.5221 - val_loss: 7.7393 - val_acc: 0.3609 Epoch 18/50 9/9 [==============================] - 1s 112ms/step - loss: 8.0532 - acc: 0.4768 - val_loss: 7.2504 - val_acc: 0.5265 Epoch 19/50 9/9 [==============================] - 1s 111ms/step - loss: 6.5527 - acc: 0.5221 - val_loss: 6.8659 - val_acc: 0.4729 Epoch 20/50 9/9 [==============================] - 1s 113ms/step - loss: 7.0188 - acc: 0.4924 - val_loss: 6.5774 - val_acc: 0.4729 Epoch 21/50 9/9 [==============================] - 1s 112ms/step - loss: 4.8837 - acc: 0.5293 - val_loss: 4.5986 - val_acc: 0.5568 Epoch 22/50 9/9 [==============================] - 1s 113ms/step - loss: 4.5787 - acc: 0.5536 - val_loss: 4.9848 - val_acc: 0.5343 Epoch 23/50 9/9 [==============================] - 1s 111ms/step - loss: 5.3264 - acc: 0.5501 - val_loss: 6.1620 - val_acc: 0.5257 Epoch 24/50 9/9 [==============================] - 1s 118ms/step - loss: 4.6995 - acc: 0.5681 - val_loss: 2.9108 - val_acc: 0.6004 Epoch 25/50 9/9 [==============================] - 1s 111ms/step - loss: 3.0915 - acc: 0.6024 - val_loss: 2.9674 - val_acc: 0.6097 Epoch 26/50 9/9 [==============================] - 1s 112ms/step - loss: 2.9893 - acc: 0.5940 - val_loss: 2.7857 - val_acc: 0.5975 Epoch 27/50 9/9 [==============================] - 1s 112ms/step - loss: 3.0031 - acc: 0.5990 - val_loss: 3.3214 - val_acc: 0.5661 Epoch 28/50 9/9 [==============================] - 1s 110ms/step - loss: 2.4497 - acc: 0.6118 - val_loss: 2.5389 - val_acc: 0.5864 Epoch 29/50 9/9 [==============================] - 1s 112ms/step - loss: 2.2352 - acc: 0.6222 - val_loss: 2.6069 - val_acc: 0.5891 Epoch 30/50 9/9 [==============================] - 1s 110ms/step - loss: 2.0529 - acc: 0.6230 - val_loss: 2.2986 - val_acc: 0.6147 Epoch 31/50 9/9 [==============================] - 1s 113ms/step - loss: 2.1396 - acc: 0.6337 - val_loss: 2.3893 - val_acc: 0.6115 Epoch 32/50 9/9 [==============================] - 1s 110ms/step - loss: 2.0879 - acc: 0.6309 - val_loss: 2.0767 - val_acc: 0.6139 Epoch 33/50 9/9 [==============================] - 1s 111ms/step - loss: 1.9498 - acc: 0.6417 - val_loss: 2.5760 - val_acc: 0.6166 Epoch 34/50 9/9 [==============================] - 1s 111ms/step - loss: 2.0624 - acc: 0.6456 - val_loss: 2.2055 - val_acc: 0.6306 Epoch 35/50 9/9 [==============================] - 1s 113ms/step - loss: 1.9772 - acc: 0.6573 - val_loss: 1.8998 - val_acc: 0.6148 Epoch 36/50 9/9 [==============================] - 1s 110ms/step - loss: 1.7421 - acc: 0.6411 - val_loss: 1.7790 - val_acc: 0.6320 Epoch 37/50 9/9 [==============================] - 1s 112ms/step - loss: 1.6005 - acc: 0.6493 - val_loss: 1.7596 - val_acc: 0.6132 Epoch 38/50 9/9 [==============================] - 1s 111ms/step - loss: 1.4635 - acc: 0.6623 - val_loss: 1.8133 - val_acc: 0.6142 Epoch 39/50 9/9 [==============================] - 1s 112ms/step - loss: 1.4952 - acc: 0.6517 - val_loss: 1.8677 - val_acc: 0.5960 Epoch 40/50 9/9 [==============================] - 1s 113ms/step - loss: 1.4972 - acc: 0.6519 - val_loss: 1.7388 - val_acc: 0.6311 Epoch 41/50 9/9 [==============================] - 1s 113ms/step - loss: 1.4158 - acc: 0.6693 - val_loss: 1.6358 - val_acc: 0.6398 Epoch 42/50 9/9 [==============================] - 1s 110ms/step - loss: 1.3600 - acc: 0.6721 - val_loss: 1.5624 - val_acc: 0.6381 Epoch 43/50 9/9 [==============================] - 1s 112ms/step - loss: 1.2960 - acc: 0.6812 - val_loss: 1.5512 - val_acc: 0.6380 Epoch 44/50 9/9 [==============================] - 1s 111ms/step - loss: 1.3473 - acc: 0.6727 - val_loss: 1.4881 - val_acc: 0.6448 Epoch 45/50 9/9 [==============================] - 1s 111ms/step - loss: 1.1990 - acc: 0.6892 - val_loss: 1.4914 - val_acc: 0.6437 Epoch 46/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2816 - acc: 0.6823 - val_loss: 1.4654 - val_acc: 0.6466 Epoch 47/50 9/9 [==============================] - 1s 113ms/step - loss: 1.2525 - acc: 0.6838 - val_loss: 1.4802 - val_acc: 0.6479 Epoch 48/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2661 - acc: 0.6799 - val_loss: 1.4692 - val_acc: 0.6447 Epoch 49/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2389 - acc: 0.6866 - val_loss: 1.4733 - val_acc: 0.6436 Epoch 50/50 9/9 [==============================] - 1s 113ms/step - loss: 1.2166 - acc: 0.6875 - val_loss: 1.4666 - val_acc: 0.6444

</div> All that is left to do is evaluate the models: ```python print( "no pretrain", no_pt_eval_model.evaluate( eval_val_ds, steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH, ), ) print( "pretrained", pt_eval_model.evaluate( eval_val_ds, steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH, ), )
``` 450/450 [==============================] - 14s 30ms/step - loss: 1.2648 - acc: 0.5311 no pretrain [1.2647558450698853, 0.5310590267181396] 450/450 [==============================] - 12s 26ms/step - loss: 1.4653 - acc: 0.6474 pretrained [1.465279221534729, 0.6474305391311646]
</div> Awesome! Our pretrained model stomped the non-pretrained model. This accuracy is quite good for a ResNet18 on the STL-10 dataset. For better results, try using an EfficientNetV2B0 instead. Unfortunately, this will require a higher end graphics card as SimSiam has a minimum batch size of 512. --- ## Conclusion TensorFlow Similarity can be used to easily train KerasCV models using contrastive algorithms such as SimCLR, SimSiam and BarlowTwins. This allows you to leverage large corpuses of unlabelled data in your model trainining pipeline. Some follow-up exercises to this tutorial: - Train a [`keras_cv.models.EfficientNetV2B0`](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/efficientnet_v2.py) on STL-10 - Experiment with other data augmentation techniques in pretraining - Train a model using the [BarlowTwins implementation](https://github.com/tensorflow/similarity/blob/master/examples/unsupervised_hello_world.ipynb) in TensorFlow similarity - Try pretraining on your own dataset