Path: blob/master/examples/vision/ipynb/metric_learning_tf_similarity.ipynb
3236 views
Metric learning for image similarity search using TensorFlow Similarity
Author: Owen Vallis
Date created: 2021/09/30
Last modified: 2022/02/29
Description: Example of using similarity metric learning on CIFAR-10 images.
Overview
This example is based on the "Metric learning for image similarity search" example. We aim to use the same data set but implement the model using TensorFlow Similarity.
Metric learning aims to train models that can embed inputs into a high-dimensional space such that "similar" inputs are pulled closer to each other and "dissimilar" inputs are pushed farther apart. Once trained, these models can produce embeddings for downstream systems where such similarity is useful, for instance as a ranking signal for search or as a form of pretrained embedding model for another supervised problem.
For a more detailed overview of metric learning, see:
Setup
This tutorial will use the TensorFlow Similarity library to learn and evaluate the similarity embedding. TensorFlow Similarity provides components that:
Make training contrastive models simple and fast.
Make it easier to ensure that batches contain pairs of examples.
Enable the evaluation of the quality of the embedding.
TensorFlow Similarity can be installed easily via pip, as follows:
Dataset samplers
We will be using the CIFAR-10 dataset for this tutorial.
For a similarity model to learn efficiently, each batch must contains at least 2 examples of each class.
To make this easy, tf_similarity offers Sampler
objects that enable you to set both the number of classes and the minimum number of examples of each class per batch.
The train and validation datasets will be created using the TFDatasetMultiShotMemorySampler
object. This creates a sampler that loads datasets from TensorFlow Datasets and yields batches containing a target number of classes and a target number of examples per class. Additionally, we can restrict the sampler to only yield the subset of classes defined in class_list
, enabling us to train on a subset of the classes and then test how the embedding generalizes to the unseen classes. This can be useful when working on few-shot learning problems.
The following cell creates a train_ds sample that:
Loads the CIFAR-10 dataset from TFDS and then takes the
examples_per_class_per_batch
.Ensures the sampler restricts the classes to those defined in
class_list
.Ensures each batch contains 10 different classes with 8 examples each.
We also create a validation dataset in the same way, but we limit the total number of examples per class to 100 and the examples per class per batch is set to the default of 2.
Visualize the dataset
The samplers will shuffle the dataset, so we can get a sense of the dataset by plotting the first 25 images.
The samplers provide a get_slice(begin, size)
method that allows us to easily select a block of samples.
Alternatively, we can use the generate_batch()
method to yield a batch. This can allow us to check that a batch contains the expected number of classes and examples per class.
Embedding model
Next we define a SimilarityModel
using the Keras Functional API. The model is a standard convnet with the addition of a MetricEmbedding
layer that applies L2 normalization. The metric embedding layer is helpful when using Cosine
distance as we only care about the angle between the vectors.
Additionally, the SimilarityModel
provides a number of helper methods for:
Indexing embedded examples
Performing example lookups
Evaluating the classification
Evaluating the quality of the embedding space
See the TensorFlow Similarity documentation for more details.
Similarity loss
The similarity loss expects batches containing at least 2 examples of each class, from which it computes the loss over the pairwise positive and negative distances. Here we are using MultiSimilarityLoss()
(paper), one of several losses in TensorFlow Similarity. This loss attempts to use all informative pairs in the batch, taking into account the self-similarity, positive-similarity, and the negative-similarity.
Indexing
Now that we have trained our model, we can create an index of examples. Here we batch index the first 200 validation examples by passing the x and y to the index along with storing the image in the data parameter. The x_index
values are embedded and then added to the index to make them searchable. The y_index
and data parameters are optional but allow the user to associate metadata with the embedded example.
Calibration
Once the index is built, we can calibrate a distance threshold using a matching strategy and a calibration metric.
Here we are searching for the optimal F1 score while using K=1 as our classifier. All matches at or below the calibrated threshold distance will be labeled as a Positive match between the query example and the label associated with the match result, while all matches above the threshold distance will be labeled as a Negative match.
Additionally, we pass in extra metrics to compute as well. All values in the output are computed at the calibrated threshold.
Finally, model.calibrate()
returns a CalibrationResults
object containing:
"cutpoints"
: A Python dict mapping the cutpoint name to a dict containing theClassificationMetric
values associated with a particular distance threshold, e.g.,"optimal" : {"acc": 0.90, "f1": 0.92}
."thresholds"
: A Python dict mappingClassificationMetric
names to a list containing the metric's value computed at each of the distance thresholds, e.g.,{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}
.
Visualization
It may be difficult to get a sense of the model quality from the metrics alone. A complementary approach is to manually inspect a set of query results to get a feel for the match quality.
Here we take 10 validation examples and plot them with their 5 nearest neighbors and the distances to the query example. Looking at the results, we see that while they are imperfect they still represent meaningfully similar images, and that the model is able to find similar images irrespective of their pose or image illumination.
We can also see that the model is very confident with certain images, resulting in very small distances between the query and the neighbors. Conversely, we see more mistakes in the class labels as the distances become larger. This is one of the reasons why calibration is critical for matching applications.
Metrics
We can also plot the extra metrics contained in the CalibrationResults
to get a sense of the matching performance as the distance threshold increases.
The following plots show the Precision, Recall, and F1 Score. We can see that the matching precision degrades as the distance increases, but that the percentage of the queries that we accept as positive matches (recall) grows faster up to the calibrated distance threshold.
We can also take 100 examples for each class and plot the confusion matrix for each example and their nearest match. We also add an "extra" 10th class to represent the matches above the calibrated distance threshold.
We can see that most of the errors are between the animal classes with an interesting number of confusions between Airplane and Bird. Additionally, we see that only a few of the 100 examples for each class returned matches outside of the calibrated distance threshold.
No Match
We can plot the examples outside of the calibrated threshold to see which images are not matching any indexed examples.
This may provide insight into what other examples may need to be indexed or surface anomalous examples within the class.
Visualize clusters
One of the best ways to quickly get a sense of the quality of how the model is doing and understand it's short comings is to project the embedding into a 2D space.
This allows us to inspect clusters of images and understand which classes are entangled.