Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/image_similarity.ipynb
Views: 2535
Building an Image Similarity System with 🤗 Transformers
In this notebook, you'll learn to build an image similarity system with 🤗 Transformers. Finding out the similarity between a query image and potential candidates is an important use case for information retrieval systems, reverse image search, for example. All the system is trying to answer is, given a query image and a set of candidate images, which images are the most similar to the query image.
🤗 Datasets library
This notebook leverages the datasets
library as it seamlessly supports parallel processing, which will come in handy when building this system.
Any model and dataset
Although the notebook uses a ViT-based model (nateraw/vit-base-beans
) and a particular dataset (Beans), it can be easily extended to use other models supporting vision modality and other image datasets. Some notable models you could try:
The approach presented in the notebook can potentially be extended to other modalities as well.
Before we start, let's install the datasets
and transformers
libraries.
|████████████████████████████████| 5.8 MB 15.1 MB/s
|████████████████████████████████| 451 kB 75.5 MB/s
|████████████████████████████████| 182 kB 54.0 MB/s
|████████████████████████████████| 7.6 MB 53.7 MB/s
|████████████████████████████████| 212 kB 74.8 MB/s
|████████████████████████████████| 132 kB 73.9 MB/s
|████████████████████████████████| 127 kB 80.5 MB/s
If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
Building an image similarity system
To build this system, we first need to define how we want to compute the similarity between two images. One widely popular practice is to compute dense representations (embeddings) of the given images and then use the cosine similarity metric to determine how similar the two images are.
For this tutorial, we'll be using “embeddings” to represent images in vector space. This gives us a nice way to meaningfully compress the high-dimensional pixel space of images (224 x 224 x 3, for example) to something much lower dimensional (768, for example). The primary advantage of doing this is the reduced computation time in the subsequent steps.
Don't worry if these things do not make sense at all. We will discuss these things in more detail shortly.
Loading a base model to compute embeddings
"Embeddings" encode the semantic information of images. To compute the embeddings from the images, we'll use a vision model that has some understanding of how to represent the input images in the vector space. This type of models is also commonly referred to as image encoders.
For loading the model, we leverage the AutoModel
class. It provides an interface for us to load any compatible model checkpoint from the Hugging Face Hub. Alongside the model, we also load the processor associated with the model for data preprocessing.
In this case, the checkpoint was obtained by fine-tuning a Vision Transformer based model on the beans
dataset. To learn more about the model, just click the model link and check out its model card.
The warning is telling us that the underlying model didn't use anything from the classifier
. Why did we not use AutoModelForImageClassification
?
This is because we want to obtain dense representations of the images and not discrete categories, which are what AutoModelForImageClassification
would have provided.
Then comes another question - why this checkpoint in particular?
We're using a specific dataset to build the system as mentioned earlier. So, instead of using a generalist model (like the ones trained on the ImageNet-1k dataset, for example), it's better to use a model that has been fine-tuned on the dataset being used. That way, the underlying model has a better understanding of the input images.
Now that we have a model for computing the embeddings, we need some candidate images to query against.
Loading the dataset for candidate images
To find out similar images, we need a set of candidate images to query against. We'll use the train
split of the beans
dataset for that purpose. To know more about the dataset, just follow the link and explore its dataset card.
The dataset has got three columns / features:
Next, we set up two dictionaries for our upcoming utilities:
label2id
which maps the class labels to integers.id2label
doing the opposite oflabel2id
.
With these components, we can proceed to build our image similarity system. To demonstrate this, we'll use 100 samples from the candidate image dataset to keep the overall runtime short.
Below, you can find a pictorial overview of the process underlying fetching similar images.
Breaking down the above figure a bit, we have:
Extract the embeddings from the candidate images (
candidate_subset
) storing them in a matrix.Take a query image and extract its embeddings.
Iterate over the embedding matrix (computed in step 1) and compute the similarity score between the query embedding and the current candidate embedding. We usually maintain a dictionary-like mapping maintaining a correspondence between some identifier of the candidate image and the similarity scores.
Sort the mapping structure w.r.t the similarity scores and return the identifiers underlying. We use these identifiers to fetch the candidate samples.
In the next cells, we implement the above procedure in code.
Next, for convenience, we create a list containing the identifiers of the candidate images.
We'll use the matrix of the embeddings of all the candidate images for computing the similarity scores with a query image. We have already computed the candidate image embeddings. In the next cell, we just gather them together in a matrix.
We'll use the cosine similarity to compute the similarity score in between two embedding vectors. We'll then use it to fetch similar candidate samples given a query sample.
Now, we can put these utilities to test.
We can notice that given the query image, candidate images having similar labels were fetched.
Now, we can visualize all this.
We now have a working image similarity system. But in reality, you'll be dealing with many more candidate images. So considering that, our current procedure has got multiple drawbacks:
If we store the embeddings as is, the memory requirements can shoot up quickly, especially when dealing with millions of candidate images. However, the embeddings are 768-d in our case, which can still be relatively high in the large-scale regime. They have high-dimensional embeddings that directly affect the subsequent computations involved in the retrieval part. So, if we can somehow reduce the dimensionality of the embeddings without disturbing their meaning, we can still maintain a good trade-off between speed and retrieval quality.
So, in the following sections, we'll implement the hashing utilities to optimize the runtime of our image similarity system.
Random projection and locality-sensitive hashing (LSH)
We can choose to just compute the embeddings with our base model and then apply a similarity metric for the system. But in realistic settings, the embeddings are still high dimensional (in this case (768, )
). This eats up storage and also increases the query time.
To mitigate that effect, we'll implement the following things:
First, we reduce the dimensionality of the embeddings with random projection. The main idea is that if the distance between a group of vectors can roughly be preserved on a plane, the dimensionality of the plane can be further reduced.
We then compute the bitwise hash values of the projected vectors to determine their hash buckets. Similar images will likely be closer in the embedding space. Therefore, they will likely also have the same hash values and are likely to go into the same hash bucket. From a deployment perspective, bitwise hash values are cheaper to store and operate on. If you're unfamiliar with the relevant concepts of hashing, this resource could be helpful.
Following is a pictorial representation of the hashing process (figure source):
Next, we define a utility that can be mapped to our dataset for computing hashes of the training images in a parallel manner.
Next, we build three utility classes building our hash tables:
Table
LSH
BuildLSHTable
Collectively, these classes implement Locality Sensitive Hashing (the idea locally close points share the same hashes).
Disclaimer: Some code has been used from this resource for writing these classes.
The Table
class
The Table
class has two methods:
add()
lets us build a dictionary mapping the hashes of the candidate images to their identifiers.query()
lets us take as inputs the query hashes and check if they exist in the table.
The table built in this class is referred to as a hash bucket.
The LSH
class
Our dimensionality reduction technique involves a degree of randomness. This can lead to a situation where similar images may not get mapped to the same hash bucket every time the process is run. To reduce this effect, we'll maintain multiple hash tables. The number of hash tables and the reduction dimensionality are the two key hyperparameters here.
The BuildLSHTable
class
It lets us:
build()
: build the hash tables.query()
with an input image aka the query image.
Notes on quantifying similarity:
We're using Jaccard index to quantify the similarity between the query image and the candidate images. As per Scikit Learn's documentation:
it is defined as the size of the intersection divided by the size of the union of two label sets.
Since we're using LSH to build the similarity system and the hashes are effectively sets, Jaccard index is a good metric to use here.
Building the LSH tables
To get a better a idea of how the tables are represented internally within lsh_builder
, let's investigate the contents of a single table.
We notice that for a given hash value, we have entries where labels are the same. Because of the randomness induced in the process, we may also notice some entries coming from different labels. It can happen for various reasons:
The reduction dimensionality is too small for compression.
The underlying images may be visually quite similar to one another yet have different labels.
In both of the above cases, experimentation is really the key to improving the results.
Now that the LSH tables have been built, we can use them to query them with images.
Inference
In this secton, we'll take query images from the test
split of our dataset and retrieve the similar images from the set of candidate images we have.
Not bad! Looks like our similarity system is fetching the correct images.
Storage-wise, we'd just have to store the lsh
attribute of lsh_builder
that has all the LSH tables:
After this, we can use it like so:
This way, instead of storing 768-d floating-point embedding vectors we're just storing 8-bit integers which are much more lightweight. Needless to say, this helps reduce the computation costs too.
Conclusion
That was a lot of content covered in this notebook. Be sure to take them step by step. In this section, we want to leave you with some extensions we provide regarding similarity systems.
🤗 Datasets offers direct integrations with FAISS which further simplifies the process of building similarity systems. To know more, you can check out the official documentation and this notebook. Additionally, we have created this Space application that lets you easily demo an image similarity system with more interactivity.
We encourage you to try these tools out and rebuild your own similarity system.