Path: blob/master/deep_learning/contrastive/sentence_transformer.ipynb
2604 views
Sentence Transformer: Training Bi-Encoder via Contrastive Loss
In modern recommendation system, it is common to have two major distinct stages, a recall/retrieval stage, and a ranking stage [2] [3]. In this article, we'll be discussing these two stages in the context of a pre-trained encoder model, where different types of architectures are commonly used at different stages to reflect their own strength.

Recall/Retrieval focuses on candidate generation, given we have a million things be it inventories, passages, etc. stored in our database, we need a way to efficiently retrieve a subset of them that are more relevant for the given context. In encoder model's paradigm, this means we need a way to represent our incoming context be it query, questions, etc. in a vector representation, often times referred to as sentence embedding and perform a similarity search between that with entities that are also already represented in a sentence embedding. This step is often done using the bi-encoder architecture, where we can pass individual sentences/entities through our encoder and then perform similarity search using metrics such as cosine similarity. Being able to input individual entities through our encoder is the key to an effient retrieval stage, where we need to quickly scan through our database, as it means we can often times pre-compute these results apriori.

These subsets are then passed to the ranker, where given we now have our pairs, we need to score these pairs so the system can assign a score to rank these retrieved inventories in a sorted order. This step is often done using cross encoder, a.k.a. cross attention architectures. Cross encoder doesn't produce sentence embeddings for our data but instead relies on classification mechansim for out input pairs, hence whenever we have a fixed set of data pairs we wish to score, cross encoder often times achieves better performance than their bi encoder counterparts even when using less training data.

In this document, our focus will be how to train bi-encoder style architecture using contrastive loss [1] [4].
Contrastive Loss
For Bi-encoder, our input data at the bare minimum needs to contain positive pairs, where we need to be creative and define what's the definition for positive pairs based on our use case. These pairs are often referred to as anchors and positives, . Given our dataset, the standard architecture for these type of task is a siamese network. Each base arm involves using a pre-trained encoder followed by a pooling operator for deriving a fix sized sentence embedding vector. During each step we feed both our anchor and positive one at time through the same encoder plus pooling arm. This encoder plus pooling arm's weights are often tied for our pairs. Tying weights ensure similar entities are mapped to similar locations in the representation space.
For training this network, what we wish to accomplish is have our anchor and positive pairs and become closer in the vector space, whereas and some random negative examples becomes distant in vector space.
Where is a similarity function such as cosine similarity or dot product. This can then be treated as a classification task using cross entropy loss. is often times sampled using in batch negatives, meaning for each example in a batch, other examples' positives will be taken as its negatives, this effectively re-uses the already sampled data in our current batch without the need to implement additional negative sampling logic.
Dataset
For our dataset, we'll be using Stanford Natural Language Inference (SNLI) and Multi-Genre NLI (MNLI), refer to its link for more details on these two datasets. For constrastive loss, we will filter out premise and hypothesis pairs that has a label 0 assigned to it. These are positive samples where the premise suggests the hypothesis.
We'll also use the Semantic Textual Similarity Benchmark (stsb) dataset each contains sentence pairs with human annotated similarity score from 1 to 5 for evaluating our model's performance
As usual, we tokenize our anchor and positive text.
Here we also define a customize collate function for batching our dataset.
Model
This section defines our transformer followed by pooling encoder, siamese network, evaluation metric, then proceed with using transformer library's Trainer class for training our network.
For our evaluation metric, we will first compute the cosine similarity of our pairs' embedding, then calculate pearson and spearman correlation with our ground truth similarity score.
Evaluation
Quantitative benchmark on our test set shows we are capable of achieving a 80+% spearman correlation. This section performs a manual inspection on some sample input sentences, where some are more similar compared to others.
Given our sentence embedding, we then perform a cosine similarity for finding semantically similar sentences. Notice our trained model was able to capture sentences that have very few matching words but are semantically similar.
The comparison is not shown here, but we can compare this with results from out of the box pre-trained models and see that further training these networks with constrastive loss makes day and night differences. i.e. if we were to directly retrieve sentence embeddings by either pooling its last hidden output/layer or extracting first token ([CLS] token)'s embedding, it often times yield rather poor sentence embeddings as mentioned in sentence BERT [5].
In this document, we walked through a modern baseline approach for training sentence transformers. Modern in a sense that we are building on top of pre-trained language models instead of training from scratch, baseline meaning there're a lot more we can do to improve these models' performance, we'll take a look at those tricks in future documents.
Reference
[1] Blog: Next-Gen Sentence Embeddings with Multiple Negatives Ranking Loss
[2] Blog: Real World Recommendation System – Part 1
[3] Sentence Transformers Documentation: Retrieve & Re-Rank
[4] Github: Sentence Transformers Training Natural Language Inference
[5] Paper: Nils Reimers, Iryna Gurevych - Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks - 2019