Path: blob/master/notebooks/book1/15/entailment_attention_mlp_torch.ipynb
1192 views
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/entailment_attention_mlp_jax.ipynb
Textual entailment classifier using an MLP plus attention
In textual entailment, the input is 2 sentences (premise and hypothesis), and the output is a label, specifying if P entails H, P contradicts H, or neither. (This is also called "natural language inference".) We use attention to align hypothesis to premise and vice versa, then compare the aligned words to estimate similarity between the sentences, and pass the weighted similarities to an MLP.
Based on sec 15.5 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-attention.html
Data
We use SNLI (Stanford Natural Language Inference) dataset described in sec 15.4 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html.
Show first 3 training examples and their labels (“0”, “1”, and “2” correspond to “entailment”, “contradiction”, and “neutral”, respectively ).
Model
The model is described in the book. Below we just give the code.
Attending
We define attention weights where is the embedding of the 'th token from the premise, is the embedding of the 'th token from the hypothesis, and is an MLP that maps from the embedding space to another hidden space.
The 'th word in A computes a weighted average of "relevant" words in B, and vice versa, as follows:
Comparing
We concatenate word in A, , with its "soft counterpart" in B, , and vice versa, and then pass this through another MLP to get a "comparison vector" for each input location.
Aggregation
We sum-pool the "comparison vectors" for each input sentence, and then pass the pair of poolings to yet another MLP to generate the final classification.
Putting it altogether
We use a pre-trained embedding of size E=100. The (attend) function maps from to hiddens. The (compare) function maps to . The (aggregate) function maps to 3 outputs.