Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/15/entailment_attention_mlp_torch.ipynb
1192 views
Kernel: Python 3

Open In Colab

Textual entailment classifier using an MLP plus attention

In textual entailment, the input is 2 sentences (premise and hypothesis), and the output is a label, specifying if P entails H, P contradicts H, or neither. (This is also called "natural language inference".) We use attention to align hypothesis to premise and vice versa, then compare the aligned words to estimate similarity between the sentences, and pass the weighted similarities to an MLP.

Based on sec 15.5 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-attention.html

import numpy as np import matplotlib.pyplot as plt import math from IPython import display try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch import nn from torch.nn import functional as F from torch.utils import data import collections import re import random import os import requests import zipfile import tarfile import hashlib import time np.random.seed(seed=1) torch.manual_seed(1) !mkdir figures # for saving plots
mkdir: cannot create directory ‘figures’: File exists

Data

We use SNLI (Stanford Natural Language Inference) dataset described in sec 15.4 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html.

# Required functions for downloading data def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir
DATA_HUB = dict() DATA_HUB["SNLI"] = ("https://nlp.stanford.edu/projects/snli/snli_1.0.zip", "9fcde07509c7e87ec61c640c1b2753d9041758e4") data_dir = download_extract("SNLI")
def read_snli(data_dir, is_train): """Read the SNLI dataset into premises, hypotheses, and labels.""" def extract_text(s): # Remove information that will not be used by us s = re.sub("\\(", "", s) s = re.sub("\\)", "", s) # Substitute two or more consecutive whitespace with space s = re.sub("\\s{2,}", " ", s) return s.strip() label_set = {"entailment": 0, "contradiction": 1, "neutral": 2} file_name = os.path.join(data_dir, "snli_1.0_train.txt" if is_train else "snli_1.0_test.txt") with open(file_name, "r") as f: rows = [row.split("\t") for row in f.readlines()[1:]] premises = [extract_text(row[1]) for row in rows if row[0] in label_set] hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set] labels = [label_set[row[0]] for row in rows if row[0] in label_set] return premises, hypotheses, labels

Show first 3 training examples and their labels (“0”, “1”, and “2” correspond to “entailment”, “contradiction”, and “neutral”, respectively ).

train_data = read_snli(data_dir, is_train=True) for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]): print("premise:", x0) print("hypothesis:", x1) print("label:", y)
premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is training his horse for a competition . label: 2 premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is at a diner , ordering an omelette . label: 1 premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is outdoors , on a horse . label: 0
test_data = read_snli(data_dir, is_train=False) for data in [train_data, test_data]: print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764] [3368, 3237, 3219]
def tokenize(lines, token="word"): """Split text lines into word or character tokens.""" if token == "word": return [line.split() for line in lines] elif token == "char": return [list(line) for line in lines] else: print("ERROR: unknown token type: " + token) class Vocab: """Vocabulary for text.""" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): if tokens is None: tokens = [] if reserved_tokens is None: reserved_tokens = [] # Sort according to frequencies counter = count_corpus(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The index for the unknown token is 0 self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens uniq_tokens += [token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens] self.idx_to_token, self.token_to_idx = [], dict() for token in uniq_tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] return [self.idx_to_token[index] for index in indices] def count_corpus(tokens): """Count token frequencies.""" # Here `tokens` is a 1D list or 2D list if len(tokens) == 0 or isinstance(tokens[0], list): # Flatten a list of token lists into a list of tokens tokens = [token for line in tokens for token in line] return collections.Counter(tokens)
class SNLIDataset(torch.utils.data.Dataset): """A customized dataset to load the SNLI dataset.""" def __init__(self, dataset, num_steps, vocab=None): self.num_steps = num_steps all_premise_tokens = tokenize(dataset[0]) all_hypothesis_tokens = tokenize(dataset[1]) if vocab is None: self.vocab = Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq=5, reserved_tokens=["<pad>"]) else: self.vocab = vocab self.premises = self._pad(all_premise_tokens) self.hypotheses = self._pad(all_hypothesis_tokens) self.labels = torch.tensor(dataset[2]) print("read " + str(len(self.premises)) + " examples") def _pad(self, lines): return torch.tensor([truncate_pad(self.vocab[line], self.num_steps, self.vocab["<pad>"]) for line in lines]) def __getitem__(self, idx): return (self.premises[idx], self.hypotheses[idx]), self.labels[idx] def __len__(self): return len(self.premises)
def load_data_snli(batch_size, num_steps=50): """Download the SNLI dataset and return data iterators and vocabulary.""" num_workers = 4 data_dir = download_extract("SNLI") train_data = read_snli(data_dir, True) test_data = read_snli(data_dir, False) train_set = SNLIDataset(train_data, num_steps) test_set = SNLIDataset(test_data, num_steps, train_set.vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter, train_set.vocab def truncate_pad(line, num_steps, padding_token): """Truncate or pad sequences.""" if len(line) > num_steps: return line[:num_steps] # Truncate return line + [padding_token] * (num_steps - len(line))
train_iter, test_iter, vocab = load_data_snli(128, 50) len(vocab)
read 549367 examples read 9824 examples
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
18678
for X, Y in train_iter: print(X[0].shape) print(X[1].shape) print(Y.shape) break
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
torch.Size([128, 50]) torch.Size([128, 50]) torch.Size([128])

Model

The model is described in the book. Below we just give the code.

Attending

We define attention weights eij=f(ai)Tf(bj) e_{ij} = f(a_i)^T f(b_j) where aiREa_i \in R^E is the embedding of the ii'th token from the premise, bjREb_j \in R^E is the embedding of the jj'th token from the hypothesis, and f:RERHf: R^E \rightarrow R^H is an MLP that maps from the embedding space to another hidden space.

def mlp(num_inputs, num_hiddens, flatten): net = [] net.append(nn.Dropout(0.2)) net.append(nn.Linear(num_inputs, num_hiddens)) net.append(nn.ReLU()) if flatten: net.append(nn.Flatten(start_dim=1)) net.append(nn.Dropout(0.2)) net.append(nn.Linear(num_hiddens, num_hiddens)) net.append(nn.ReLU()) if flatten: net.append(nn.Flatten(start_dim=1)) return nn.Sequential(*net)

The ii'th word in A computes a weighted average of "relevant" words in B, and vice versa, as follows: βi=j=1nexp(eij)k=1nexp(eik)bjαj=i=1mexp(eij)k=1mexp(ekj)ai \begin{align} \beta_i &= \sum_{j=1}^n \frac{\exp(e_{ij})}{\sum_{k=1}^n \exp(e_{ik})} b_j \\ \alpha_j &= \sum_{i=1}^m \frac{\exp(e_{ij})}{\sum_{k=1}^m \exp(e_{kj})} a_i \end{align}

class Attend(nn.Module): def __init__(self, num_inputs, num_hiddens, **kwargs): super(Attend, self).__init__(**kwargs) self.f = mlp(num_inputs, num_hiddens, flatten=False) def forward(self, A, B): # Shape of `A`/`B`: (`batch_size`, no. of words in sequence A/B, # `embed_size`) # Shape of `f_A`/`f_B`: (`batch_size`, no. of words in sequence A/B, # `num_hiddens`) f_A = self.f(A) f_B = self.f(B) # Shape of `e`: (`batch_size`, no. of words in sequence A, # no. of words in sequence B) e = torch.bmm(f_A, f_B.permute(0, 2, 1)) # Shape of `beta`: (`batch_size`, no. of words in sequence A, # `embed_size`), where sequence B is softly aligned with each word # (axis 1 of `beta`) in sequence A beta = torch.bmm(F.softmax(e, dim=-1), B) # Shape of `alpha`: (`batch_size`, no. of words in sequence B, # `embed_size`), where sequence A is softly aligned with each word # (axis 1 of `alpha`) in sequence B alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A) return beta, alpha

Comparing

We concatenate word ii in A, aia_i, with its "soft counterpart" in B, βi\beta_i, and vice versa, and then pass this through another MLP gg to get a "comparison vector" for each input location. vA,i=g([ai,βi]),  i=1,,mvB,j=g([bj,αj]),  j=1,,n \begin{align} v_{A,i} &= g([a_i, \beta_i]), \; i=1,\ldots, m \\ v_{B,j} &= g([b_j, \alpha_j]), \; j=1,\ldots, n \end{align}

class Compare(nn.Module): def __init__(self, num_inputs, num_hiddens, **kwargs): super(Compare, self).__init__(**kwargs) self.g = mlp(num_inputs, num_hiddens, flatten=False) def forward(self, A, B, beta, alpha): V_A = self.g(torch.cat([A, beta], dim=2)) V_B = self.g(torch.cat([B, alpha], dim=2)) return V_A, V_B

Aggregation

We sum-pool the "comparison vectors" for each input sentence, and then pass the pair of poolings to yet another MLP hh to generate the final classification.

vA=i=1mvA,ivB=j=1nvB,jy^=h([vA,vB])\begin{align} v_A &= \sum_{i=1}^m v_{A,i} \\ v_B &= \sum_{j=1}^n v_{B,j} \\ \hat{y} &= h([v_A, v_B]) \end{align}
class Aggregate(nn.Module): def __init__(self, num_inputs, num_hiddens, num_outputs, **kwargs): super(Aggregate, self).__init__(**kwargs) self.h = mlp(num_inputs, num_hiddens, flatten=True) self.linear = nn.Linear(num_hiddens, num_outputs) def forward(self, V_A, V_B): # Sum up both sets of comparison vectors V_A = V_A.sum(dim=1) V_B = V_B.sum(dim=1) # Feed the concatenation of both summarization results into an MLP Y_hat = self.linear(self.h(torch.cat([V_A, V_B], dim=1))) return Y_hat

Putting it altogether

We use a pre-trained embedding of size E=100. The ff (attend) function maps from E=100E=100 to H=200H=200 hiddens. The gg (compare) function maps 2E=2002E=200 to H=200H=200. The hh (aggregate) function maps 2H=4002H=400 to 3 outputs.

class DecomposableAttention(nn.Module): def __init__( self, vocab, embed_size, num_hiddens, num_inputs_attend=100, num_inputs_compare=200, num_inputs_agg=400, **kwargs ): super(DecomposableAttention, self).__init__(**kwargs) self.embedding = nn.Embedding(len(vocab), embed_size) self.attend = Attend(num_inputs_attend, num_hiddens) self.compare = Compare(num_inputs_compare, num_hiddens) # There are 3 possible outputs: entailment, contradiction, and neutral self.aggregate = Aggregate(num_inputs_agg, num_hiddens, num_outputs=3) def forward(self, X): premises, hypotheses = X A = self.embedding(premises) B = self.embedding(hypotheses) beta, alpha = self.attend(A, B) V_A, V_B = self.compare(A, B, beta, alpha) Y_hat = self.aggregate(V_A, V_B) return Y_hat
class TokenEmbedding: """Token Embedding.""" def __init__(self, embedding_name): self.idx_to_token, self.idx_to_vec = self._load_embedding(embedding_name) self.unknown_idx = 0 self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)} def _load_embedding(self, embedding_name): idx_to_token, idx_to_vec = ["<unk>"], [] data_dir = download_extract(embedding_name) # GloVe website: https://nlp.stanford.edu/projects/glove/ # fastText website: https://fasttext.cc/ with open(os.path.join(data_dir, "vec.txt"), "r") as f: for line in f: elems = line.rstrip().split(" ") token, elems = elems[0], [float(elem) for elem in elems[1:]] # Skip header information, such as the top row in fastText if len(elems) > 1: idx_to_token.append(token) idx_to_vec.append(elems) idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec return idx_to_token, torch.tensor(idx_to_vec) def __getitem__(self, tokens): indices = [self.token_to_idx.get(token, self.unknown_idx) for token in tokens] vecs = self.idx_to_vec[torch.tensor(indices)] return vecs def __len__(self): return len(self.idx_to_token)
def try_all_gpus(): """Return all available GPUs, or [cpu(),] if no GPU exists.""" devices = [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] return devices if devices else [torch.device("cpu")]
DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip" DATA_HUB["glove.6b.100d"] = (DATA_URL, "cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a") embed_size, num_hiddens, devices = 100, 200, try_all_gpus() net = DecomposableAttention(vocab, embed_size, num_hiddens) # get pre-trained GloVE embeddings of size 100 glove_embedding = TokenEmbedding("glove.6b.100d") embeds = glove_embedding[vocab.idx_to_token] net.embedding.weight.data.copy_(embeds);
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...

Training

class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid()
def accuracy(y_hat, y): """Compute the number of correct predictions.""" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: y_hat = torch.argmax(y_hat, axis=1) cmp_ = y_hat.type(y.dtype) == y return float(cmp_.type(y.dtype).sum()) def evaluate_accuracy_gpu(net, data_iter, device=None): """Compute the accuracy for a model on a dataset using a GPU.""" if isinstance(net, torch.nn.Module): net.eval() # Set the model to evaluation mode if not device: device = next(iter(net.parameters())).device # No. of correct predictions, no. of predictions metric = Accumulator(2) for X, y in data_iter: if isinstance(X, list): # Required for BERT Fine-tuning X = [x.to(device) for x in X] else: X = X.to(device) y = y.to(device) metric.add(accuracy(net(X), y), y.numel()) return metric[0] / metric[1]
def train_batch(net, X, y, loss, trainer, devices): if isinstance(X, list): # Required for BERT Fine-tuning X = [x.to(devices[0]) for x in X] else: X = X.to(devices[0]) y = y.to(devices[0]) net.train() trainer.zero_grad() pred = net(X) l = loss(pred, y) l.sum().backward() trainer.step() train_loss_sum = l.sum() train_acc_sum = accuracy(pred, y) return train_loss_sum, train_acc_sum def train(net, train_iter, test_iter, loss, trainer, num_epochs, devices=try_all_gpus()): timer, num_batches = Timer(), len(train_iter) animator = Animator( xlabel="epoch", xlim=[1, num_epochs], ylim=[0, 1], legend=["train loss", "train acc", "test acc"] ) net = nn.DataParallel(net, device_ids=devices).to(devices[0]) for epoch in range(num_epochs): # Store training_loss, training_accuracy, num_examples, num_features metric = Accumulator(4) for i, (features, labels) in enumerate(train_iter): timer.start() l, acc = train_batch(net, features, labels, loss, trainer, devices) metric.add(l, acc, labels.shape[0], labels.numel()) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[2], metric[1] / metric[3], None)) test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc)) print(f"loss {metric[0] / metric[2]:.3f}, train acc " f"{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}") print(f"{metric[2] * num_epochs / timer.sum():.1f} examples/sec on " f"{str(devices)}")
lr, num_epochs = 0.001, 4 trainer = torch.optim.Adam(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss(reduction="none") train(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.501, train acc 0.803, test acc 0.815 14828.6 examples/sec on [device(type='cuda', index=0)]
Image in a Jupyter notebook

Testing

def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().""" if torch.cuda.device_count() >= i + 1: return torch.device(f"cuda:{i}") return torch.device("cpu")
def predict_snli(net, vocab, premise, hypothesis): net.eval() premise = torch.tensor(vocab[premise], device=try_gpu()) hypothesis = torch.tensor(vocab[hypothesis], device=try_gpu()) label = torch.argmax(net([premise.reshape((1, -1)), hypothesis.reshape((1, -1))]), dim=1) return "entailment" if label == 0 else "contradiction" if label == 1 else "neutral"
predict_snli(net, vocab, ["he", "is", "good", "."], ["he", "is", "bad", "."])
'contradiction'
predict_snli(net, vocab, ["he", "is", "very", "naughty", "."], ["he", "is", "bad", "."])
'contradiction'
predict_snli(net, vocab, ["he", "is", "awful", "."], ["he", "is", "bad", "."])
'entailment'
predict_snli(net, vocab, ["he", "is", "handsome", "."], ["he", "is", "bad", "."])
'contradiction'

Examples from training set

predict_snli( net, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "outdoors", "on", "a", "horse", "."], )
'entailment'
predict_snli( net, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "at", "a", "diner", "ordering", "an", "omelette", "."], )
'contradiction'
predict_snli( net, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "training", "a", "horse", "for", "a", "competition", "."], )
'neutral'