Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/tabular/deep_learning_tabular.ipynb
1480 views
Kernel: Python 3 (ipykernel)
# code for loading notebook's format import os # path : store the current path to convert back to it later path = os.getcwd() os.chdir(os.path.join('..', '..', 'notebook_format')) from formats import load_style load_style(css_style='custom2.css', plot_style=False)
os.chdir(path) %load_ext watermark %load_ext autoreload %autoreload 2 import os import yaml import torch import numpy as np import torch.nn as nn import torch.optim as optim import pandas as pd import sklearn.metrics as metrics from torch.nn import functional as F from torch.utils.data import DataLoader from datasets import ( Dataset, load_dataset, disable_progress_bar ) from sklearn.model_selection import train_test_split from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler from transformers import ( PretrainedConfig, PreTrainedModel, TrainingArguments, Trainer ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") %watermark -a 'Ethen' -d -t -v -u -p transformers,datasets,torch,numpy,pandas,sklearn
Author: Ethen Last updated: 2023-08-23 21:35:27 Python implementation: CPython Python version : 3.10.6 IPython version : 8.13.2 transformers: 4.31.0 datasets : 2.14.4 torch : 2.0.1 numpy : 1.23.2 pandas : 2.0.1 sklearn : 1.3.0

Deep Learning for Tabular Data

While deep learning's achievements are often highlighted in areas like computer vision and natural language processing, a lesser-discussed yet potent application involves applying deep learning to tabular data.

A key technique to maximize deep learning's potential with tabular data involves using embeddings for categorical variables [4]. This means representing categories in a lower-dimensional numeric space, capturing intricate relationships between them. For instance, this could reveal geographic connections between high-cardinality categorical features like zip codes, without explicit guidance. Even for continuous features such as days of the week, it's still worth exploring the potential advantages of treating them as categorical features and utilizing embeddings.

Furthermore, embeddings offer benefits beyond their initial use. Once trained, these embeddings can be employed in other contexts. For example, they can serve as features for tree-based models, granting them the enriched knowledge gleaned from deep learning. This cross-application of embeddings underscores their versatility and their ability to enhance various modeling techniques.

In this article, we'll be looking at some bare minimum steps for training a self-defined deep learning model and training it using huggingface Trainer.

Data Preprocessing

We'll be using a downsampled criteo dataset, which originated from a Kaggle competition [2]. Though after the competition ended, those original data files became unavailable on the platform. We turned to an alternative source for downloading a similar dataset [1]. Each row corresponds to a display ad served by Criteo. Positive (clicked) and negatives (non-clicked) examples have both been subsampled at different rates in order to reduce the dataset size. Fields in this dataset includes:

  • Label: Target variable that indicates if an ad was clicked (1) or not (0).

  • I1-I13: A total of 13 columns of integer features (mostly count features).

  • C1-C26: A total of 26 columns of categorical features. The values of these features have been hashed onto 32 bits for anonymization purposes.

Unfortunately, the meanings of these features aren't disclosed.

Note, there are many ways to implement a data preprocessing step, the baseline approach we'll be performing here is to:

  • Encode categorical columns as distinct numerical ids.

  • Standardize/Scale numerical columns.

  • Given the un-balanced dataset, we perform random downsampling on the negative class for our training set, while keeping the test set unbalanced.

# one time code for creating a sampled criteo dataset import gzip import pandas as pd def parse_criteo_data(gzip_file: str, num_records: int, output_path: str): """ Parse gzipped criteo dataset and save it into a tabular parquet format. """ columns = [ 'label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26' ] dtype = {} for col in columns: if "C" in col: dtype[col] = "string" elif col == "label": dtype[col] = "int" else: dtype[col] = "float" lines = [] with gzip.open(gzip_file, 'r') as f_in: for i in range(num_records): line = f_in.readline() line = str(line, encoding="utf-8") line = line.strip().split("\t") lines.append(line) df = pd.DataFrame(lines, columns=columns) df = df.replace("", None) df = df.astype(dtype) df.to_parquet(output_path, index=False) gzip_file = "day_0.gz" num_records = 1000000 output_path = "criteo_sampled.parquet" parse_criteo_data(gzip_file, num_records, output_path)
input_path = "criteo_data/criteo_sampled.parquet" df = pd.read_parquet(input_path) print(df.shape) df.head()
(1000000, 40)
sparse_features = ['C' + str(i) for i in range(1, 27)] dense_features = ['I' + str(i) for i in range(1, 14)] feature_names = dense_features + sparse_features df[sparse_features] = df[sparse_features].fillna('-1') df[dense_features] = df[dense_features].fillna(0)
# label encoding for categorical/sparse features # and scaling for numerical/dense features ordinal_encoder = OrdinalEncoder(min_frequency=30) df[sparse_features] = ordinal_encoder.fit_transform(df[sparse_features]) min_max_scaler = MinMaxScaler(feature_range=(0, 1)) df[dense_features] = min_max_scaler.fit_transform(df[dense_features]) df.head()
df["label"].value_counts()
label 0 970960 1 29040 Name: count, dtype: int64
def downsample_negative(df: pd.DataFrame, frac: float = 0.5, random_state: int = 1234): """Given a binary classification task with 0/1 labels, downsample negative class (class 0) with the specified fraction parameter. """ df_majority = df[df["label"] == 0] df_minority = df[df["label"] == 1] df_downsampled_majority = df_majority.sample(frac=frac, random_state=random_state) df_downsampled = pd.concat([df_downsampled_majority, df_minority]) # shuffle the combined data frame df_downsampled = df_downsampled.sample(frac=1, random_state=random_state).reset_index(drop=True) return df_downsampled
df_train, df_test = train_test_split(df, test_size=0.1, random_state=1234, stratify=df["label"]) df_test = df_test.reset_index(drop=True) df_train_downsampled = downsample_negative(df_train) dataset_train = Dataset.from_pandas(df_train_downsampled) dataset_test = Dataset.from_pandas(df_test) print(dataset_train) dataset_train[0]
Dataset({ features: ['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26'], num_rows: 463068 })
{'label': 0, 'I1': 0.0, 'I2': 0.005625, 'I3': 0.0, 'I4': 0.0, 'I5': 0.0007385524372230429, 'I6': 0.0, 'I7': 0.0, 'I8': 0.0, 'I9': 0.0005611672278338945, 'I10': 0.0, 'I11': 0.013513513513513514, 'I12': 0.000576345639540026, 'I13': 0.0, 'C1': 888.0, 'C2': 866.0, 'C3': 1134.0, 'C4': 292.0, 'C5': 379.0, 'C6': 0.0, 'C7': 1490.0, 'C8': 141.0, 'C9': 2.0, 'C10': 631.0, 'C11': 113.0, 'C12': 1406.0, 'C13': 4.0, 'C14': 400.0, 'C15': 882.0, 'C16': 11.0, 'C17': 3.0, 'C18': 138.0, 'C19': 11.0, 'C20': 114.0, 'C21': 1011.0, 'C22': 113.0, 'C23': 1209.0, 'C24': 835.0, 'C25': 2.0, 'C26': 23.0}

We'll specify a config mapping for tabular features that we'll be using across our batch collate function as well as model. This config mapping have features we wish to leverage as keys, and different value/enum specifying whether the field is numerical or categorical type. This will be beneficial to inform our model about the embedding size required for a categorical type as well as how many numerical fields are there to initiate the dense/feed forward layers.

# demonstrating the functionality with 1 numerical and 1 categorical feature tabular_features_config = { "I1": { "dtype": "numerical", }, "C1": { "dtype": "categorical", "vocab_size": len(ordinal_encoder.categories_[0]), "embedding_size": 32 } } def tabular_collate_fn(batch): """ Use in conjunction with Dataloader's collate_fn for tabular data. Returns ------- batch : dict Dictionary with two primary keys: tabular_inputs, and labels. Tabular inputs is a nested field, where each element is a feature_name -> float tensor mapping. e.g. { 'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])}, 'labels': tensor([0, 0]) } """ labels = [] tabular_inputs = {} for example in batch: label = example["label"] labels.append(label) for name in tabular_features_config: feature = example[name] if name not in tabular_inputs: tabular_inputs[name] = [feature] else: tabular_inputs[name].append(feature) for name in tabular_inputs: tabular_inputs[name] = torch.FloatTensor(tabular_inputs[name]) batch = { "tabular_inputs": tabular_inputs, "labels": torch.LongTensor(labels) } return batch
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn) batch = next(iter(data_loader)) batch
{'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])}, 'labels': tensor([0, 0])}
# specify all the features in config.yaml to prevent clunky display # we'll need to update vocabulary size for each categorical features # if we were to use a different dataset # for category in ordinal_encoder.categories_: # print(len(category)) with open("features_config.yaml", "r") as f_in: config = yaml.safe_load(f_in) tabular_features_config = config["tabular_features_config"] data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn) batch = next(iter(data_loader)) batch
{'tabular_inputs': {'I1': tensor([0., 0.]), 'I2': tensor([0.0056, 0.0180]), 'I3': tensor([0., 0.]), 'I4': tensor([0.0000, 0.0012]), 'I5': tensor([0.0007, 0.0002]), 'I6': tensor([0., 0.]), 'I7': tensor([0., 0.]), 'I8': tensor([0.0000, 0.0217]), 'I9': tensor([0.0006, 0.0017]), 'I10': tensor([0., 0.]), 'I11': tensor([0.0135, 0.0045]), 'I12': tensor([0.0006, 0.0168]), 'I13': tensor([0., 0.]), 'C1': tensor([ 888., 1313.]), 'C2': tensor([866., 276.]), 'C3': tensor([1134., 3660.]), 'C4': tensor([292., 213.]), 'C5': tensor([379., 963.]), 'C6': tensor([0., 2.]), 'C7': tensor([1490., 2072.]), 'C8': tensor([141., 401.]), 'C9': tensor([2., 4.]), 'C10': tensor([ 631., 1574.]), 'C11': tensor([ 113., 1499.]), 'C12': tensor([1406., 2965.]), 'C13': tensor([4., 8.]), 'C14': tensor([400., 198.]), 'C15': tensor([882., 36.]), 'C16': tensor([11., 26.]), 'C17': tensor([3., 2.]), 'C18': tensor([138., 75.]), 'C19': tensor([11., 3.]), 'C20': tensor([ 114., 1255.]), 'C21': tensor([1011., 1711.]), 'C22': tensor([ 113., 1374.]), 'C23': tensor([1209., 1466.]), 'C24': tensor([835., 38.]), 'C25': tensor([ 2., 32.]), 'C26': tensor([23., 18.])}, 'labels': tensor([0, 0])}

Model

Our model architecture mainly involves: Converting categorical features into a low dimensonal embedding, these embedding outputs are then concatenated with rest of the dense features before feeding them into subsequent feed forward layers.

def get_mlp_layers(input_dim: int, mlp_config): """ Construct MLP, a.k.a. Feed forward layers based on input config. Parameters ---------- input_dim : Input dimension for the first layer. mlp_config : list of dictionary with mlp spec. An example is shown below, the only mandatory parameter is hidden size. ``` [ { "hidden_size": 1024, "dropout_p": 0.1, "activation_function": "ReLU", "activation_function_kwargs": {}, "normalization_function": "LayerNorm" "normalization_function_kwargs": {"eps": 1e-05} } ] ``` Returns ------- nn.Sequential : Sequential layer converted from input mlp_config. If mlp_config is None, then this returned value will also be None. current_dim : Dimension for the last layer. """ if mlp_config is None: return None, input_dim layers = [] current_dim = input_dim for config in mlp_config: hidden_size = config["hidden_size"] dropout_p = config.get("dropout_p", 0.0) activation_function = config.get("activation_function") activation_function_kwargs = config.get("activation_function_kwargs", {}) normalization_function = config.get("normalization_function") normalization_function_kwargs = config.get("normalization_function_kwargs", {}) linear = nn.Linear(current_dim, hidden_size) layers.append(linear) if normalization_function: normalization = getattr(nn, normalization_function)(hidden_size, **normalization_function_kwargs) layers.append(normalization) if activation_function: activation = getattr(nn, activation_function)(**activation_function_kwargs) layers.append(activation) dropout = nn.Dropout(p=dropout_p) layers.append(dropout) current_dim = hidden_size return nn.Sequential(*layers), current_dim

The next code block involves defining a config and model class following huggingface transformer's class structure [3]. This allows us to leverage its Trainer class for training and evaluating our models instead of writing custom training loops.

class TabularModelConfig(PretrainedConfig): model_type = "tabular" def __init__( self, tabular_features_config=None, mlp_config=None, num_labels=2, **kwargs ): super().__init__(**kwargs) self.tabular_features_config = tabular_features_config self.mlp_config = mlp_config self.num_labels = num_labels
class TabularModel(PreTrainedModel): config_class = TabularModelConfig def __init__(self, config): super().__init__(config) self.config = config self.embeddings, output_dim = self.init_tabular_parameters(config.tabular_features_config) self.mlp, output_dim = get_mlp_layers(output_dim, config.mlp_config) self.head = nn.Linear(output_dim, config.num_labels) def forward(self, tabular_inputs, labels=None): concatenated_inputs = self.concatenate_tabular_inputs( tabular_inputs, self.config.tabular_features_config ) mlp_outputs = self.mlp(concatenated_inputs) logits = self.head(mlp_outputs) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) # at the bare minimum, we need to return loss as well as logits # for both training and evaluation return loss, F.softmax(logits, dim=-1) def concatenate_tabular_inputs(self, tabular_inputs, tabular_features_config): numerical_inputs = [] categorical_inputs = [] for name, config in tabular_features_config.items(): if config["dtype"] == "categorical": feature_name = f"{name}_embedding" share_embedding = config.get("share_embedding") if share_embedding: feature_name = f"{share_embedding}_embedding" embedding = self.embeddings[feature_name] features = tabular_inputs[name].type(torch.long) embed = embedding(features) categorical_inputs.append(embed) elif config["dtype"] == "numerical": features = tabular_inputs[name].type(torch.float32) if len(features.shape) == 1: features = features.unsqueeze(dim=1) numerical_inputs.append(features) if len(numerical_inputs) > 0: numerical_inputs = torch.cat(numerical_inputs, dim=-1) categorical_inputs.append(numerical_inputs) concatenated_inputs = torch.cat(categorical_inputs, dim=-1) return concatenated_inputs def init_tabular_parameters(self, tabular_features_config): embeddings = {} output_dim = 0 for name, config in tabular_features_config.items(): if config["dtype"] == "categorical": feature_name = f"{name}_embedding" # create new embedding layer for categorical features if share_embedding is None share_embedding = config.get("share_embedding") if share_embedding: share_embedding_config = tabular_features_config[share_embedding] embedding_size = share_embedding_config["embedding_size"] else: embedding_size = config["embedding_size"] embedding = nn.Embedding(config["vocab_size"], embedding_size) embeddings[feature_name] = embedding output_dim += embedding_size elif config["dtype"] == "numerical": output_dim += 1 return nn.ModuleDict(embeddings), output_dim
mlp_config = [ { "hidden_size": 1024, "dropout_p": 0.1, "activation_function": "ReLU", "normalization_function": "LayerNorm" }, { "hidden_size": 512, "dropout_p": 0.1, "activation_function": "ReLU", "normalization_function": "LayerNorm" }, { "hidden_size": 256, "dropout_p": 0.1, "activation_function": "ReLU", "normalization_function": "LayerNorm" } ] config = TabularModelConfig(tabular_features_config, mlp_config) model = TabularModel(config) print("# of parameters: ", model.num_parameters()) model
# of parameters: 51308904
TabularModel( (embeddings): ModuleDict( (C1_embedding): Embedding(179728, 64) (C2_embedding): Embedding(12325, 32) (C3_embedding): Embedding(11780, 32) (C4_embedding): Embedding(4156, 32) (C5_embedding): Embedding(10576, 32) (C6_embedding): Embedding(3, 3) (C7_embedding): Embedding(5850, 32) (C8_embedding): Embedding(1139, 32) (C9_embedding): Embedding(38, 16) (C10_embedding): Embedding(136421, 64) (C11_embedding): Embedding(33820, 32) (C12_embedding): Embedding(34916, 32) (C13_embedding): Embedding(10, 10) (C14_embedding): Embedding(1841, 32) (C15_embedding): Embedding(5445, 32) (C16_embedding): Embedding(56, 16) (C17_embedding): Embedding(4, 4) (C18_embedding): Embedding(615, 32) (C19_embedding): Embedding(14, 14) (C20_embedding): Embedding(187780, 64) (C21_embedding): Embedding(80020, 32) (C22_embedding): Embedding(165496, 64) (C23_embedding): Embedding(29741, 9) (C24_embedding): Embedding(7693, 32) (C25_embedding): Embedding(54, 16) (C26_embedding): Embedding(33, 16) ) (mlp): Sequential( (0): Linear(in_features=789, out_features=1024, bias=True) (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (2): ReLU() (3): Dropout(p=0.1, inplace=False) (4): Linear(in_features=1024, out_features=512, bias=True) (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (6): ReLU() (7): Dropout(p=0.1, inplace=False) (8): Linear(in_features=512, out_features=256, bias=True) (9): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (10): ReLU() (11): Dropout(p=0.1, inplace=False) ) (head): Linear(in_features=256, out_features=2, bias=True) )
# a quick test on a sample batch to ensure model forward pass runs output = model(**batch) output
(tensor(0.4340, grad_fn=<NllLossBackward0>), tensor([[0.6561, 0.3439], [0.6399, 0.3601]], grad_fn=<SoftmaxBackward0>))

Rest of the code block defines boilerplate code for leveraging huggingface transformer's Trainer, as well as defining a compute_metrics function for calculating standard binary classification related metrics.

def compute_metrics(eval_preds, round_digits: int = 3): y_pred, y_true = eval_preds y_score = y_pred[:, 1] log_loss = round(metrics.log_loss(y_true, y_score), round_digits) roc_auc = round(metrics.roc_auc_score(y_true, y_score), round_digits) pr_auc = round(metrics.average_precision_score(y_true, y_score), round_digits) return { 'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': log_loss }
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE" training_args = TrainingArguments( output_dir="tabular", num_train_epochs=5, learning_rate=0.001, per_device_train_batch_size=128, gradient_accumulation_steps=2, fp16=True, lr_scheduler_type="constant", evaluation_strategy="steps", eval_steps=1000, save_strategy="steps", save_steps=1000, save_total_limit=2, do_train=True, # we are collecting all tabular features into a single entry # tabular_inputs during collate function, this is to prevent # huggingface trainer from removing these features while processing # our dataset remove_unused_columns=False ) trainer = Trainer( model, args=training_args, data_collator=tabular_collate_fn, train_dataset=dataset_train, eval_dataset=dataset_test, compute_metrics=compute_metrics ) # for this dataset, Roc-AUC typically falls in the range of 0.725 - 0.727 train_output = trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn( Could not estimate the number of tokens of the input, floating-point operations will not be computed
# we can also leverage .predict for quickly performing batch prediction on a given # input dataset prediction_output = trainer.predict(dataset_test) prediction_output
PredictionOutput(predictions=array([[0.99899954, 0.0010005 ], [0.99010193, 0.00989806], [0.9948001 , 0.00519988], ..., [0.9937588 , 0.00624126], [0.83815056, 0.16184942], [0.869991 , 0.13000904]], dtype=float32), label_ids=array([0, 0, 0, ..., 0, 0, 0]), metrics={'test_loss': 0.13327662646770477, 'test_roc_auc': 0.719, 'test_pr_auc': 0.085, 'test_log_loss': 0.133, 'test_runtime': 73.7289, 'test_samples_per_second': 1356.321, 'test_steps_per_second': 169.54})

End Notes

In this post, we walked through a baseline workflow for training tabular datasets using deep neural networks in PyTorch. Many works have cited the success of applying deep neural networks as part of their core recommendation stack, e.g. Youtube Recommendation [5] or Airbnb Search [6] [7]. Apart from making the model bigger/deeper for improving performance, we'll briefly touch upon some of their key learnings to conclude this article.

Heterogeneous Signals

Compared to matrix factorization based algorithms in collaborative filtering, it's easier to add diverse set of signals into the model.

For instance, in the context of Youtube recommendation:

  • Recommendation system particularly benefit from specialized features that capture historical behavior. This includes user's previous interaction with the item, how many videos has the user watched from a specific channel? Time since the user last watched a video on a particular topic. Apart from numerical features that are hand crafted, we can also include user's watch or search history as variable length sequence and have it mapped into a dense embedding representation.

  • In a retrieval + ranking staged system, candidate generation information can be propagated into ranking phase as features. e.g. which sources nominated a candidate and its assigned score.

  • Categorical variables' embedding can be shared. e.g. a single video id embedding can be leveraged across various features (impression video id, last video id watched by the user, seed video id for the recommendation).

  • While popular tree based models are invariant to scaling of individual features, neural networks are quite sensitive to them. Therefore, Normalizing continuous features is a must. Normalization can be done via Min/Max scaling, log-transformation, or standard normalization.

  • Recommendation system often exhibit some form of bias towards the past, as they are trained using prior data. For Youtube, adding a content's age on a platform allows the model to represent a video's time dependent behavior.

e.g. For Airbnb search:

  • Domain knowledge proves to be valuable in feature normalization. e.g. When dealing with geo location represented by latitude and longitude, instead of using the raw coordinates, we can calculate the offset from map's center displayed to the user. This allows the model to learn distance based global properties rather than specifics of individual geography. For learning local geography, a new categorical feature is created by taking city specified in the query, and the level 12 S2 cell for a listing. A hashing function then maps these two values (city and S2 cells) into an integer. For example, given the query "San Francisco" and a listing near the Embarcadero (S2 cell 539058204), hashing {"San Francisco", 539058204} -> 71829521 creates this categorical feature.

  • Position bias is also a notable topic in literature. This bias emerges when historical logs are used for training subsequent models. Introducing position as a feature while regularizing by dropout was proposed as strategies for mitigating this bias.

Reference

  • [1] Criteo 1TB Click Logs dataset

  • [2] Kaggle Competition - Display Advertising Challenge

  • [3] Transformers Doc - Sharing custom models

  • [4] Blog: An Introduction to Deep Learning for Tabular Data

  • [5] Paul Covington, Jay Adams, Emre Sargin - Deep Neural Networks for YouTube Recommendations (2016)

  • [6] Malay Haldar, Mustafa Abdool, Prashant Ramanathan, Tao Xu, Shulin Yang, Huizhong Duan, Qing Zhang, Nick Barrow-Williams, Bradley C. Turnbull, Brendan M. Collins, Thomas Legrand - Applying Deep Learning To Airbnb Search (2018)

  • [7] Malay Haldar, Mustafa Abdool, Prashant Ramanathan, Tyler Sax, Lanbo Zhang, Aamir Mansawala, Shulin Yang, Bradley Turnbull, Junshuo Liao - Improving Deep Learning For Airbnb Search (2020)