Path: blob/master/deep_learning/tabular/deep_learning_learning_to_rank.ipynb
1480 views
Learning to Rank 101
Suppose we have a query denoted as , and its corresponding set of documents denoted as . Our objective is to learn a function such that will produce an ordered collection of documents, , in descending order of relevance. Where the exact definition of relevance can vary between different applications.
In general, there're three main types of loss function for training this function: pointwise, pairwise, listwise. In this article, we'll be giving a 101 introduction to each of these variants, list out their pros and cons, as well as implementing these loss functions ourselves and training the tabular deep learning module using huggingface Trainer.
Pointwise
For pointwise approach, the aforementioned ranking task is formulated as a classic regression or classification task. The function is simplied to , treating the relevance assessment of each query document independently. Suppose we have two queries that yield 2 and 3 corresponding documents respectively:
The training examples are creating by pairing each query with its associated documents.
Pros:
Simplicity: Existing machine learning algorithms and loss functions we might be more familiar with can be directly applied in the pointwise setting.
Cons:
Sub-Optimal Results: This approach may not fully capitalize on the complete information available across the entire document list for a given query, potentially leading to sub-optimal outcomes.
Pairwise
In pairwise approach, the goal remains identical to pointwise, in which we're learning a pointwise scoring function , but training instances are constructed using pairs of documents from the same query:
This approach introduces a new set of binary pairwise labels, derived by comparing individual relevance scores within each pair. For instance, considering the first query , if (totally irrelevant) for and (highly relevant) for , a new label is assigned to the document pair . This transforms the task into a binary classification learning problem.
To learn the pointwise function in a pairwise manner, RankNet [1] proposed modeling the score difference probabilistically using logistic function:
Where if document is deemed a better match than document (), the probability of the scoring function assigning a higher score to than should be close to 1. This reflects the model's effort to understand how to score document pairs based on query information, effectively learning to rank.
Pros:
Pairwise Ranking Learning: Compared with pointwise model, pairwise model learns how to rank in a pairwise context, by focusing on correct classification of ordered pairs, it is potentially approximating the ultimate ranking task involving a list of documents.
Cons:
Pointwise Scoring: The scoring function remains pointwise, implying relative information among different documents with the same query is not yet fully harnessed.
Uneven pairs: If not careful with data curation where the number of documents varies largely from query to query, then the trained model may be biased towards queries with more document pairs.
Listwise
Listwise approach addresses the ranking problem in its natural form, specifically it takes in a list of instances during training so the group structure is maintained.
One of the first proposed approach is ListNet [2], where the loss is calculated between a predicted probability distribution versus target probability distribution.
Where:
denotes features representing a particular query-document pair.
represents each document's non-negative relevance labels.
encodes the probability of appearing at the top of the ranked list, referred to as top one probability. Given these two distributions, their loss is can be measured by a standard cross entropy loss.
Pros:
Direct Ranking Learning. By formulating the problem in its native form instead of relying on proxies, it is a theoretically sound solution to approach a ranking task. i.e. Minimizing the errors in ranking the entire document list as opposed to document pairs.
Cons:
Pointwise Scoring. Scoring function is still pointwise, which could be sub-optimal.
Note that:
Different from pointwise approach that also uses softmax function and cross entropy loss, in listwise loss function, both of these are conducted over all items within the same list.
There're subsuquent works [3] that provides theoretically justifications for ListNet's softmax cross entropy loss. In particular they show that in a binary labeled setup, the loss bounds two popular learning to rank evaluation metrics: Mean Reciprocal Rank and Normalized Discounted Cumulative Gain.
Data
Each row represents a query-document pair in the dataset, with columns structured as follows:
First column contains the relevance label for this specific pair. A higher relevance label signifies a greater relevance between the query and the document
Second column contains the query ID.
Subsequent columns contain various features.
The row concludes with a comment about the pair, which includes the document's ID.
Model
Learning to rank based approaches regardless of whether it's pairwise or listwise requires data from the same context/group to be in the same mini-batch. Given our input data is in a pointwise format where each row represents a query-document pair, some additional transformations are necessary. We'll use some toy examples to illustrate these points before showing the full blown implementation.
In pairwise loss, the trick is to expand the original 1d tensor for computing a pairwise difference.
Context's pairwise difference signals which pairs belong to the same context. Pairs belonging to different context should be masked out during the loss calculation.
Label's pairwise difference is symmetric, and we only need to consider pairs where the difference is positive and convert it to a binary label.
Prediction score's pairwise difference will be input to our loss function.
In listwise loss, loss are calculated once for all data within the same context/group. Hence apart from the predicted scores and target/labels, we also need to know which examples belong to the same context/group. One common way to do this is to assume the examples are already sorted by context, and have a group length variable which stores each group's instance count.
In the example below, we have 5 observations belonging to 2 contexts/groups. [3, 2]
means the fist 3 items belongs to the first group, whereas the next 2 items belongs to the second group. torch.split
then splits the original single tensor into grouped chunks, in a vanilla implementation, we can loop through each group and compute the cross entropy loss.
A cleaner solution would be to pad these grouped chunks and perform the calculation in a batched manner. The padding values do matter, where we'll use an extremely small prediction score with 0 as its corresponding label.
When training a learning to rank model, an important detail is to prevent data shuffling in our data loader so data from the same context can be grouped together in a mini-batch. At the time of writing this, huggingface transformer's Trainer will by default enable shuffling on our train dataset. We quickly override that behaviour by using get_test_dataloader
even for our train dataloader. This addresses the issue with the least amount of code with the quirk being now per_device_eval_batch_size
will also be used for per_device_train_batch_size
, which can be a bit confusing.
In computational advertising, particularly its click through rate application, pointwise loss function still remains to be the dominating approach due to:
Calibrated Score. For ad auction to properly take place, a model's prediction score needs to be treated as a click probability instead of a score that only denotes ordering or perference.
Data Sparsity. Pairwise/listwise approach relies on events that have positive outcomes. These approaches compare records with positive events to those without for building their loss functions. However, in practice, these positive events can be sparse, meaning there are far fewer instances of user engagement (clicks) than non-engagement. This sparsity implies that using pairwise or listwise methods would result in a significant loss of available data and might hinder downstream performance. Pointwise approach doesn't have this limitation and can make better use of available data.
To preserve the benefits from both pointwise and pairwise/listwise approaches, an intuitive way is to calculate weighted average of the two loss functions to take advantage from both sides [4] [5]. Given the sparsity of pairwise data, it can be beneficial to create pseudo pairs to prevent the model to be biased towards classification loss. e.g. we can form more pairs artificially by grouping impressions from different request but under the same session and user.
Reference
[1] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, Greg Hullender - Learning to Rank using Gradient Descent - 2005
[2] Zhe Cao, Tao Qin, Ming-Feng Tsai, et al. - Learning to Rank: From Pairwise Approach to Listwise Approach - 2007
[3] Sebastian Bruch, Xuanhui Wang, Michael Bendersky, Marc Najork - An Analysis of the Softmax Cross Entropy Loss for Learning-to-Rank with Binary Relevance - 2019
[4] Cheng Li, Yue Lu, Qiaozhu Mei, Dong Wang, Sandeep Pandey - Click-through Prediction for Advertising in Twitter Timeline - 2015
[5] Shuguang Han et al. - Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model - 2022
[6] Tao Qin, Tie-Yan Liu - Introducing LETOR 4.0 Datasets - 2013
[7] LETOR: Learning to Rank for Information Retrieval - LETOR 4.0
[8] Microsoft Learning to Rank Datasets
[9] Blog: Istella Learning to Rank dataset
[10] Blog: Learning to rank is good for your ML career - Part 2: let’s implement ListNet!