Path: blob/master/deep_learning/tabular/bert_ctr/bert_ctr.ipynb
1480 views
BERT CTR
In many real world applications, our dataset is often times multi-modal [3]. Where there are textual information, structured/tabular data, or even data represented in image, video, audio format. In this article, we'll look at the problem of multimodal fusion, where our goal is to jointly model information from two or more modalities to make a prediction, specifically text and structured/tabular data.
Data
KDD 2012 Track 2 [1] is focused on predicting click through rate (CTR). This task involves not only textual data related to search query and ad, but also includes additional structured/tabular data such as user's demographic information gender/age, etc.
Note, while this dataset we're using is slightly more representative of real world application, it unfortunately still has some limitations. The most notable one being, for anonymity reasons, this dataset provides tokens represented as hash values instead of actual raw text. This entails we can't directly leverage pre-trained tokenizer/model.
Another dataset with similar characteristics is Baidu-ULTR [2]. This dataset is advantageous in terms of its larger scale: featuring 1.2B search sessions, and its richness: offering diverse display information such as position, display height, and so on for studying un-biased learning to rank. As well as diversified user behavior, including click, skip, dwell time for exploring multi-task learning. Though, again, it also has similar limitation where raw text is provided in a hashed id manner.
We'll be working with a sampled version of the original raw dataset. Readers can refer to kdd2012_track2_preprocess_data.py for a sample pre-processing script.
These subsequent function/class for using deep learning on tabular data closely follows the ones introduced in Deep Learning for Tabular Data - PyTorch. [nbviewer][html]
We'll specify a config mapping for our input features, which will be utilized in both our batch collate function as well as the model itself. This config mapping stores the features we want to use as keys, and different value/enum specifying whether the field is of text/token, numerical or categorical type. Which inform our pipeline about the embedding size required for a categorical type, how many numerical fields are there to initiate the dense/feed forward layers, whether to apply tokenization/padding for text/token type fields.
Model
The model architecture we'll be implementing looks something along the lines of:
Feeding textual input through BERT style transformer block, converting categorical features into a low dimensonal embedding, packing all of our numerical features together.
The output from 3 different field groups are concatenated together before feeding them into subsequent feed forward layers.
The next code block involves defining a config and model class following huggingface transformer's class structure [3]. This allows us to leverage its Trainer class for training and evaluating our models instead of writing custom training loops.
Rest of the code block defines boilerplate code for leveraging huggingface transformer's Trainer, as well as defining a compute_metrics
function for calculating standard binary classification related metrics.
We'll also train a AutoModelForSequenceClassification
model as our baseline. This allows us to compare whether incorporating tabular features on top of text features will result in performance gains.
There're many ways tips and tricks for training our model in a transfer learning setting [7] [8] [9]. e.g.
Two stage training. The model is first trained to convergence with a frozen encoder, then unfrozen and fine-tuned along with the entire model.
Differential (a.k.a layer-wise) learning rate. The idea is to have different learning rates for different parts/layers of our model. Intuition behind this is initial layers in a pre-trained model likely have learned general features which we don't want to modify as much, hence we should be setting a smaller learning rate for those parts of the model.
From the experiment above, including some tabular/structure data into the model out-performs the text only variant by 300 basis point (1 absolute percent) 0.709 versus 0.679 on this particular preprocessed dataset. The gap will most likely widen even more if we devote more effort into feature engineering additional features out of the raw data.
End Notes
Recent works such as DeText [4], CTR-BERT [5], TwinBERT [6] also explored leveraging BERT for click through rate (CTR) problems. Some common themes from these works includes:
Cross architecture knowledge distillation process, where large cross attention teacher are distilled to smaller bi-encoder student. Both teacher and student incorporates tabular features via late fusion manner. One of the considerations in building industry scale CTR prediction model is achieving fast real time inference. Bi-encoder model address this by decoupling the encoding process for queries and documents. This enables one of the computationally expensive BERT-based document encoders to be pre-computed offline and cached (cache can be refreshed on a periodic basis, e.g. daily). During run time, the inference cost reduces to a query encoder, tabular features, as well as a late fusion cross layer such as MLP (multi-layer perceptron).
N-pass ranker. To leverage cross encoder in production setting, we can have a N-pass ranker. i.e. A more lightweight ranker that can quickly weed out a large amounts of irrelevant document, while applying the more performant but computationally expensive ranker on a smaller subset of the documents.
In domain pre-training. A common practice of using pre-trained BERT model is to directly leverage the public checkpoint that have been trained on general domain such as Wikipedia, and fine-tune it on our specific task. Continuing to pre-train on in-domain data using un/self-supervised learning loss like MLM (Masked Language Modeling) proves to be effective for vertical specific applications.
Reference
[1] KDD Cup 2012, Track 2 - Predict the click-through rate of ads given the query and user information.
[2] Haitao Mao, Dawei Yin, Xiaokai Chu, et al. - A Large Scale Search Dataset for Unbiased Learning to Rank (2022)
[3] How to Incorporate Tabular Data with HuggingFace Transformers
[4] Weiwei Guo, Xiaowei Liu, Sida Wang, Huiji Gao, Ananth Sankar, Zimeng Yang, Qi Guo, Liang Zhang, Bo Long, Bee-Chung Chen, Deepak Agarwal - DeText: A Deep Text Ranking Framework with BERT (2020)
[5] Aashiq Muhamed, et al. - CTR-BERT: Cost-effective knowledge distillation for billion-parameter teacher models (2021)
[6] Wenhao Lu, Jian Jiao, Ruofei Zhang - TwinBERT: Distilling Knowledge to Twin-Structured BERT Models for Efficient Retrieval (2022)
[7] Chi Sun, Xipeng Qiu, Yige Xu, Xuanjing Huang - How to Fine-Tune BERT for Text Classification? (2019)
[8] Getting the most out of Transfer Learning
[9] Going the extra mile, lessons learnt from Kaggle on how to train better NLP models (Part II)