Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/metric_learning.py
8035 views
1
"""
2
Title: Metric learning for image similarity search
3
Author: [Mat Kelcey](https://twitter.com/mat_kelcey)
4
Date created: 2020/06/05
5
Last modified: 2020/06/09
6
Description: Example of using similarity metric learning on CIFAR-10 images.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
Metric learning aims to train models that can embed inputs into a high-dimensional space
14
such that "similar" inputs, as defined by the training scheme, are located close to each
15
other. These models once trained can produce embeddings for downstream systems where such
16
similarity is useful; examples include as a ranking signal for search or as a form of
17
pretrained embedding model for another supervised problem.
18
19
For a more detailed overview of metric learning see:
20
21
* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)
22
* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)
23
"""
24
25
"""
26
## Setup
27
28
Set Keras backend to tensorflow.
29
"""
30
import os
31
32
os.environ["KERAS_BACKEND"] = "tensorflow"
33
34
import random
35
import matplotlib.pyplot as plt
36
import numpy as np
37
import tensorflow as tf
38
from collections import defaultdict
39
from PIL import Image
40
from sklearn.metrics import ConfusionMatrixDisplay
41
import keras
42
from keras import layers
43
44
"""
45
## Dataset
46
47
For this example we will be using the
48
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
49
"""
50
51
from keras.datasets import cifar10
52
53
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
54
55
x_train = x_train.astype("float32") / 255.0
56
y_train = np.squeeze(y_train)
57
x_test = x_test.astype("float32") / 255.0
58
y_test = np.squeeze(y_test)
59
60
"""
61
To get a sense of the dataset we can visualise a grid of 25 random examples.
62
63
64
"""
65
66
height_width = 32
67
68
69
def show_collage(examples):
70
box_size = height_width + 2
71
num_rows, num_cols = examples.shape[:2]
72
73
collage = Image.new(
74
mode="RGB",
75
size=(num_cols * box_size, num_rows * box_size),
76
color=(250, 250, 250),
77
)
78
for row_idx in range(num_rows):
79
for col_idx in range(num_cols):
80
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
81
collage.paste(
82
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
83
)
84
85
# Double size for visualisation.
86
collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))
87
return collage
88
89
90
# Show a collage of 5x5 random images.
91
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
92
examples = x_train[sample_idxs]
93
show_collage(examples)
94
95
"""
96
Metric learning provides training data not as explicit `(X, y)` pairs but instead uses
97
multiple instances that are related in the way we want to express similarity. In our
98
example we will use instances of the same class to represent similarity; a single
99
training instance will not be one image, but a pair of images of the same class. When
100
referring to the images in this pair we'll use the common metric learning names of the
101
`anchor` (a randomly chosen image) and the `positive` (another randomly chosen image of
102
the same class).
103
104
To facilitate this we need to build a form of lookup that maps from classes to the
105
instances of that class. When generating data for training we will sample from this
106
lookup.
107
"""
108
109
class_idx_to_train_idxs = defaultdict(list)
110
for y_train_idx, y in enumerate(y_train):
111
class_idx_to_train_idxs[y].append(y_train_idx)
112
113
class_idx_to_test_idxs = defaultdict(list)
114
for y_test_idx, y in enumerate(y_test):
115
class_idx_to_test_idxs[y].append(y_test_idx)
116
117
"""
118
For this example we are using the simplest approach to training; a batch will consist of
119
`(anchor, positive)` pairs spread across the classes. The goal of learning will be to
120
move the anchor and positive pairs closer together and further away from other instances
121
in the batch. In this case the batch size will be dictated by the number of classes; for
122
CIFAR-10 this is 10.
123
"""
124
125
num_classes = 10
126
127
128
class AnchorPositivePairs(keras.utils.Sequence):
129
def __init__(self, num_batches):
130
super().__init__()
131
self.num_batches = num_batches
132
133
def __len__(self):
134
return self.num_batches
135
136
def __getitem__(self, _idx):
137
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
138
for class_idx in range(num_classes):
139
examples_for_class = class_idx_to_train_idxs[class_idx]
140
anchor_idx = random.choice(examples_for_class)
141
positive_idx = random.choice(examples_for_class)
142
while positive_idx == anchor_idx:
143
positive_idx = random.choice(examples_for_class)
144
x[0, class_idx] = x_train[anchor_idx]
145
x[1, class_idx] = x_train[positive_idx]
146
return x
147
148
149
"""
150
We can visualise a batch in another collage. The top row shows randomly chosen anchors
151
from the 10 classes, the bottom row shows the corresponding 10 positives.
152
"""
153
154
examples = next(iter(AnchorPositivePairs(num_batches=1)))
155
156
show_collage(examples)
157
158
"""
159
## Embedding model
160
161
We define a custom model with a `train_step` that first embeds both anchors and positives
162
and then uses their pairwise dot products as logits for a softmax.
163
"""
164
165
166
class EmbeddingModel(keras.Model):
167
def train_step(self, data):
168
# Note: Workaround for open issue, to be removed.
169
if isinstance(data, tuple):
170
data = data[0]
171
anchors, positives = data[0], data[1]
172
173
with tf.GradientTape() as tape:
174
# Run both anchors and positives through model.
175
anchor_embeddings = self(anchors, training=True)
176
positive_embeddings = self(positives, training=True)
177
178
# Calculate cosine similarity between anchors and positives. As they have
179
# been normalised this is just the pair wise dot products.
180
similarities = keras.ops.einsum(
181
"ae,pe->ap", anchor_embeddings, positive_embeddings
182
)
183
184
# Since we intend to use these as logits we scale them by a temperature.
185
# This value would normally be chosen as a hyper parameter.
186
temperature = 0.2
187
similarities /= temperature
188
189
# We use these similarities as logits for a softmax. The labels for
190
# this call are just the sequence [0, 1, 2, ..., num_classes] since we
191
# want the main diagonal values, which correspond to the anchor/positive
192
# pairs, to be high. This loss will move embeddings for the
193
# anchor/positive pairs together and move all other pairs apart.
194
sparse_labels = keras.ops.arange(num_classes)
195
loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
196
197
# Calculate gradients and apply via optimizer.
198
gradients = tape.gradient(loss, self.trainable_variables)
199
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
200
201
# Update and return metrics (specifically the one for the loss value).
202
for metric in self.metrics:
203
# Calling `self.compile` will by default add a `keras.metrics.Mean` loss
204
if metric.name == "loss":
205
metric.update_state(loss)
206
else:
207
metric.update_state(sparse_labels, similarities)
208
209
return {m.name: m.result() for m in self.metrics}
210
211
212
"""
213
Next we describe the architecture that maps from an image to an embedding. This model
214
simply consists of a sequence of 2d convolutions followed by global pooling with a final
215
linear projection to an embedding space. As is common in metric learning we normalise the
216
embeddings so that we can use simple dot products to measure similarity. For simplicity
217
this model is intentionally small.
218
"""
219
220
inputs = layers.Input(shape=(height_width, height_width, 3))
221
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
222
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
223
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
224
x = layers.GlobalAveragePooling2D()(x)
225
embeddings = layers.Dense(units=8, activation=None)(x)
226
embeddings = layers.UnitNormalization()(embeddings)
227
228
model = EmbeddingModel(inputs, embeddings)
229
230
"""
231
Finally we run the training. On a Google Colab GPU instance this takes about a minute.
232
"""
233
model.compile(
234
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
235
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
236
)
237
238
history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
239
240
plt.plot(history.history["loss"])
241
plt.show()
242
243
"""
244
## Testing
245
246
We can review the quality of this model by applying it to the test set and considering
247
near neighbours in the embedding space.
248
249
First we embed the test set and calculate all near neighbours. Recall that since the
250
embeddings are unit length we can calculate cosine similarity via dot products.
251
"""
252
253
near_neighbours_per_example = 10
254
255
embeddings = model.predict(x_test)
256
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
257
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
258
259
"""
260
As a visual check of these embeddings we can build a collage of the near neighbours for 5
261
random examples. The first column of the image below is a randomly selected image, the
262
following 10 columns show the nearest neighbours in order of similarity.
263
"""
264
265
num_collage_examples = 5
266
267
examples = np.empty(
268
(
269
num_collage_examples,
270
near_neighbours_per_example + 1,
271
height_width,
272
height_width,
273
3,
274
),
275
dtype=np.float32,
276
)
277
for row_idx in range(num_collage_examples):
278
examples[row_idx, 0] = x_test[row_idx]
279
anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])
280
for col_idx, nn_idx in enumerate(anchor_near_neighbours):
281
examples[row_idx, col_idx + 1] = x_test[nn_idx]
282
283
show_collage(examples)
284
285
"""
286
We can also get a quantified view of the performance by considering the correctness of
287
near neighbours in terms of a confusion matrix.
288
289
Let us sample 10 examples from each of the 10 classes and consider their near neighbours
290
as a form of prediction; that is, does the example and its near neighbours share the same
291
class?
292
293
We observe that each animal class does generally well, and is confused the most with the
294
other animal classes. The vehicle classes follow the same pattern.
295
"""
296
297
confusion_matrix = np.zeros((num_classes, num_classes))
298
299
# For each class.
300
for class_idx in range(num_classes):
301
# Consider 10 examples.
302
example_idxs = class_idx_to_test_idxs[class_idx][:10]
303
for y_test_idx in example_idxs:
304
# And count the classes of its near neighbours.
305
for nn_idx in near_neighbours[y_test_idx][:-1]:
306
nn_class_idx = y_test[nn_idx]
307
confusion_matrix[class_idx, nn_class_idx] += 1
308
309
# Display a confusion matrix.
310
labels = [
311
"Airplane",
312
"Automobile",
313
"Bird",
314
"Cat",
315
"Deer",
316
"Dog",
317
"Frog",
318
"Horse",
319
"Ship",
320
"Truck",
321
]
322
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
323
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
324
plt.show()
325
326