Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_cv/simsiam_with_kerascv.py
3282 views
1
"""
2
Title: SimSiam Training with TensorFlow Similarity and KerasCV
3
Author: [lukewood](https://lukewood.xyz), Ian Stenbit, Owen Vallis
4
Date created: 2023/01/22
5
Last modified: 2023/01/22
6
Description: Train a KerasCV model using unlabelled data with SimSiam.
7
"""
8
9
"""
10
## Overview
11
12
[TensorFlow similarity](https://github.com/tensorflow/similarity) makes it easy to train
13
KerasCV models on unlabelled corpuses of data using contrastive learning algorithms such
14
as SimCLR, SimSiam, and Barlow Twins. In this guide, we will train a KerasCV model
15
using the SimSiam implementation from TensorFlow Similarity.
16
17
## Background
18
19
Self-supervised learning is an approach to pre-training models using unlabeled data.
20
This approach drastically increases accuracy when you have very few labeled examples but
21
a lot of unlabelled data.
22
The key insight is that you can train a self-supervised model to learn data
23
representations by contrasting multiple augmented views of the same example.
24
These learned representations capture data invariants, e.g., object translation, color
25
jitter, noise, etc. Training a simple linear classifier on top of the frozen
26
representations is easier and requires fewer labels because the pre-trained model
27
already produces meaningful and generally useful features.
28
29
Overall, self-supervised pre-training learns representations which are [more generic and
30
robust than other approaches to augmented training and pre-training](https://arxiv.org/abs/2002.05709).
31
An overview of the general contrastive learning process is shown below:
32
33
![Contrastive overview](https://i.imgur.com/mzaEq3C.png)
34
35
In this tutorial, we will use the [SimSiam](https://arxiv.org/abs/2011.10566) algorithm
36
for contrastive learning. As of 2022, SimSiam is the state of the art algorithm for
37
contrastive learning; allowing for unprecedented scores on CIFAR-100 and other datasets.
38
39
You may need to install:
40
41
```
42
pip -q install tensorflow_similarity
43
pip -q install keras-cv
44
```
45
46
To get started, we will sort out some imports.
47
"""
48
import resource
49
import gc
50
import os
51
import random
52
import time
53
import tensorflow_addons as tfa
54
import keras_cv
55
from pathlib import Path
56
import matplotlib.pyplot as plt
57
import numpy as np
58
from tensorflow import keras
59
from tensorflow.keras import layers
60
from tabulate import tabulate
61
import tensorflow_similarity as tfsim # main package
62
import tensorflow as tf
63
from keras_cv import layers as cv_layers
64
65
import tensorflow_datasets as tfds
66
67
"""
68
Lets sort out some high level config issues and define some constants.
69
The resource limit increase is required to load STL-10, `tfsim.utils.tf_cap_memory()`
70
prevents TensorFlow from hogging the GPU memory in a cluster, and
71
`tfds.disable_progress_bar()` makes tfds less noisy.
72
"""
73
74
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
75
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
76
tfsim.utils.tf_cap_memory() # Avoid GPU memory blow up
77
tfds.disable_progress_bar()
78
79
BATCH_SIZE = 512
80
PRE_TRAIN_EPOCHS = 50
81
VAL_STEPS_PER_EPOCH = 20
82
WEIGHT_DECAY = 5e-4
83
INIT_LR = 3e-2 * int(BATCH_SIZE / 256)
84
WARMUP_LR = 0.0
85
WARMUP_STEPS = 0
86
DIM = 2048
87
88
"""
89
## Data loading
90
91
Next, we will load the STL-10 dataset. STL-10 is a dataset consisting of 100k unlabelled
92
images, 5k labelled training images, and 10k labelled test images. Due to this distribution,
93
STL-10 is commonly used as a benchmark for contrastive learning models.
94
95
First lets load our unlabelled data
96
"""
97
train_ds = tfds.load("stl10", split="unlabelled")
98
train_ds = train_ds.map(
99
lambda entry: entry["image"], num_parallel_calls=tf.data.AUTOTUNE
100
)
101
train_ds = train_ds.map(
102
lambda image: tf.cast(image, tf.float32), num_parallel_calls=tf.data.AUTOTUNE
103
)
104
train_ds = train_ds.shuffle(buffer_size=8 * BATCH_SIZE, reshuffle_each_iteration=True)
105
106
"""
107
Next, we need to prepare some labelled samples.
108
This is done so that TensorFlow similarity can probe the learned embedding to ensure
109
that the model is learning appropriately.
110
"""
111
112
(x_raw_train, y_raw_train), ds_info = tfds.load(
113
"stl10", split="train", as_supervised=True, batch_size=-1, with_info=True
114
)
115
x_raw_train, y_raw_train = tf.cast(x_raw_train, tf.float32), tf.cast(
116
y_raw_train, tf.float32
117
)
118
x_test, y_test = tfds.load(
119
"stl10",
120
split="test",
121
as_supervised=True,
122
batch_size=-1,
123
)
124
x_test, y_test = tf.cast(x_test, tf.float32), tf.cast(y_test, tf.float32)
125
126
"""
127
In self supervised learning, queries and indexes are labeled subset datasets used to
128
evaluate the quality of the produced latent embedding. The following code assembles
129
these datasets:
130
"""
131
132
# Compute the indices for query, index, val, and train splits
133
query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []
134
for cid in range(ds_info.features["label"].num_classes):
135
idxs = tf.random.shuffle(tf.where(y_raw_train == cid))
136
idxs = tf.reshape(idxs, (-1,))
137
query_idxs.extend(idxs[:100]) # 200 query examples per class
138
index_idxs.extend(idxs[100:200]) # 200 index examples per class
139
val_idxs.extend(idxs[200:300]) # 100 validation examples per class
140
train_idxs.extend(idxs[300:]) # The remaining are used for training
141
142
random.shuffle(query_idxs)
143
random.shuffle(index_idxs)
144
random.shuffle(val_idxs)
145
random.shuffle(train_idxs)
146
147
148
def create_split(idxs: list) -> tuple:
149
x, y = [], []
150
for idx in idxs:
151
x.append(x_raw_train[int(idx)])
152
y.append(y_raw_train[int(idx)])
153
return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(
154
np.array(y), dtype=tf.int64
155
)
156
157
158
x_query, y_query = create_split(query_idxs)
159
x_index, y_index = create_split(index_idxs)
160
x_val, y_val = create_split(val_idxs)
161
x_train, y_train = create_split(train_idxs)
162
163
PRE_TRAIN_STEPS_PER_EPOCH = tf.data.experimental.cardinality(train_ds) // BATCH_SIZE
164
PRE_TRAIN_STEPS_PER_EPOCH = int(PRE_TRAIN_STEPS_PER_EPOCH.numpy())
165
166
print(
167
tabulate(
168
[
169
["train", tf.data.experimental.cardinality(train_ds), None],
170
["val", x_val.shape, y_val.shape],
171
["query", x_query.shape, y_query.shape],
172
["index", x_index.shape, y_index.shape],
173
["test", x_test.shape, y_test.shape],
174
],
175
headers=["# of Examples", "Labels"],
176
)
177
)
178
179
"""
180
## Augmentations
181
182
Self-supervised networks require at least two augmented "views" of each example.
183
This can be created using a dataset and an augmentation function.
184
The dataset treats each example in the batch as its own class and then the augment
185
function produces two separate views for each example.
186
187
This means the resulting batch will yield tuples containing the two views, i.e.,
188
Tuple[(BATCH_SIZE, 32, 32, 3), (BATCH_SIZE, 32, 32, 3)].
189
190
Using KerasCV, it is trivial to construct an augmenter that performs as the one
191
described in the original SimSiam paper. Lets do that below.
192
"""
193
194
target_size = (96, 96)
195
crop_area_factor = (0.08, 1)
196
aspect_ratio_factor = (3 / 4, 4 / 3)
197
grayscale_rate = 0.2
198
color_jitter_rate = 0.8
199
brightness_factor = 0.2
200
contrast_factor = 0.8
201
saturation_factor = (0.3, 0.7)
202
hue_factor = 0.2
203
204
augmenter = keras.Sequential(
205
[
206
cv_layers.RandomFlip("horizontal"),
207
cv_layers.RandomCropAndResize(
208
target_size,
209
crop_area_factor=crop_area_factor,
210
aspect_ratio_factor=aspect_ratio_factor,
211
),
212
cv_layers.RandomApply(
213
cv_layers.Grayscale(output_channels=3), rate=grayscale_rate
214
),
215
cv_layers.RandomApply(
216
cv_layers.RandomColorJitter(
217
value_range=(0, 255),
218
brightness_factor=brightness_factor,
219
contrast_factor=contrast_factor,
220
saturation_factor=saturation_factor,
221
hue_factor=hue_factor,
222
),
223
rate=color_jitter_rate,
224
),
225
],
226
)
227
228
"""
229
Next, lets pass our images through this pipeline.
230
Note that KerasCV supports batched augmentation, so batching before
231
augmentation dramatically improves performance
232
233
"""
234
235
236
@tf.function()
237
def process(img):
238
return augmenter(img), augmenter(img)
239
240
241
def prepare_dataset(dataset):
242
dataset = dataset.repeat()
243
dataset = dataset.shuffle(1024)
244
dataset = dataset.batch(BATCH_SIZE)
245
dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE)
246
return dataset.prefetch(tf.data.AUTOTUNE)
247
248
249
train_ds = prepare_dataset(train_ds)
250
251
val_ds = tf.data.Dataset.from_tensor_slices(x_val)
252
val_ds = prepare_dataset(val_ds)
253
254
print("train_ds", train_ds)
255
print("val_ds", val_ds)
256
257
"""
258
Lets visualize our pairs using the `tfsim.visualization` utility package.
259
"""
260
display_imgs = next(train_ds.as_numpy_iterator())
261
max_pixel = np.max([display_imgs[0].max(), display_imgs[1].max()])
262
min_pixel = np.min([display_imgs[0].min(), display_imgs[1].min()])
263
264
tfsim.visualization.visualize_views(
265
views=display_imgs,
266
num_imgs=16,
267
views_per_col=8,
268
max_pixel_value=max_pixel,
269
min_pixel_value=min_pixel,
270
)
271
272
"""
273
## Model Creation
274
275
Now that our data and augmentation pipeline is setup, we can move on to
276
constructing the contrastive learning pipeline. First, lets produce a backbone.
277
For this task, we will use a KerasCV ResNet18 model as the backbone.
278
"""
279
280
281
def get_backbone(input_shape):
282
inputs = layers.Input(shape=input_shape)
283
x = inputs
284
x = keras_cv.models.ResNet18(
285
input_shape=input_shape,
286
include_rescaling=True,
287
include_top=False,
288
pooling="avg",
289
)(x)
290
return tfsim.models.SimilarityModel(inputs, x)
291
292
293
backbone = get_backbone((96, 96, 3))
294
backbone.summary()
295
296
"""
297
This MLP is common to all the self-supervised models and is typically a stack of 3
298
layers of the same size. However, SimSiam only uses 2 layers for the smaller CIFAR
299
images. Having too much capacity in the models can make it difficult for the loss to
300
stabilize and converge.
301
302
Note: This is the model output that is returned by `ContrastiveModel.predict()` and
303
represents the distance based embedding. This embedding can be used for the KNN
304
lookups and matching classification metrics. However, when using the pre-train
305
model for downstream tasks, only the `ContrastiveModel.backbone` is used.
306
"""
307
308
309
def get_projector(input_dim, dim, activation="relu", num_layers: int = 3):
310
inputs = tf.keras.layers.Input((input_dim,), name="projector_input")
311
x = inputs
312
313
for i in range(num_layers - 1):
314
x = tf.keras.layers.Dense(
315
dim,
316
use_bias=False,
317
kernel_initializer=tf.keras.initializers.LecunUniform(),
318
name=f"projector_layer_{i}",
319
)(x)
320
x = tf.keras.layers.BatchNormalization(
321
epsilon=1.001e-5, name=f"batch_normalization_{i}"
322
)(x)
323
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_{i}")(
324
x
325
)
326
x = tf.keras.layers.Dense(
327
dim,
328
use_bias=False,
329
kernel_initializer=tf.keras.initializers.LecunUniform(),
330
name="projector_output",
331
)(x)
332
x = tf.keras.layers.BatchNormalization(
333
epsilon=1.001e-5,
334
center=False, # Page:5, Paragraph:2 of SimSiam paper
335
scale=False, # Page:5, Paragraph:2 of SimSiam paper
336
name=f"batch_normalization_ouput",
337
)(x)
338
# Metric Logging layer. Monitors the std of the layer activations.
339
# Degenerate solutions colapse to 0 while valid solutions will move
340
# towards something like 0.0220. The actual number will depend on the layer size.
341
o = tfsim.layers.ActivationStdLoggingLayer(name="proj_std")(x)
342
projector = tf.keras.Model(inputs, o, name="projector")
343
return projector
344
345
346
projector = get_projector(input_dim=backbone.output.shape[-1], dim=DIM, num_layers=2)
347
projector.summary()
348
349
350
"""
351
Finally, we must construct the predictor. The predictor is used in SimSiam, and is a
352
simple stack of two MLP layers, containing a bottleneck in the hidden layer.
353
"""
354
355
356
def get_predictor(input_dim, hidden_dim=512, activation="relu"):
357
inputs = tf.keras.layers.Input(shape=(input_dim,), name="predictor_input")
358
x = inputs
359
360
x = tf.keras.layers.Dense(
361
hidden_dim,
362
use_bias=False,
363
kernel_initializer=tf.keras.initializers.LecunUniform(),
364
name="predictor_layer_0",
365
)(x)
366
x = tf.keras.layers.BatchNormalization(
367
epsilon=1.001e-5, name="batch_normalization_0"
368
)(x)
369
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_0")(x)
370
371
x = tf.keras.layers.Dense(
372
input_dim,
373
kernel_initializer=tf.keras.initializers.LecunUniform(),
374
name="predictor_output",
375
)(x)
376
# Metric Logging layer. Monitors the std of the layer activations.
377
# Degenerate solutions colapse to 0 while valid solutions will move
378
# towards something like 0.0220. The actual number will depend on the layer size.
379
o = tfsim.layers.ActivationStdLoggingLayer(name="pred_std")(x)
380
predictor = tf.keras.Model(inputs, o, name="predictor")
381
return predictor
382
383
384
predictor = get_predictor(input_dim=DIM, hidden_dim=512)
385
predictor.summary()
386
387
388
"""
389
## Training
390
391
First, we need to initialize our training model, loss, and optimizer.
392
"""
393
loss = tfsim.losses.SimSiamLoss(projection_type="cosine_distance", name="simsiam")
394
395
contrastive_model = tfsim.models.ContrastiveModel(
396
backbone=backbone,
397
projector=projector,
398
predictor=predictor, # NOTE: simiam requires predictor model.
399
algorithm="simsiam",
400
name="simsiam",
401
)
402
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
403
initial_learning_rate=INIT_LR,
404
decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,
405
)
406
wd_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
407
initial_learning_rate=WEIGHT_DECAY,
408
decay_steps=PRE_TRAIN_EPOCHS * PRE_TRAIN_STEPS_PER_EPOCH,
409
)
410
optimizer = tfa.optimizers.SGDW(
411
learning_rate=lr_decayed_fn, weight_decay=wd_decayed_fn, momentum=0.9
412
)
413
414
"""
415
Next we can compile the model the same way you compile any other Keras model.
416
"""
417
418
contrastive_model.compile(
419
optimizer=optimizer,
420
loss=loss,
421
)
422
423
"""
424
We track the training using `EvalCallback`.
425
`EvalCallback` creates an index at the end of each epoch and provides a proxy for the
426
nearest neighbor matching classification using `binary_accuracy`.
427
Calculates how often the query label matches the derived lookup label.
428
429
Accuracy is technically (TP+TN)/(TP+FP+TN+FN), but here we filter all
430
queries above the distance threshold. In the case of binary matching, this
431
makes all the TPs and FPs below the distance threshold and all the TNs and
432
FNs above the distance threshold.
433
434
As we are only concerned with the matches below the distance threshold, the
435
accuracy simplifies to TP/(TP+FP) and is equivalent to the precision with
436
respect to the unfiltered queries. However, we also want to consider the
437
query coverage at the distance threshold, i.e., the percentage of queries
438
that return a match, computed as (TP+FP)/(TP+FP+TN+FN). Therefore, we can
439
take $ precision \times query_coverage $ to produce a measure that capture
440
the precision scaled by the query coverage. This simplifies down to the
441
binary accuracy presented here, giving TP/(TP+FP+TN+FN).
442
"""
443
444
DATA_PATH = Path("./")
445
log_dir = DATA_PATH / "models" / "logs" / f"{loss.name}_{time.time()}"
446
chkpt_dir = DATA_PATH / "models" / "checkpoints" / f"{loss.name}_{time.time()}"
447
448
callbacks = [
449
tfsim.callbacks.EvalCallback(
450
tf.cast(x_query, tf.float32),
451
y_query,
452
tf.cast(x_index, tf.float32),
453
y_index,
454
metrics=["binary_accuracy"],
455
k=1,
456
tb_logdir=log_dir,
457
),
458
tf.keras.callbacks.TensorBoard(
459
log_dir=log_dir,
460
histogram_freq=1,
461
update_freq=100,
462
),
463
tf.keras.callbacks.ModelCheckpoint(
464
filepath=chkpt_dir,
465
monitor="val_loss",
466
mode="min",
467
save_best_only=True,
468
save_weights_only=True,
469
),
470
]
471
472
"""
473
All that is left to do is run fit()!
474
"""
475
476
print(train_ds)
477
print(val_ds)
478
history = contrastive_model.fit(
479
train_ds,
480
epochs=PRE_TRAIN_EPOCHS,
481
steps_per_epoch=PRE_TRAIN_STEPS_PER_EPOCH,
482
validation_data=val_ds,
483
validation_steps=VAL_STEPS_PER_EPOCH,
484
callbacks=callbacks,
485
)
486
487
488
"""
489
## Plotting and Evaluation
490
"""
491
492
plt.figure(figsize=(15, 4))
493
plt.subplot(1, 3, 1)
494
plt.plot(history.history["loss"])
495
plt.grid()
496
plt.title(f"{loss.name} - loss")
497
498
plt.subplot(1, 3, 2)
499
plt.plot(history.history["proj_std"], label="proj")
500
if "pred_std" in history.history:
501
plt.plot(history.history["pred_std"], label="pred")
502
plt.grid()
503
plt.title(f"{loss.name} - std metrics")
504
plt.legend()
505
506
plt.subplot(1, 3, 3)
507
plt.plot(history.history["binary_accuracy"], label="acc")
508
plt.grid()
509
plt.title(f"{loss.name} - match metrics")
510
plt.legend()
511
512
plt.show()
513
514
515
"""
516
## Fine Tuning on the Labelled Data
517
518
As a final step we will fine tune a classifier on 10% of the training data. This will
519
allow us to evaluate the quality of our learned representation. First, we handle data
520
loading:
521
"""
522
523
eval_augmenter = keras.Sequential(
524
[
525
keras_cv.layers.RandomCropAndResize(
526
(96, 96), crop_area_factor=(0.8, 1.0), aspect_ratio_factor=(1.0, 1.0)
527
),
528
keras_cv.layers.RandomFlip(mode="horizontal"),
529
]
530
)
531
532
eval_train_ds = tf.data.Dataset.from_tensor_slices(
533
(x_raw_train, tf.keras.utils.to_categorical(y_raw_train, 10))
534
)
535
eval_train_ds = eval_train_ds.repeat()
536
eval_train_ds = eval_train_ds.shuffle(1024)
537
eval_train_ds = eval_train_ds.map(lambda x, y: (eval_augmenter(x), y), tf.data.AUTOTUNE)
538
eval_train_ds = eval_train_ds.batch(BATCH_SIZE)
539
eval_train_ds = eval_train_ds.prefetch(tf.data.AUTOTUNE)
540
541
eval_val_ds = tf.data.Dataset.from_tensor_slices(
542
(x_test, tf.keras.utils.to_categorical(y_test, 10))
543
)
544
eval_val_ds = eval_val_ds.repeat()
545
eval_val_ds = eval_val_ds.shuffle(1024)
546
eval_val_ds = eval_val_ds.batch(BATCH_SIZE)
547
eval_val_ds = eval_val_ds.prefetch(tf.data.AUTOTUNE)
548
549
"""
550
## Benchmark Against a Naive Model
551
552
Finally, lets setup a naive model that does not leverage the unlabeled data corpus.
553
"""
554
555
TEST_EPOCHS = 50
556
TEST_STEPS_PER_EPOCH = x_raw_train.shape[0] // BATCH_SIZE
557
558
559
def get_eval_model(img_size, backbone, total_steps, trainable=True, lr=1.8):
560
backbone.trainable = trainable
561
inputs = tf.keras.layers.Input((img_size, img_size, 3), name="eval_input")
562
x = backbone(inputs, training=trainable)
563
o = tf.keras.layers.Dense(10, activation="softmax")(x)
564
model = tf.keras.Model(inputs, o)
565
cosine_decayed_lr = tf.keras.experimental.CosineDecay(
566
initial_learning_rate=lr, decay_steps=total_steps
567
)
568
opt = tf.keras.optimizers.SGD(cosine_decayed_lr, momentum=0.9)
569
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])
570
return model
571
572
573
no_pt_eval_model = get_eval_model(
574
img_size=96,
575
backbone=get_backbone((96, 96, 3)),
576
total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
577
trainable=True,
578
lr=1e-3,
579
)
580
no_pt_history = no_pt_eval_model.fit(
581
eval_train_ds,
582
batch_size=BATCH_SIZE,
583
epochs=TEST_EPOCHS,
584
steps_per_epoch=TEST_STEPS_PER_EPOCH,
585
validation_data=eval_val_ds,
586
validation_steps=VAL_STEPS_PER_EPOCH,
587
)
588
589
"""
590
Pretty bad results! Lets try fine-tuning our SimSiam pretrained model:
591
"""
592
593
pt_eval_model = get_eval_model(
594
img_size=96,
595
backbone=contrastive_model.backbone,
596
total_steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
597
trainable=False,
598
lr=30.0,
599
)
600
pt_eval_model.summary()
601
pt_history = pt_eval_model.fit(
602
eval_train_ds,
603
batch_size=BATCH_SIZE,
604
epochs=TEST_EPOCHS,
605
steps_per_epoch=TEST_STEPS_PER_EPOCH,
606
validation_data=eval_val_ds,
607
validation_steps=VAL_STEPS_PER_EPOCH,
608
)
609
610
"""
611
All that is left to do is evaluate the models:
612
"""
613
print(
614
"no pretrain",
615
no_pt_eval_model.evaluate(
616
eval_val_ds,
617
steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
618
),
619
)
620
print(
621
"pretrained",
622
pt_eval_model.evaluate(
623
eval_val_ds,
624
steps=TEST_EPOCHS * TEST_STEPS_PER_EPOCH,
625
),
626
)
627
"""
628
Awesome! Our pretrained model stomped the non-pretrained model.
629
This accuracy is quite good for a ResNet18 on the STL-10 dataset.
630
For better results, try using an EfficientNetV2B0 instead.
631
Unfortunately, this will require a higher end graphics card as
632
SimSiam has a minimum batch size of 512.
633
634
## Conclusion
635
636
TensorFlow Similarity can be used to easily train KerasCV models using
637
contrastive algorithms such as SimCLR, SimSiam and BarlowTwins.
638
This allows you to leverage large corpuses of unlabelled data in your
639
model trainining pipeline.
640
641
Some follow-up exercises to this tutorial:
642
643
- Train a [`keras_cv.models.EfficientNetV2B0`](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/efficientnet_v2.py)
644
on STL-10
645
- Experiment with other data augmentation techniques in pretraining
646
- Train a model using the [BarlowTwins implementation](https://github.com/tensorflow/similarity/blob/master/examples/unsupervised_hello_world.ipynb) in TensorFlow similarity
647
- Try pretraining on your own dataset
648
"""
649
650