SimSiam Training with TensorFlow Similarity and KerasCV
Author: lukewood, Ian Stenbit, Owen Vallis
Date created: 2023/01/22
Last modified: 2023/01/22
Description: Train a KerasCV model using unlabelled data with SimSiam.
View in Colab •
GitHub source
Overview
TensorFlow similarity makes it easy to train KerasCV models on unlabelled corpuses of data using contrastive learning algorithms such as SimCLR, SimSiam, and Barlow Twins. In this guide, we will train a KerasCV model using the SimSiam implementation from TensorFlow Similarity.
Background
Self-supervised learning is an approach to pre-training models using unlabeled data. This approach drastically increases accuracy when you have very few labeled examples but a lot of unlabelled data. The key insight is that you can train a self-supervised model to learn data representations by contrasting multiple augmented views of the same example. These learned representations capture data invariants, e.g., object translation, color jitter, noise, etc. Training a simple linear classifier on top of the frozen representations is easier and requires fewer labels because the pre-trained model already produces meaningful and generally useful features.
Overall, self-supervised pre-training learns representations which are more generic and robust than other approaches to augmented training and pre-training. An overview of the general contrastive learning process is shown below:

In this tutorial, we will use the SimSiam algorithm for contrastive learning. As of 2022, SimSiam is the state of the art algorithm for contrastive learning; allowing for unprecedented scores on CIFAR-100 and other datasets.
You may need to install:
pip -q install tensorflow_similarity
pip -q install keras-cv
To get started, we will sort out some imports.
import resource
import gc
import os
import random
import time
import tensorflow_addons as tfa
import keras_cv
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tabulate import tabulate
import tensorflow_similarity as tfsim
import tensorflow as tf
from keras_cv import layers as cv_layers
import tensorflow_datasets as tfds
```
You do not have Waymo Open Dataset installed, so KerasCV Waymo metrics are not available.
Your CPU supports instructions that this binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2 For maximum performance, you can install NMSLIB from sources pip install --no-binary :all: nmslib
</div>
Lets sort out some high level config issues and define some constants.
The resource limit increase is required to load STL-10, `tfsim.utils.tf_cap_memory()`
prevents TensorFlow from hogging the GPU memory in a cluster, and
`tfds.disable_progress_bar()` makes tfds less noisy.
```python
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
tfsim.utils.tf_cap_memory()
tfds.disable_progress_bar()
BATCH_SIZE = 512
PRE_TRAIN_EPOCHS = 50
VAL_STEPS_PER_EPOCH = 20
WEIGHT_DECAY = 5e-4
INIT_LR = 3e-2 * int(BATCH_SIZE / 256)
WARMUP_LR = 0.0
WARMUP_STEPS = 0
DIM = 2048
Data loading
Next, we will load the STL-10 dataset. STL-10 is a dataset consisting of 100k unlabelled images, 5k labelled training images, and 10k labelled test images. Due to this distribution, STL-10 is commonly used as a benchmark for contrastive learning models.
First lets load our unlabelled data
train_ds = tfds.load("stl10", split="unlabelled")
train_ds = train_ds.map(
lambda entry: entry["image"], num_parallel_calls=tf.data.AUTOTUNE
)
train_ds = train_ds.map(
lambda image: tf.cast(image, tf.float32), num_parallel_calls=tf.data.AUTOTUNE
)
train_ds = train_ds.shuffle(buffer_size=8 * BATCH_SIZE, reshuffle_each_iteration=True)
```
[1mDownloading and preparing dataset 2.46 GiB (download: 2.46 GiB, generated: 1.86 GiB, total: 4.32 GiB) to ~/tensorflow_datasets/stl10/1.0.0...[0m
[1mDataset stl10 downloaded and prepared to ~/tensorflow_datasets/stl10/1.0.0. Subsequent calls will reuse this data.[0m
WARNING:tensorflow:From /home/lukewood/.local/lib/python3.7/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda functions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /home/lukewood/.local/lib/python3.7/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda functions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
</div>
Next, we need to prepare some labelled samples.
This is done so that TensorFlow similarity can probe the learned embedding to ensure
that the model is learning appropriately.
```python
(x_raw_train, y_raw_train), ds_info = tfds.load(
"stl10", split="train", as_supervised=True, batch_size=-1, with_info=True
)
x_raw_train, y_raw_train = tf.cast(x_raw_train, tf.float32), tf.cast(
y_raw_train, tf.float32
)
x_test, y_test = tfds.load(
"stl10",
split="test",
as_supervised=True,
batch_size=-1,
)
x_test, y_test = tf.cast(x_test, tf.float32), tf.cast(y_test, tf.float32)
In self supervised learning, queries and indexes are labeled subset datasets used to evaluate the quality of the produced latent embedding. The following code assembles these datasets:
query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []
for cid in range(ds_info.features["label"].num_classes):
idxs = tf.random.shuffle(tf.where(y_raw_train == cid))
idxs = tf.reshape(idxs, (-1,))
query_idxs.extend(idxs[:100])
index_idxs.extend(idxs[100:200])
val_idxs.extend(idxs[200:300])
train_idxs.extend(idxs[300:])
random.shuffle(query_idxs)
random.shuffle(index_idxs)
random.shuffle(val_idxs)
random.shuffle(train_idxs)
def create_split(idxs: list) -> tuple:
x, y = [], []
for idx in idxs:
x.append(x_raw_train[int(idx)])
y.append(y_raw_train[int(idx)])
return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(
np.array(y), dtype=tf.int64
)
x_query, y_query = create_split(query_idxs)
x_index, y_index = create_split(index_idxs)
x_val, y_val = create_split(val_idxs)
x_train, y_train = create_split(train_idxs)
PRE_TRAIN_STEPS_PER_EPOCH = tf.data.experimental.cardinality(train_ds) // BATCH_SIZE
PRE_TRAIN_STEPS_PER_EPOCH = int(PRE_TRAIN_STEPS_PER_EPOCH.numpy())
print(
tabulate(
[
["train", tf.data.experimental.cardinality(train_ds), None],
["val", x_val.shape, y_val.shape],
["query", x_query.shape, y_query.shape],
["index", x_index.shape, y_index.shape],
["test", x_test.shape, y_test.shape],
],
headers=["# of Examples", "Labels"],
)
)
```
# of Examples Labels
----- ----------------- --------
train 100000
val (1000, 96, 96, 3) (1000,)
query (1000, 96, 96, 3) (1000,)
index (1000, 96, 96, 3) (1000,)
test (8000, 96, 96, 3) (8000,)
</div>
---
## Augmentations
Self-supervised networks require at least two augmented "views" of each example.
This can be created using a dataset and an augmentation function.
The dataset treats each example in the batch as its own class and then the augment
function produces two separate views for each example.
This means the resulting batch will yield tuples containing the two views, i.e.,
Tuple[(BATCH_SIZE, 32, 32, 3), (BATCH_SIZE, 32, 32, 3)].
Using KerasCV, it is trivial to construct an augmenter that performs as the one
described in the original SimSiam paper. Lets do that below.
```python
target_size = (96, 96)
crop_area_factor = (0.08, 1)
aspect_ratio_factor = (3 / 4, 4 / 3)
grayscale_rate = 0.2
color_jitter_rate = 0.8
brightness_factor = 0.2
contrast_factor = 0.8
saturation_factor = (0.3, 0.7)
hue_factor = 0.2
augmenter = keras.Sequential(
[
cv_layers.RandomFlip("horizontal"),
cv_layers.RandomCropAndResize(
target_size,
crop_area_factor=crop_area_factor,
aspect_ratio_factor=aspect_ratio_factor,
),
cv_layers.RandomApply(
cv_layers.Grayscale(output_channels=3), rate=grayscale_rate
),
cv_layers.RandomApply(
cv_layers.RandomColorJitter(
value_range=(0, 255),
brightness_factor=brightness_factor,
contrast_factor=contrast_factor,
saturation_factor=saturation_factor,
hue_factor=hue_factor,
),
rate=color_jitter_rate,
),
],
)
Next, lets pass our images through this pipeline. Note that KerasCV supports batched augmentation, so batching before augmentation dramatically improves performance
@tf.function()
def process(img):
return augmenter(img), augmenter(img)
def prepare_dataset(dataset):
dataset = dataset.repeat()
dataset = dataset.shuffle(1024)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE)
return dataset.prefetch(tf.data.AUTOTUNE)
train_ds = prepare_dataset(train_ds)
val_ds = tf.data.Dataset.from_tensor_slices(x_val)
val_ds = prepare_dataset(val_ds)
print("train_ds", train_ds)
print("val_ds", val_ds)
</div>
Lets visualize our pairs using the `tfsim.visualization` utility package.
```python
display_imgs = next(train_ds.as_numpy_iterator())
max_pixel = np.max([display_imgs[0].max(), display_imgs[1].max()])
min_pixel = np.min([display_imgs[0].min(), display_imgs[1].min()])
tfsim.visualization.visualize_views(
views=display_imgs,
num_imgs=16,
views_per_col=8,
max_pixel_value=max_pixel,
min_pixel_value=min_pixel,
)

Model Creation
Now that our data and augmentation pipeline is setup, we can move on to constructing the contrastive learning pipeline. First, lets produce a backbone. For this task, we will use a KerasCV ResNet18 model as the backbone.
def get_backbone(input_shape):
inputs = layers.Input(shape=input_shape)
x = inputs
x = keras_cv.models.ResNet18(
input_shape=input_shape,
include_rescaling=True,
include_top=False,
pooling="avg",
)(x)
return tfsim.models.SimilarityModel(inputs, x)
backbone = get_backbone((96, 96, 3))
backbone.summary()
```
Model: "similarity_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 96, 96, 3)] 0
resnet18 (Functional) (None, 512) 11186112
================================================================= Total params: 11,186,112 Trainable params: 11,176,512 Non-trainable params: 9,600
</div>
This MLP is common to all the self-supervised models and is typically a stack of 3
layers of the same size. However, SimSiam only uses 2 layers for the smaller CIFAR
images. Having too much capacity in the models can make it difficult for the loss to
stabilize and converge.
Note: This is the model output that is returned by `ContrastiveModel.predict()` and
represents the distance based embedding. This embedding can be used for the KNN
lookups and matching classification metrics. However, when using the pre-train
model for downstream tasks, only the `ContrastiveModel.backbone` is used.
```python
def get_projector(input_dim, dim, activation="relu", num_layers: int = 3):
inputs = tf.keras.layers.Input((input_dim,), name="projector_input")
x = inputs
for i in range(num_layers - 1):
x = tf.keras.layers.Dense(
dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name=f"projector_layer_{i}",
)(x)
x = tf.keras.layers.BatchNormalization(
epsilon=1.001e-5, name=f"batch_normalization_{i}"
)(x)
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_{i}")(
x
)
x = tf.keras.layers.Dense(
dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="projector_output",
)(x)
x = tf.keras.layers.BatchNormalization(
epsilon=1.001e-5,
center=False, # Page:5, Paragraph:2 of SimSiam paper
scale=False, # Page:5, Paragraph:2 of SimSiam paper
name=f"batch_normalization_ouput",
)(x)
# Metric Logging layer. Monitors the std of the layer activations.
# Degenerate solutions colapse to 0 while valid solutions will move
# towards something like 0.0220. The actual number will depend on the layer size.
o = tfsim.layers.ActivationStdLoggingLayer(name="proj_std")(x)
projector = tf.keras.Model(inputs, o, name="projector")
return projector
projector = get_projector(input_dim=backbone.output.shape[-1], dim=DIM, num_layers=2)
projector.summary()
```
Model: "projector"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
projector_input (InputLayer [(None, 512)] 0
)
projector_layer_0 (Dense) (None, 2048) 1048576
batch_normalization_0 (Batc (None, 2048) 8192
hNormalization)
relu_activation_0 (Activati (None, 2048) 0
on)
projector_output (Dense) (None, 2048) 4194304
batch_normalization_ouput ( (None, 2048) 4096
BatchNormalization)
proj_std (ActivationStdLogg (None, 2048) 0
ingLayer)
================================================================= Total params: 5,255,168 Trainable params: 5,246,976 Non-trainable params: 8,192
</div>
Finally, we must construct the predictor. The predictor is used in SimSiam, and is a
simple stack of two MLP layers, containing a bottleneck in the hidden layer.
```python
def get_predictor(input_dim, hidden_dim=512, activation="relu"):
inputs = tf.keras.layers.Input(shape=(input_dim,), name="predictor_input")
x = inputs
x = tf.keras.layers.Dense(
hidden_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="predictor_layer_0",
)(x)
x = tf.keras.layers.BatchNormalization(
epsilon=1.001e-5, name="batch_normalization_0"
)(x)
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_0")(x)
x = tf.keras.layers.Dense(
input_dim,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="predictor_output",
)(x)
o = tfsim.layers.ActivationStdLoggingLayer(name="pred_std")(x)
predictor = tf.keras.Model(inputs, o, name="predictor")
return predictor
predictor = get_predictor(input_dim=DIM, hidden_dim=512)
predictor.summary()
```
Model: "predictor"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
predictor_input (InputLayer [(None, 2048)] 0
)
predictor_layer_0 (Dense) (None, 512) 1048576
batch_normalization_0 (Batc (None, 512) 2048
hNormalization)
relu_activation_0 (Activati (None, 512) 0
on)
predictor_output (Dense) (None, 2048) 1050624
pred_std (ActivationStdLogg (None, 2048) 0
ingLayer)
================================================================= Total params: 2,101,248 Trainable params: 2,100,224 Non-trainable params: 1,024
</div>
---
## Training
First, we need to initialize our training model, loss, and optimizer.
```python
loss = tfsim.losses.SimSiamLoss(projection_type="cosine_distance", name="simsiam")
contrastive_model = tfsim.models.ContrastiveModel(
backbone=backbone,
projector=projector,
predictor=predictor, # NOTE: simiam requires predictor model.
algorithm="simsiam",
name="simsiam",
)
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=INIT_LR,
decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,
)
wd_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=WEIGHT_DECAY,
decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,
)
optimizer = tfa.optimizers.SGDW(
learning_rate=lr_decayed_fn, weight_decay=wd_decayed_fn, momentum=0.9
)
Next we can compile the model the same way you compile any other Keras model.
contrastive_model.compile(
optimizer=optimizer,
loss=loss,
)
We track the training using EvalCallback
. EvalCallback
creates an index at the end of each epoch and provides a proxy for the nearest neighbor matching classification using binary_accuracy
. Calculates how often the query label matches the derived lookup label.
Accuracy is technically (TP+TN)/(TP+FP+TN+FN), but here we filter all queries above the distance threshold. In the case of binary matching, this makes all the TPs and FPs below the distance threshold and all the TNs and FNs above the distance threshold.
As we are only concerned with the matches below the distance threshold, the accuracy simplifies to TP/(TP+FP) and is equivalent to the precision with respect to the unfiltered queries. However, we also want to consider the query coverage at the distance threshold, i.e., the percentage of queries that return a match, computed as (TP+FP)/(TP+FP+TN+FN). Therefore, we can take precision×querycoverage to produce a measure that capture the precision scaled by the query coverage. This simplifies down to the binary accuracy presented here, giving TP/(TP+FP+TN+FN).
DATA_PATH = Path("./")
log_dir = DATA_PATH / "models" / "logs" / f"{loss.name}_{time.time()}"
chkpt_dir = DATA_PATH / "models" / "checkpoints" / f"{loss.name}_{time.time()}"
callbacks = [
tfsim.callbacks.EvalCallback(
tf.cast(x_query, tf.float32),
y_query,
tf.cast(x_index, tf.float32),
y_index,
metrics=["binary_accuracy"],
k=1,
tb_logdir=log_dir,
),
tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
update_freq=100,
),
tf.keras.callbacks.ModelCheckpoint(
filepath=chkpt_dir,
monitor="val_loss",
mode="min",
save_best_only=True,
save_weights_only=True,
),
]
```
TensorBoard logging enable in models/logs/simsiam_1674516693.2898047/index
</div>
All that is left to do is run fit()!
```python
print(train_ds)
print(val_ds)
history = contrastive_model.fit(
train_ds,
epochs=PRE_TRAIN_EPOCHS,
steps_per_epoch=PRE_TRAIN_STEPS_PER_EPOCH,
validation_data=val_ds,
validation_steps=VAL_STEPS_PER_EPOCH,
callbacks=callbacks,
)
```
Epoch 1/50
195/195 [==============================] - ETA: 0s - loss: 0.3137 - proj_std: 0.0122 - pred_std: 0.0076binary_accuracy: 0.2270
195/195 [==============================] - 90s 398ms/step - loss: 0.3137 - proj_std: 0.0122 - pred_std: 0.0076 - val_loss: 0.0764 - val_proj_std: 0.0068 - val_pred_std: 0.0026 - binary_accuracy: 0.2270
Epoch 2/50
195/195 [==============================] - ETA: 0s - loss: 0.1469 - proj_std: 0.0101 - pred_std: 0.0048binary_accuracy: 0.2260
195/195 [==============================] - 76s 390ms/step - loss: 0.1469 - proj_std: 0.0101 - pred_std: 0.0048 - val_loss: 0.3514 - val_proj_std: 0.0162 - val_pred_std: 0.0096 - binary_accuracy: 0.2260
Epoch 3/50
195/195 [==============================] - ETA: 0s - loss: 0.2779 - proj_std: 0.0145 - pred_std: 0.0102binary_accuracy: 0.2550
195/195 [==============================] - 74s 379ms/step - loss: 0.2779 - proj_std: 0.0145 - pred_std: 0.0102 - val_loss: 0.4038 - val_proj_std: 0.0158 - val_pred_std: 0.0122 - binary_accuracy: 0.2550
Epoch 4/50
195/195 [==============================] - ETA: 0s - loss: 0.2763 - proj_std: 0.0167 - pred_std: 0.0145binary_accuracy: 0.2630
195/195 [==============================] - 75s 385ms/step - loss: 0.2763 - proj_std: 0.0167 - pred_std: 0.0145 - val_loss: 0.1668 - val_proj_std: 0.0135 - val_pred_std: 0.0114 - binary_accuracy: 0.2630
Epoch 5/50
195/195 [==============================] - ETA: 0s - loss: 0.2235 - proj_std: 0.0174 - pred_std: 0.0154binary_accuracy: 0.2530
195/195 [==============================] - 63s 326ms/step - loss: 0.2235 - proj_std: 0.0174 - pred_std: 0.0154 - val_loss: 0.1847 - val_proj_std: 0.0144 - val_pred_std: 0.0127 - binary_accuracy: 0.2530
Epoch 6/50
195/195 [==============================] - ETA: 0s - loss: 0.2091 - proj_std: 0.0181 - pred_std: 0.0165binary_accuracy: 0.2580
195/195 [==============================] - 68s 350ms/step - loss: 0.2091 - proj_std: 0.0181 - pred_std: 0.0165 - val_loss: 0.2072 - val_proj_std: 0.0189 - val_pred_std: 0.0176 - binary_accuracy: 0.2580
Epoch 7/50
195/195 [==============================] - ETA: 0s - loss: 0.2385 - proj_std: 0.0191 - pred_std: 0.0178binary_accuracy: 0.2700
195/195 [==============================] - 69s 354ms/step - loss: 0.2385 - proj_std: 0.0191 - pred_std: 0.0178 - val_loss: 0.4700 - val_proj_std: 0.0186 - val_pred_std: 0.0199 - binary_accuracy: 0.2700
Epoch 8/50
195/195 [==============================] - ETA: 0s - loss: 0.1986 - proj_std: 0.0186 - pred_std: 0.0174binary_accuracy: 0.2750
195/195 [==============================] - 68s 350ms/step - loss: 0.1986 - proj_std: 0.0186 - pred_std: 0.0174 - val_loss: 0.3135 - val_proj_std: 0.0192 - val_pred_std: 0.0181 - binary_accuracy: 0.2750
Epoch 9/50
195/195 [==============================] - ETA: 0s - loss: 0.2182 - proj_std: 0.0191 - pred_std: 0.0180binary_accuracy: 0.2670
195/195 [==============================] - 68s 350ms/step - loss: 0.2182 - proj_std: 0.0191 - pred_std: 0.0180 - val_loss: 0.2822 - val_proj_std: 0.0155 - val_pred_std: 0.0135 - binary_accuracy: 0.2670
Epoch 10/50
195/195 [==============================] - ETA: 0s - loss: 0.1991 - proj_std: 0.0185 - pred_std: 0.0173binary_accuracy: 0.3090
195/195 [==============================] - 69s 353ms/step - loss: 0.1991 - proj_std: 0.0185 - pred_std: 0.0173 - val_loss: 0.1550 - val_proj_std: 0.0134 - val_pred_std: 0.0117 - binary_accuracy: 0.3090
Epoch 11/50
195/195 [==============================] - ETA: 0s - loss: 0.2080 - proj_std: 0.0185 - pred_std: 0.0175binary_accuracy: 0.2840
195/195 [==============================] - 69s 353ms/step - loss: 0.2080 - proj_std: 0.0185 - pred_std: 0.0175 - val_loss: 0.2511 - val_proj_std: 0.0185 - val_pred_std: 0.0181 - binary_accuracy: 0.2840
Epoch 12/50
195/195 [==============================] - ETA: 0s - loss: 0.1934 - proj_std: 0.0186 - pred_std: 0.0176binary_accuracy: 0.2980
195/195 [==============================] - 68s 352ms/step - loss: 0.1934 - proj_std: 0.0186 - pred_std: 0.0176 - val_loss: 0.1785 - val_proj_std: 0.0122 - val_pred_std: 0.0104 - binary_accuracy: 0.2980
Epoch 13/50
195/195 [==============================] - ETA: 0s - loss: 0.1945 - proj_std: 0.0190 - pred_std: 0.0181binary_accuracy: 0.3020
195/195 [==============================] - 69s 356ms/step - loss: 0.1945 - proj_std: 0.0190 - pred_std: 0.0181 - val_loss: 0.1189 - val_proj_std: 0.0118 - val_pred_std: 0.0107 - binary_accuracy: 0.3020
Epoch 14/50
195/195 [==============================] - ETA: 0s - loss: 0.2009 - proj_std: 0.0194 - pred_std: 0.0187binary_accuracy: 0.3320
195/195 [==============================] - 68s 350ms/step - loss: 0.2009 - proj_std: 0.0194 - pred_std: 0.0187 - val_loss: 0.1736 - val_proj_std: 0.0127 - val_pred_std: 0.0114 - binary_accuracy: 0.3320
Epoch 15/50
195/195 [==============================] - ETA: 0s - loss: 0.2029 - proj_std: 0.0194 - pred_std: 0.0186binary_accuracy: 0.3320
195/195 [==============================] - 69s 353ms/step - loss: 0.2029 - proj_std: 0.0194 - pred_std: 0.0186 - val_loss: 0.1638 - val_proj_std: 0.0154 - val_pred_std: 0.0148 - binary_accuracy: 0.3320
Epoch 16/50
195/195 [==============================] - ETA: 0s - loss: 0.1972 - proj_std: 0.0196 - pred_std: 0.0190binary_accuracy: 0.3060
195/195 [==============================] - 68s 348ms/step - loss: 0.1972 - proj_std: 0.0196 - pred_std: 0.0190 - val_loss: 0.2987 - val_proj_std: 0.0203 - val_pred_std: 0.0202 - binary_accuracy: 0.3060
Epoch 17/50
195/195 [==============================] - ETA: 0s - loss: 0.1885 - proj_std: 0.0196 - pred_std: 0.0190binary_accuracy: 0.3050
195/195 [==============================] - 68s 350ms/step - loss: 0.1885 - proj_std: 0.0196 - pred_std: 0.0190 - val_loss: 0.1805 - val_proj_std: 0.0161 - val_pred_std: 0.0150 - binary_accuracy: 0.3050
Epoch 18/50
195/195 [==============================] - ETA: 0s - loss: 0.1933 - proj_std: 0.0196 - pred_std: 0.0189binary_accuracy: 0.3270
195/195 [==============================] - 68s 352ms/step - loss: 0.1933 - proj_std: 0.0196 - pred_std: 0.0189 - val_loss: 0.1917 - val_proj_std: 0.0159 - val_pred_std: 0.0151 - binary_accuracy: 0.3270
Epoch 19/50
195/195 [==============================] - ETA: 0s - loss: 0.1934 - proj_std: 0.0196 - pred_std: 0.0189binary_accuracy: 0.3260
195/195 [==============================] - 69s 353ms/step - loss: 0.1934 - proj_std: 0.0196 - pred_std: 0.0189 - val_loss: 0.1808 - val_proj_std: 0.0171 - val_pred_std: 0.0165 - binary_accuracy: 0.3260
Epoch 20/50
195/195 [==============================] - ETA: 0s - loss: 0.1757 - proj_std: 0.0194 - pred_std: 0.0187binary_accuracy: 0.3180
195/195 [==============================] - 68s 349ms/step - loss: 0.1757 - proj_std: 0.0194 - pred_std: 0.0187 - val_loss: 0.1957 - val_proj_std: 0.0176 - val_pred_std: 0.0167 - binary_accuracy: 0.3180
Epoch 21/50
195/195 [==============================] - ETA: 0s - loss: 0.1752 - proj_std: 0.0194 - pred_std: 0.0188binary_accuracy: 0.3000
195/195 [==============================] - 78s 403ms/step - loss: 0.1752 - proj_std: 0.0194 - pred_std: 0.0188 - val_loss: 0.2070 - val_proj_std: 0.0178 - val_pred_std: 0.0172 - binary_accuracy: 0.3000
Epoch 22/50
195/195 [==============================] - ETA: 0s - loss: 0.1743 - proj_std: 0.0195 - pred_std: 0.0190binary_accuracy: 0.3280
195/195 [==============================] - 62s 317ms/step - loss: 0.1743 - proj_std: 0.0195 - pred_std: 0.0190 - val_loss: 0.2240 - val_proj_std: 0.0181 - val_pred_std: 0.0176 - binary_accuracy: 0.3280
Epoch 23/50
195/195 [==============================] - ETA: 0s - loss: 0.1692 - proj_std: 0.0193 - pred_std: 0.0188binary_accuracy: 0.3140
195/195 [==============================] - 68s 348ms/step - loss: 0.1692 - proj_std: 0.0193 - pred_std: 0.0188 - val_loss: 0.1892 - val_proj_std: 0.0186 - val_pred_std: 0.0181 - binary_accuracy: 0.3140
Epoch 24/50
195/195 [==============================] - ETA: 0s - loss: 0.1529 - proj_std: 0.0190 - pred_std: 0.0184binary_accuracy: 0.3280
195/195 [==============================] - 69s 353ms/step - loss: 0.1529 - proj_std: 0.0190 - pred_std: 0.0184 - val_loss: 0.2405 - val_proj_std: 0.0196 - val_pred_std: 0.0194 - binary_accuracy: 0.3280
Epoch 25/50
195/195 [==============================] - ETA: 0s - loss: 0.1425 - proj_std: 0.0187 - pred_std: 0.0182binary_accuracy: 0.3560
195/195 [==============================] - 75s 384ms/step - loss: 0.1425 - proj_std: 0.0187 - pred_std: 0.0182 - val_loss: 0.1602 - val_proj_std: 0.0181 - val_pred_std: 0.0178 - binary_accuracy: 0.3560
Epoch 26/50
195/195 [==============================] - ETA: 0s - loss: 0.1277 - proj_std: 0.0186 - pred_std: 0.0182binary_accuracy: 0.3080
195/195 [==============================] - 63s 322ms/step - loss: 0.1277 - proj_std: 0.0186 - pred_std: 0.0182 - val_loss: 0.1815 - val_proj_std: 0.0193 - val_pred_std: 0.0192 - binary_accuracy: 0.3080
Epoch 27/50
195/195 [==============================] - ETA: 0s - loss: 0.1326 - proj_std: 0.0189 - pred_std: 0.0185binary_accuracy: 0.3540
195/195 [==============================] - 69s 357ms/step - loss: 0.1326 - proj_std: 0.0189 - pred_std: 0.0185 - val_loss: 0.1919 - val_proj_std: 0.0177 - val_pred_std: 0.0174 - binary_accuracy: 0.3540
Epoch 28/50
195/195 [==============================] - ETA: 0s - loss: 0.1383 - proj_std: 0.0187 - pred_std: 0.0183binary_accuracy: 0.4060
195/195 [==============================] - 75s 388ms/step - loss: 0.1383 - proj_std: 0.0187 - pred_std: 0.0183 - val_loss: 0.1795 - val_proj_std: 0.0170 - val_pred_std: 0.0165 - binary_accuracy: 0.4060
Epoch 29/50
195/195 [==============================] - ETA: 0s - loss: 0.1348 - proj_std: 0.0177 - pred_std: 0.0172binary_accuracy: 0.3410
195/195 [==============================] - 61s 312ms/step - loss: 0.1348 - proj_std: 0.0177 - pred_std: 0.0172 - val_loss: 0.2115 - val_proj_std: 0.0187 - val_pred_std: 0.0185 - binary_accuracy: 0.3410
Epoch 30/50
195/195 [==============================] - ETA: 0s - loss: 0.1198 - proj_std: 0.0178 - pred_std: 0.0174binary_accuracy: 0.3520
195/195 [==============================] - 78s 401ms/step - loss: 0.1198 - proj_std: 0.0178 - pred_std: 0.0174 - val_loss: 0.1277 - val_proj_std: 0.0124 - val_pred_std: 0.0115 - binary_accuracy: 0.3520
Epoch 31/50
195/195 [==============================] - ETA: 0s - loss: 0.1185 - proj_std: 0.0180 - pred_std: 0.0176binary_accuracy: 0.3840
195/195 [==============================] - 68s 349ms/step - loss: 0.1185 - proj_std: 0.0180 - pred_std: 0.0176 - val_loss: 0.1637 - val_proj_std: 0.0187 - val_pred_std: 0.0185 - binary_accuracy: 0.3840
Epoch 32/50
195/195 [==============================] - ETA: 0s - loss: 0.1228 - proj_std: 0.0181 - pred_std: 0.0177binary_accuracy: 0.3790
195/195 [==============================] - 61s 312ms/step - loss: 0.1228 - proj_std: 0.0181 - pred_std: 0.0177 - val_loss: 0.1381 - val_proj_std: 0.0185 - val_pred_std: 0.0182 - binary_accuracy: 0.3790
Epoch 33/50
195/195 [==============================] - ETA: 0s - loss: 0.1180 - proj_std: 0.0176 - pred_std: 0.0173binary_accuracy: 0.4050
195/195 [==============================] - 70s 358ms/step - loss: 0.1180 - proj_std: 0.0176 - pred_std: 0.0173 - val_loss: 0.1273 - val_proj_std: 0.0188 - val_pred_std: 0.0186 - binary_accuracy: 0.4050
Epoch 34/50
195/195 [==============================] - ETA: 0s - loss: 0.1145 - proj_std: 0.0176 - pred_std: 0.0173binary_accuracy: 0.3880
195/195 [==============================] - 67s 342ms/step - loss: 0.1145 - proj_std: 0.0176 - pred_std: 0.0173 - val_loss: 0.1958 - val_proj_std: 0.0191 - val_pred_std: 0.0193 - binary_accuracy: 0.3880
Epoch 35/50
195/195 [==============================] - ETA: 0s - loss: 0.1112 - proj_std: 0.0175 - pred_std: 0.0172binary_accuracy: 0.3840
195/195 [==============================] - 68s 348ms/step - loss: 0.1112 - proj_std: 0.0175 - pred_std: 0.0172 - val_loss: 0.1372 - val_proj_std: 0.0186 - val_pred_std: 0.0185 - binary_accuracy: 0.3840
Epoch 36/50
195/195 [==============================] - ETA: 0s - loss: 0.1149 - proj_std: 0.0173 - pred_std: 0.0171binary_accuracy: 0.4030
195/195 [==============================] - 67s 343ms/step - loss: 0.1149 - proj_std: 0.0173 - pred_std: 0.0171 - val_loss: 0.1284 - val_proj_std: 0.0165 - val_pred_std: 0.0163 - binary_accuracy: 0.4030
Epoch 37/50
195/195 [==============================] - ETA: 0s - loss: 0.1108 - proj_std: 0.0174 - pred_std: 0.0171binary_accuracy: 0.4100
195/195 [==============================] - 71s 366ms/step - loss: 0.1108 - proj_std: 0.0174 - pred_std: 0.0171 - val_loss: 0.1387 - val_proj_std: 0.0145 - val_pred_std: 0.0141 - binary_accuracy: 0.4100
Epoch 38/50
195/195 [==============================] - ETA: 0s - loss: 0.1028 - proj_std: 0.0174 - pred_std: 0.0172binary_accuracy: 0.4180
195/195 [==============================] - 66s 338ms/step - loss: 0.1028 - proj_std: 0.0174 - pred_std: 0.0172 - val_loss: 0.1183 - val_proj_std: 0.0182 - val_pred_std: 0.0180 - binary_accuracy: 0.4180
Epoch 39/50
195/195 [==============================] - ETA: 0s - loss: 0.1011 - proj_std: 0.0171 - pred_std: 0.0170binary_accuracy: 0.4020
195/195 [==============================] - 69s 357ms/step - loss: 0.1011 - proj_std: 0.0171 - pred_std: 0.0170 - val_loss: 0.1056 - val_proj_std: 0.0177 - val_pred_std: 0.0176 - binary_accuracy: 0.4020
Epoch 40/50
195/195 [==============================] - ETA: 0s - loss: 0.1081 - proj_std: 0.0167 - pred_std: 0.0165binary_accuracy: 0.4670
195/195 [==============================] - 67s 346ms/step - loss: 0.1081 - proj_std: 0.0167 - pred_std: 0.0165 - val_loss: 0.1144 - val_proj_std: 0.0182 - val_pred_std: 0.0182 - binary_accuracy: 0.4670
Epoch 41/50
195/195 [==============================] - ETA: 0s - loss: 0.1060 - proj_std: 0.0166 - pred_std: 0.0165binary_accuracy: 0.4280
195/195 [==============================] - 68s 349ms/step - loss: 0.1060 - proj_std: 0.0166 - pred_std: 0.0165 - val_loss: 0.1180 - val_proj_std: 0.0175 - val_pred_std: 0.0174 - binary_accuracy: 0.4280
Epoch 42/50
195/195 [==============================] - ETA: 0s - loss: 0.1063 - proj_std: 0.0163 - pred_std: 0.0162binary_accuracy: 0.4220
195/195 [==============================] - 69s 356ms/step - loss: 0.1063 - proj_std: 0.0163 - pred_std: 0.0162 - val_loss: 0.1143 - val_proj_std: 0.0173 - val_pred_std: 0.0171 - binary_accuracy: 0.4220
Epoch 43/50
195/195 [==============================] - ETA: 0s - loss: 0.1050 - proj_std: 0.0162 - pred_std: 0.0161binary_accuracy: 0.4310
195/195 [==============================] - 69s 353ms/step - loss: 0.1050 - proj_std: 0.0162 - pred_std: 0.0161 - val_loss: 0.1171 - val_proj_std: 0.0169 - val_pred_std: 0.0168 - binary_accuracy: 0.4310
Epoch 44/50
195/195 [==============================] - ETA: 0s - loss: 0.1013 - proj_std: 0.0159 - pred_std: 0.0157binary_accuracy: 0.4140
195/195 [==============================] - 75s 386ms/step - loss: 0.1013 - proj_std: 0.0159 - pred_std: 0.0157 - val_loss: 0.1106 - val_proj_std: 0.0161 - val_pred_std: 0.0159 - binary_accuracy: 0.4140
Epoch 45/50
195/195 [==============================] - ETA: 0s - loss: 0.1035 - proj_std: 0.0160 - pred_std: 0.0159binary_accuracy: 0.4350
195/195 [==============================] - 63s 324ms/step - loss: 0.1035 - proj_std: 0.0160 - pred_std: 0.0159 - val_loss: 0.1086 - val_proj_std: 0.0171 - val_pred_std: 0.0171 - binary_accuracy: 0.4350
Epoch 46/50
195/195 [==============================] - ETA: 0s - loss: 0.0999 - proj_std: 0.0157 - pred_std: 0.0157binary_accuracy: 0.4510
195/195 [==============================] - 69s 354ms/step - loss: 0.0999 - proj_std: 0.0157 - pred_std: 0.0157 - val_loss: 0.1000 - val_proj_std: 0.0164 - val_pred_std: 0.0164 - binary_accuracy: 0.4510
Epoch 47/50
195/195 [==============================] - ETA: 0s - loss: 0.1002 - proj_std: 0.0157 - pred_std: 0.0156binary_accuracy: 0.4680
195/195 [==============================] - 68s 351ms/step - loss: 0.1002 - proj_std: 0.0157 - pred_std: 0.0156 - val_loss: 0.1067 - val_proj_std: 0.0163 - val_pred_std: 0.0163 - binary_accuracy: 0.4680
Epoch 48/50
195/195 [==============================] - ETA: 0s - loss: 0.0980 - proj_std: 0.0155 - pred_std: 0.0153binary_accuracy: 0.4410
195/195 [==============================] - 68s 352ms/step - loss: 0.0980 - proj_std: 0.0155 - pred_std: 0.0153 - val_loss: 0.0986 - val_proj_std: 0.0159 - val_pred_std: 0.0159 - binary_accuracy: 0.4410
Epoch 49/50
195/195 [==============================] - ETA: 0s - loss: 0.0944 - proj_std: 0.0155 - pred_std: 0.0154binary_accuracy: 0.4520
195/195 [==============================] - 69s 355ms/step - loss: 0.0944 - proj_std: 0.0155 - pred_std: 0.0154 - val_loss: 0.0949 - val_proj_std: 0.0164 - val_pred_std: 0.0163 - binary_accuracy: 0.4520
Epoch 50/50
195/195 [==============================] - ETA: 0s - loss: 0.0937 - proj_std: 0.0155 - pred_std: 0.0154binary_accuracy: 0.4570
195/195 [==============================] - 67s 347ms/step - loss: 0.0937 - proj_std: 0.0155 - pred_std: 0.0154 - val_loss: 0.0978 - val_proj_std: 0.0166 - val_pred_std: 0.0165 - binary_accuracy: 0.4570
</div>
---
```python
plt.figure(figsize=(15, 4))
plt.subplot(1, 3, 1)
plt.plot(history.history["loss"])
plt.grid()
plt.title(f"{loss.name} - loss")
plt.subplot(1, 3, 2)
plt.plot(history.history["proj_std"], label="proj")
if "pred_std" in history.history:
plt.plot(history.history["pred_std"], label="pred")
plt.grid()
plt.title(f"{loss.name} - std metrics")
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(history.history["binary_accuracy"], label="acc")
plt.grid()
plt.title(f"{loss.name} - match metrics")
plt.legend()
plt.show()

Fine Tuning on the Labelled Data
As a final step we will fine tune a classifier on 10% of the training data. This will allow us to evaluate the quality of our learned representation. First, we handle data loading:
eval_augmenter = keras.Sequential(
[
keras_cv.layers.RandomCropAndResize(
(96, 96), crop_area_factor=(0.8, 1.0), aspect_ratio_factor=(1.0, 1.0)
),
keras_cv.layers.RandomFlip(mode="horizontal"),
]
)
eval_train_ds = tf.data.Dataset.from_tensor_slices(
(x_raw_train, tf.keras.utils.to_categorical(y_raw_train, 10))
)
eval_train_ds = eval_train_ds.repeat()
eval_train_ds = eval_train_ds.shuffle(1024)
eval_train_ds = eval_train_ds.map(lambda x, y: (eval_augmenter(x), y), tf.data.AUTOTUNE)
eval_train_ds = eval_train_ds.batch(BATCH_SIZE)
eval_train_ds = eval_train_ds.prefetch(tf.data.AUTOTUNE)
eval_val_ds = tf.data.Dataset.from_tensor_slices(
(x_test, tf.keras.utils.to_categorical(y_test, 10))
)
eval_val_ds = eval_val_ds.repeat()
eval_val_ds = eval_val_ds.shuffle(1024)
eval_val_ds = eval_val_ds.batch(BATCH_SIZE)
eval_val_ds = eval_val_ds.prefetch(tf.data.AUTOTUNE)
Benchmark Against a Naive Model
Finally, lets setup a naive model that does not leverage the unlabeled data corpus.
TEST_EPOCHS = 50
TEST_STEPS_PER_EPOCH = x_raw_train.shape[0] // BATCH_SIZE
def get_eval_model(img_size, backbone, total_steps, trainable=True, lr=1.8):
backbone.trainable = trainable
inputs = tf.keras.layers.Input((img_size, img_size, 3), name="eval_input")
x = backbone(inputs, training=trainable)
o = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs, o)
cosine_decayed_lr = tf.keras.experimental.CosineDecay(
initial_learning_rate=lr, decay_steps=total_steps
)
opt = tf.keras.optimizers.SGD(cosine_decayed_lr, momentum=0.9)
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])
return model
no_pt_eval_model = get_eval_model(
img_size=96,
backbone=get_backbone((96, 96, 3)),
total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
trainable=True,
lr=1e-3,
)
no_pt_history = no_pt_eval_model.fit(
eval_train_ds,
batch_size=BATCH_SIZE,
epochs=TEST_EPOCHS,
steps_per_epoch=TEST_STEPS_PER_EPOCH,
validation_data=eval_val_ds,
validation_steps=VAL_STEPS_PER_EPOCH,
)
```
Epoch 1/50
9/9 [==============================] - 6s 249ms/step - loss: 2.4969 - acc: 0.1302 - val_loss: 2.2889 - val_acc: 0.1669
Epoch 2/50
9/9 [==============================] - 1s 139ms/step - loss: 2.2002 - acc: 0.1888 - val_loss: 2.1074 - val_acc: 0.2160
Epoch 3/50
9/9 [==============================] - 1s 139ms/step - loss: 2.0066 - acc: 0.2619 - val_loss: 1.9138 - val_acc: 0.2968
Epoch 4/50
9/9 [==============================] - 1s 139ms/step - loss: 1.8394 - acc: 0.3227 - val_loss: 1.7825 - val_acc: 0.3326
Epoch 5/50
9/9 [==============================] - 1s 140ms/step - loss: 1.7191 - acc: 0.3585 - val_loss: 1.7004 - val_acc: 0.3545
Epoch 6/50
9/9 [==============================] - 1s 140ms/step - loss: 1.6458 - acc: 0.3806 - val_loss: 1.6473 - val_acc: 0.3734
Epoch 7/50
9/9 [==============================] - 1s 139ms/step - loss: 1.5798 - acc: 0.4030 - val_loss: 1.6009 - val_acc: 0.3907
Epoch 8/50
9/9 [==============================] - 1s 139ms/step - loss: 1.5244 - acc: 0.4332 - val_loss: 1.5696 - val_acc: 0.4029
Epoch 9/50
9/9 [==============================] - 1s 140ms/step - loss: 1.4977 - acc: 0.4325 - val_loss: 1.5416 - val_acc: 0.4126
Epoch 10/50
9/9 [==============================] - 1s 139ms/step - loss: 1.4555 - acc: 0.4559 - val_loss: 1.5087 - val_acc: 0.4271
Epoch 11/50
9/9 [==============================] - 1s 140ms/step - loss: 1.4294 - acc: 0.4627 - val_loss: 1.4897 - val_acc: 0.4384
Epoch 12/50
9/9 [==============================] - 1s 139ms/step - loss: 1.4031 - acc: 0.4820 - val_loss: 1.4759 - val_acc: 0.4410
Epoch 13/50
9/9 [==============================] - 1s 141ms/step - loss: 1.3625 - acc: 0.4941 - val_loss: 1.4501 - val_acc: 0.4486
Epoch 14/50
9/9 [==============================] - 1s 140ms/step - loss: 1.3443 - acc: 0.5026 - val_loss: 1.4390 - val_acc: 0.4525
Epoch 15/50
9/9 [==============================] - 1s 139ms/step - loss: 1.3235 - acc: 0.5067 - val_loss: 1.4308 - val_acc: 0.4578
Epoch 16/50
9/9 [==============================] - 1s 139ms/step - loss: 1.2863 - acc: 0.5328 - val_loss: 1.4089 - val_acc: 0.4650
Epoch 17/50
9/9 [==============================] - 1s 140ms/step - loss: 1.2851 - acc: 0.5339 - val_loss: 1.3944 - val_acc: 0.4700
Epoch 18/50
9/9 [==============================] - 1s 141ms/step - loss: 1.2501 - acc: 0.5464 - val_loss: 1.3887 - val_acc: 0.4773
Epoch 19/50
9/9 [==============================] - 1s 139ms/step - loss: 1.2324 - acc: 0.5510 - val_loss: 1.3783 - val_acc: 0.4820
Epoch 20/50
9/9 [==============================] - 1s 140ms/step - loss: 1.2223 - acc: 0.5562 - val_loss: 1.3655 - val_acc: 0.4848
Epoch 21/50
9/9 [==============================] - 1s 140ms/step - loss: 1.2070 - acc: 0.5664 - val_loss: 1.3579 - val_acc: 0.4867
Epoch 22/50
9/9 [==============================] - 1s 141ms/step - loss: 1.1820 - acc: 0.5738 - val_loss: 1.3482 - val_acc: 0.4913
Epoch 23/50
9/9 [==============================] - 1s 139ms/step - loss: 1.1688 - acc: 0.5790 - val_loss: 1.3375 - val_acc: 0.4964
Epoch 24/50
9/9 [==============================] - 1s 141ms/step - loss: 1.1514 - acc: 0.5896 - val_loss: 1.3403 - val_acc: 0.4966
Epoch 25/50
9/9 [==============================] - 1s 138ms/step - loss: 1.1307 - acc: 0.5961 - val_loss: 1.3321 - val_acc: 0.5025
Epoch 26/50
9/9 [==============================] - 1s 139ms/step - loss: 1.1341 - acc: 0.6009 - val_loss: 1.3220 - val_acc: 0.5035
Epoch 27/50
9/9 [==============================] - 1s 139ms/step - loss: 1.1177 - acc: 0.5987 - val_loss: 1.3149 - val_acc: 0.5074
Epoch 28/50
9/9 [==============================] - 1s 139ms/step - loss: 1.1078 - acc: 0.6068 - val_loss: 1.3089 - val_acc: 0.5137
Epoch 29/50
9/9 [==============================] - 1s 141ms/step - loss: 1.0929 - acc: 0.6046 - val_loss: 1.3015 - val_acc: 0.5139
Epoch 30/50
9/9 [==============================] - 1s 138ms/step - loss: 1.0915 - acc: 0.6139 - val_loss: 1.3064 - val_acc: 0.5149
Epoch 31/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0634 - acc: 0.6254 - val_loss: 1.2955 - val_acc: 0.5123
Epoch 32/50
9/9 [==============================] - 1s 141ms/step - loss: 1.0675 - acc: 0.6254 - val_loss: 1.2979 - val_acc: 0.5167
Epoch 33/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0595 - acc: 0.6289 - val_loss: 1.2911 - val_acc: 0.5186
Epoch 34/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0397 - acc: 0.6328 - val_loss: 1.2906 - val_acc: 0.5208
Epoch 35/50
9/9 [==============================] - 1s 139ms/step - loss: 1.0415 - acc: 0.6378 - val_loss: 1.2863 - val_acc: 0.5222
Epoch 36/50
9/9 [==============================] - 1s 139ms/step - loss: 1.0435 - acc: 0.6257 - val_loss: 1.2830 - val_acc: 0.5215
Epoch 37/50
9/9 [==============================] - 1s 144ms/step - loss: 1.0242 - acc: 0.6461 - val_loss: 1.2820 - val_acc: 0.5268
Epoch 38/50
9/9 [==============================] - 1s 141ms/step - loss: 1.0212 - acc: 0.6421 - val_loss: 1.2766 - val_acc: 0.5259
Epoch 39/50
9/9 [==============================] - 1s 141ms/step - loss: 1.0213 - acc: 0.6385 - val_loss: 1.2770 - val_acc: 0.5259
Epoch 40/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0224 - acc: 0.6428 - val_loss: 1.2742 - val_acc: 0.5262
Epoch 41/50
9/9 [==============================] - 1s 142ms/step - loss: 0.9994 - acc: 0.6510 - val_loss: 1.2755 - val_acc: 0.5238
Epoch 42/50
9/9 [==============================] - 1s 141ms/step - loss: 1.0154 - acc: 0.6474 - val_loss: 1.2784 - val_acc: 0.5244
Epoch 43/50
9/9 [==============================] - 1s 139ms/step - loss: 1.0176 - acc: 0.6441 - val_loss: 1.2680 - val_acc: 0.5247
Epoch 44/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0101 - acc: 0.6471 - val_loss: 1.2711 - val_acc: 0.5288
Epoch 45/50
9/9 [==============================] - 1s 139ms/step - loss: 1.0080 - acc: 0.6536 - val_loss: 1.2691 - val_acc: 0.5275
Epoch 46/50
9/9 [==============================] - 1s 143ms/step - loss: 1.0038 - acc: 0.6428 - val_loss: 1.2706 - val_acc: 0.5302
Epoch 47/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0070 - acc: 0.6573 - val_loss: 1.2678 - val_acc: 0.5293
Epoch 48/50
9/9 [==============================] - 1s 140ms/step - loss: 1.0030 - acc: 0.6450 - val_loss: 1.2723 - val_acc: 0.5278
Epoch 49/50
9/9 [==============================] - 1s 139ms/step - loss: 1.0080 - acc: 0.6447 - val_loss: 1.2691 - val_acc: 0.5252
Epoch 50/50
9/9 [==============================] - 1s 142ms/step - loss: 1.0093 - acc: 0.6497 - val_loss: 1.2712 - val_acc: 0.5278
</div>
Pretty bad results! Lets try fine-tuning our SimSiam pretrained model:
```python
pt_eval_model = get_eval_model(
img_size=96,
backbone=contrastive_model.backbone,
total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
trainable=False,
lr=30.0,
)
pt_eval_model.summary()
pt_history = pt_eval_model.fit(
eval_train_ds,
batch_size=BATCH_SIZE,
epochs=TEST_EPOCHS,
steps_per_epoch=TEST_STEPS_PER_EPOCH,
validation_data=eval_val_ds,
validation_steps=VAL_STEPS_PER_EPOCH,
)
```
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
eval_input (InputLayer) [(None, 96, 96, 3)] 0
similarity_model (Similarit (None, 512) 11186112
yModel)
dense_1 (Dense) (None, 10) 5130
================================================================= Total params: 11,191,242 Trainable params: 5,130 Non-trainable params: 11,186,112
Epoch 1/50 9/9 [==============================] - 3s 172ms/step - loss: 18.2303 - acc: 0.2563 - val_loss: 16.9489 - val_acc: 0.3463 Epoch 2/50 9/9 [==============================] - 1s 109ms/step - loss: 24.5528 - acc: 0.3498 - val_loss: 19.1886 - val_acc: 0.4050 Epoch 3/50 9/9 [==============================] - 1s 110ms/step - loss: 18.3920 - acc: 0.4477 - val_loss: 20.0611 - val_acc: 0.4456 Epoch 4/50 9/9 [==============================] - 1s 113ms/step - loss: 15.4172 - acc: 0.4993 - val_loss: 12.2465 - val_acc: 0.5116 Epoch 5/50 9/9 [==============================] - 1s 110ms/step - loss: 10.5517 - acc: 0.5217 - val_loss: 8.5560 - val_acc: 0.5474 Epoch 6/50 9/9 [==============================] - 1s 112ms/step - loss: 7.4812 - acc: 0.5395 - val_loss: 7.9182 - val_acc: 0.5053 Epoch 7/50 9/9 [==============================] - 1s 112ms/step - loss: 8.4429 - acc: 0.5024 - val_loss: 7.7339 - val_acc: 0.5071 Epoch 8/50 9/9 [==============================] - 1s 113ms/step - loss: 8.6143 - acc: 0.5109 - val_loss: 10.4784 - val_acc: 0.5157 Epoch 9/50 9/9 [==============================] - 1s 111ms/step - loss: 8.7506 - acc: 0.5061 - val_loss: 7.8201 - val_acc: 0.4914 Epoch 10/50 9/9 [==============================] - 1s 111ms/step - loss: 8.7927 - acc: 0.4996 - val_loss: 9.6188 - val_acc: 0.4668 Epoch 11/50 9/9 [==============================] - 1s 114ms/step - loss: 7.2190 - acc: 0.4844 - val_loss: 8.8605 - val_acc: 0.4240 Epoch 12/50 9/9 [==============================] - 1s 111ms/step - loss: 8.0435 - acc: 0.4681 - val_loss: 8.2731 - val_acc: 0.4425 Epoch 13/50 9/9 [==============================] - 1s 112ms/step - loss: 7.1718 - acc: 0.5048 - val_loss: 5.4667 - val_acc: 0.5485 Epoch 14/50 9/9 [==============================] - 1s 111ms/step - loss: 6.7500 - acc: 0.5111 - val_loss: 5.5898 - val_acc: 0.5158 Epoch 15/50 9/9 [==============================] - 1s 110ms/step - loss: 5.1562 - acc: 0.5467 - val_loss: 3.7606 - val_acc: 0.5587 Epoch 16/50 9/9 [==============================] - 1s 111ms/step - loss: 3.5923 - acc: 0.5814 - val_loss: 5.0881 - val_acc: 0.5336 Epoch 17/50 9/9 [==============================] - 1s 110ms/step - loss: 5.1907 - acc: 0.5221 - val_loss: 7.7393 - val_acc: 0.3609 Epoch 18/50 9/9 [==============================] - 1s 112ms/step - loss: 8.0532 - acc: 0.4768 - val_loss: 7.2504 - val_acc: 0.5265 Epoch 19/50 9/9 [==============================] - 1s 111ms/step - loss: 6.5527 - acc: 0.5221 - val_loss: 6.8659 - val_acc: 0.4729 Epoch 20/50 9/9 [==============================] - 1s 113ms/step - loss: 7.0188 - acc: 0.4924 - val_loss: 6.5774 - val_acc: 0.4729 Epoch 21/50 9/9 [==============================] - 1s 112ms/step - loss: 4.8837 - acc: 0.5293 - val_loss: 4.5986 - val_acc: 0.5568 Epoch 22/50 9/9 [==============================] - 1s 113ms/step - loss: 4.5787 - acc: 0.5536 - val_loss: 4.9848 - val_acc: 0.5343 Epoch 23/50 9/9 [==============================] - 1s 111ms/step - loss: 5.3264 - acc: 0.5501 - val_loss: 6.1620 - val_acc: 0.5257 Epoch 24/50 9/9 [==============================] - 1s 118ms/step - loss: 4.6995 - acc: 0.5681 - val_loss: 2.9108 - val_acc: 0.6004 Epoch 25/50 9/9 [==============================] - 1s 111ms/step - loss: 3.0915 - acc: 0.6024 - val_loss: 2.9674 - val_acc: 0.6097 Epoch 26/50 9/9 [==============================] - 1s 112ms/step - loss: 2.9893 - acc: 0.5940 - val_loss: 2.7857 - val_acc: 0.5975 Epoch 27/50 9/9 [==============================] - 1s 112ms/step - loss: 3.0031 - acc: 0.5990 - val_loss: 3.3214 - val_acc: 0.5661 Epoch 28/50 9/9 [==============================] - 1s 110ms/step - loss: 2.4497 - acc: 0.6118 - val_loss: 2.5389 - val_acc: 0.5864 Epoch 29/50 9/9 [==============================] - 1s 112ms/step - loss: 2.2352 - acc: 0.6222 - val_loss: 2.6069 - val_acc: 0.5891 Epoch 30/50 9/9 [==============================] - 1s 110ms/step - loss: 2.0529 - acc: 0.6230 - val_loss: 2.2986 - val_acc: 0.6147 Epoch 31/50 9/9 [==============================] - 1s 113ms/step - loss: 2.1396 - acc: 0.6337 - val_loss: 2.3893 - val_acc: 0.6115 Epoch 32/50 9/9 [==============================] - 1s 110ms/step - loss: 2.0879 - acc: 0.6309 - val_loss: 2.0767 - val_acc: 0.6139 Epoch 33/50 9/9 [==============================] - 1s 111ms/step - loss: 1.9498 - acc: 0.6417 - val_loss: 2.5760 - val_acc: 0.6166 Epoch 34/50 9/9 [==============================] - 1s 111ms/step - loss: 2.0624 - acc: 0.6456 - val_loss: 2.2055 - val_acc: 0.6306 Epoch 35/50 9/9 [==============================] - 1s 113ms/step - loss: 1.9772 - acc: 0.6573 - val_loss: 1.8998 - val_acc: 0.6148 Epoch 36/50 9/9 [==============================] - 1s 110ms/step - loss: 1.7421 - acc: 0.6411 - val_loss: 1.7790 - val_acc: 0.6320 Epoch 37/50 9/9 [==============================] - 1s 112ms/step - loss: 1.6005 - acc: 0.6493 - val_loss: 1.7596 - val_acc: 0.6132 Epoch 38/50 9/9 [==============================] - 1s 111ms/step - loss: 1.4635 - acc: 0.6623 - val_loss: 1.8133 - val_acc: 0.6142 Epoch 39/50 9/9 [==============================] - 1s 112ms/step - loss: 1.4952 - acc: 0.6517 - val_loss: 1.8677 - val_acc: 0.5960 Epoch 40/50 9/9 [==============================] - 1s 113ms/step - loss: 1.4972 - acc: 0.6519 - val_loss: 1.7388 - val_acc: 0.6311 Epoch 41/50 9/9 [==============================] - 1s 113ms/step - loss: 1.4158 - acc: 0.6693 - val_loss: 1.6358 - val_acc: 0.6398 Epoch 42/50 9/9 [==============================] - 1s 110ms/step - loss: 1.3600 - acc: 0.6721 - val_loss: 1.5624 - val_acc: 0.6381 Epoch 43/50 9/9 [==============================] - 1s 112ms/step - loss: 1.2960 - acc: 0.6812 - val_loss: 1.5512 - val_acc: 0.6380 Epoch 44/50 9/9 [==============================] - 1s 111ms/step - loss: 1.3473 - acc: 0.6727 - val_loss: 1.4881 - val_acc: 0.6448 Epoch 45/50 9/9 [==============================] - 1s 111ms/step - loss: 1.1990 - acc: 0.6892 - val_loss: 1.4914 - val_acc: 0.6437 Epoch 46/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2816 - acc: 0.6823 - val_loss: 1.4654 - val_acc: 0.6466 Epoch 47/50 9/9 [==============================] - 1s 113ms/step - loss: 1.2525 - acc: 0.6838 - val_loss: 1.4802 - val_acc: 0.6479 Epoch 48/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2661 - acc: 0.6799 - val_loss: 1.4692 - val_acc: 0.6447 Epoch 49/50 9/9 [==============================] - 1s 111ms/step - loss: 1.2389 - acc: 0.6866 - val_loss: 1.4733 - val_acc: 0.6436 Epoch 50/50 9/9 [==============================] - 1s 113ms/step - loss: 1.2166 - acc: 0.6875 - val_loss: 1.4666 - val_acc: 0.6444
</div>
All that is left to do is evaluate the models:
```python
print(
"no pretrain",
no_pt_eval_model.evaluate(
eval_val_ds,
steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
),
)
print(
"pretrained",
pt_eval_model.evaluate(
eval_val_ds,
steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
),
)
```
450/450 [==============================] - 14s 30ms/step - loss: 1.2648 - acc: 0.5311
no pretrain [1.2647558450698853, 0.5310590267181396]
450/450 [==============================] - 12s 26ms/step - loss: 1.4653 - acc: 0.6474
pretrained [1.465279221534729, 0.6474305391311646]
</div>
Awesome! Our pretrained model stomped the non-pretrained model.
This accuracy is quite good for a ResNet18 on the STL-10 dataset.
For better results, try using an EfficientNetV2B0 instead.
Unfortunately, this will require a higher end graphics card as
SimSiam has a minimum batch size of 512.
---
## Conclusion
TensorFlow Similarity can be used to easily train KerasCV models using
contrastive algorithms such as SimCLR, SimSiam and BarlowTwins.
This allows you to leverage large corpuses of unlabelled data in your
model trainining pipeline.
Some follow-up exercises to this tutorial:
- Train a [`keras_cv.models.EfficientNetV2B0`](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/efficientnet_v2.py)
on STL-10
- Experiment with other data augmentation techniques in pretraining
- Train a model using the [BarlowTwins implementation](https://github.com/tensorflow/similarity/blob/master/examples/unsupervised_hello_world.ipynb) in TensorFlow similarity
- Try pretraining on your own dataset