Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/15/entailment_attention_mlp_jax.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

Textual entailment classifier using an MLP plus attention

In textual entailment, the input is 2 sentences (premise and hypothesis), and the output is a label, specifying if P entails H, P contradicts H, or neither. (This is also called "natural language inference".) We use attention to align hypothesis to premise and vice versa, then compare the aligned words to estimate similarity between the sentences, and pass the weighted similarities to an MLP.

Based on sec 15.5 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-attention.html

|████████████████████████████████| 184 kB 12.2 MB/s eta 0:00:01 |████████████████████████████████| 136 kB 50.2 MB/s |████████████████████████████████| 72 kB 709 kB/s
import jax import jax.numpy as jnp # JAX NumPy try: from flax import linen as nn # The Linen API except ModuleNotFoundError: %pip install -qq flax from flax import linen as nn # The Linen API from flax.training import train_state # Useful dataclass to keep train state try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch.utils import data # For data import numpy as np # Ordinary NumPy try: import optax # Optimizers except ModuleNotFoundError: %pip install -qq optax import optax # Optimizers import collections import re import os import requests import zipfile import tarfile import hashlib import time import functools from typing import Any, Callable, Sequence, Tuple import matplotlib.pyplot as plt import pandas as pd import math from IPython import display rng = jax.random.PRNGKey(0) !mkdir figures # for saving plots ModuleDef = Any

Data

We use SNLI (Stanford Natural Language Inference) dataset described in sec 15.4 of http://d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html.

# Required functions for downloading data def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir
DATA_HUB = dict() DATA_HUB["SNLI"] = ("https://nlp.stanford.edu/projects/snli/snli_1.0.zip", "9fcde07509c7e87ec61c640c1b2753d9041758e4") data_dir = download_extract("SNLI")
Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...
def read_snli(data_dir, is_train): """Read the SNLI dataset into premises, hypotheses, and labels.""" def extract_text(s): # Remove information that will not be used by us s = re.sub("\\(", "", s) s = re.sub("\\)", "", s) # Substitute two or more consecutive whitespace with space s = re.sub("\\s{2,}", " ", s) return s.strip() label_set = {"entailment": 0, "contradiction": 1, "neutral": 2} file_name = os.path.join(data_dir, "snli_1.0_train.txt" if is_train else "snli_1.0_test.txt") with open(file_name, "r") as f: rows = [row.split("\t") for row in f.readlines()[1:]] premises = [extract_text(row[1]) for row in rows if row[0] in label_set] hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set] labels = [label_set[row[0]] for row in rows if row[0] in label_set] return premises, hypotheses, labels

Show first 3 training examples and their labels (“0”, “1”, and “2” correspond to “entailment”, “contradiction”, and “neutral”, respectively ).

train_data = read_snli(data_dir, is_train=True) for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]): print("premise:", x0) print("hypothesis:", x1) print("label:", y)
premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is training his horse for a competition . label: 2 premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is at a diner , ordering an omelette . label: 1 premise: A person on a horse jumps over a broken down airplane . hypothesis: A person is outdoors , on a horse . label: 0
test_data = read_snli(data_dir, is_train=False) for data in [train_data, test_data]: print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764] [3368, 3237, 3219]
def tokenize(lines, token="word"): """Split text lines into word or character tokens.""" if token == "word": return [line.split() for line in lines] elif token == "char": return [list(line) for line in lines] else: print("ERROR: unknown token type: " + token) class Vocab: """Vocabulary for text.""" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): if tokens is None: tokens = [] if reserved_tokens is None: reserved_tokens = [] # Sort according to frequencies counter = count_corpus(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The index for the unknown token is 0 self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens uniq_tokens += [token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens] self.idx_to_token, self.token_to_idx = [], dict() for token in uniq_tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] return [self.idx_to_token[index] for index in indices] def count_corpus(tokens): """Count token frequencies.""" # Here `tokens` is a 1D list or 2D list if len(tokens) == 0 or isinstance(tokens[0], list): # Flatten a list of token lists into a list of tokens tokens = [token for line in tokens for token in line] return collections.Counter(tokens)
class SNLIDataset(torch.utils.data.Dataset): """A customized dataset to load the SNLI dataset.""" def __init__(self, dataset, num_steps, vocab=None): self.num_steps = num_steps all_premise_tokens = tokenize(dataset[0]) all_hypothesis_tokens = tokenize(dataset[1]) if vocab is None: self.vocab = Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq=5, reserved_tokens=["<pad>"]) else: self.vocab = vocab self.premises = self._pad(all_premise_tokens) self.hypotheses = self._pad(all_hypothesis_tokens) self.labels = torch.tensor(dataset[2]) print("read " + str(len(self.premises)) + " examples") def _pad(self, lines): return torch.tensor([truncate_pad(self.vocab[line], self.num_steps, self.vocab["<pad>"]) for line in lines]) def __getitem__(self, idx): return (self.premises[idx], self.hypotheses[idx]), self.labels[idx] def __len__(self): return len(self.premises)
def load_data_snli(batch_size, num_steps=50): """Download the SNLI dataset and return data iterators and vocabulary.""" num_workers = 2 data_dir = download_extract("SNLI") train_data = read_snli(data_dir, True) test_data = read_snli(data_dir, False) train_set = SNLIDataset(train_data, num_steps) test_set = SNLIDataset(test_data, num_steps, train_set.vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter, train_set.vocab def truncate_pad(line, num_steps, padding_token): """Truncate or pad sequences.""" if len(line) > num_steps: return line[:num_steps] # Truncate return line + [padding_token] * (num_steps - len(line))
train_iter, test_iter, vocab = load_data_snli(128, 50) len(vocab)
read 549367 examples read 9824 examples
18678

Model

The model is described in the book. Below we just give the code.

Attending

We define attention weights eij=f(ai)Tf(bj) e_{ij} = f(a_i)^T f(b_j) where aiREa_i \in R^E is the embedding of the ii'th token from the premise, bjREb_j \in R^E is the embedding of the jj'th token from the hypothesis, and f:RERHf: R^E \rightarrow R^H is an MLP that maps from the embedding space to another hidden space.

class mlp(nn.Module): num_hiddens: int flatten: bool @nn.compact def __call__(self, X, train=True): X = nn.Dropout(rate=0.2, deterministic=not train)(X) X = nn.Dense(self.num_hiddens)(X) X = nn.relu(X) if self.flatten: X = X.reshape((X.shape[0], -1)) # flatten X = nn.Dropout(rate=0.2, deterministic=not train)(X) X = nn.Dense(self.num_hiddens)(X) X = nn.relu(X) if self.flatten: X = X.reshape((X.shape[0], -1)) # flatten return X
class Attend(nn.Module): num_hiddens: int def setup(self): self.f = mlp(self.num_hiddens, False) @nn.compact def __call__(self, A, B, train=True): # Shape of `A`/`B`: (`batch_size`, no. of words in sequence A/B, # `embed_size`) # Shape of `f_A`/`f_B`: (`batch_size`, no. of words in sequence A/B, # `num_hiddens`) f_A = self.f(A, train) f_B = self.f(B, train) # Shape of `e`: (`batch_size`, no. of words in sequence A, # no. of words in sequence B) e = f_A @ (f_B.transpose((0, 2, 1))) # Shape of `beta`: (`batch_size`, no. of words in sequence A, # `embed_size`), where sequence B is softly aligned with each word # (axis 1 of `beta`) in sequence A beta = nn.softmax(e, axis=-1) @ B # Shape of `alpha`: (`batch_size`, no. of words in sequence B, # `embed_size`), where sequence A is softly aligned with each word # (axis 1 of `alpha`) in sequence B alpha = nn.softmax(e.transpose((0, 2, 1)), axis=-1) @ A return beta, alpha

Comparing

We concatenate word ii in A, aia_i, with its "soft counterpart" in B, βi\beta_i, and vice versa, and then pass this through another MLP gg to get a "comparison vector" for each input location. vA,i=g([ai,βi]),  i=1,,mvB,j=g([bj,αj]),  j=1,,n \begin{align} v_{A,i} &= g([a_i, \beta_i]), \; i=1,\ldots, m \\ v_{B,j} &= g([b_j, \alpha_j]), \; j=1,\ldots, n \end{align}

class Compare(nn.Module): num_hiddens: int def setup(self): self.g = mlp(self.num_hiddens, False) @nn.compact def __call__(self, A, B, beta, alpha, train=True): V_A = self.g(jnp.concatenate((A, beta), axis=2), train) V_B = self.g(jnp.concatenate((B, alpha), axis=2), train) return V_A, V_B

Aggregation

We sum-pool the "comparison vectors" for each input sentence, and then pass the pair of poolings to yet another MLP hh to generate the final classification.

vA=i=1mvA,ivB=j=1nvB,jy^=h([vA,vB])\begin{align} v_A &= \sum_{i=1}^m v_{A,i} \\ v_B &= \sum_{j=1}^n v_{B,j} \\ \hat{y} &= h([v_A, v_B]) \end{align}
class Aggregate(nn.Module): num_hiddens: int num_outputs: int @nn.compact def __call__(self, V_A, V_B, train=True): # Sum up both sets of comparison vectors V_A = V_A.sum(axis=1) V_B = V_B.sum(axis=1) # Feed the concatenation of both summarization results into an MLP Y_hat = nn.Dense(self.num_outputs)(mlp(self.num_hiddens, True)(jnp.concatenate((V_A, V_B), axis=1), train)) return Y_hat

Putting it altogether

We use a pre-trained embedding of size E=100. The ff (attend) function maps from E=100E=100 to H=200H=200 hiddens. The gg (compare) function maps 2E=2002E=200 to H=200H=200. The hh (aggregate) function maps 2H=4002H=400 to 3 outputs.

class DecomposableAttention(nn.Module): vocab: Any embed_size: int num_hiddens: int embed_init: Callable def setup(self): self.embedding = nn.Embed(len(self.vocab), self.embed_size, embedding_init=self.embed_init) self.attend = Attend(self.num_hiddens) self.compare = Compare(self.num_hiddens) # There are 3 possible outputs: entailment, contradiction, and neutral self.aggregate = Aggregate(self.num_hiddens, 3) def __call__(self, X, train=True): premises, hypotheses = X A = self.embedding(premises) B = self.embedding(hypotheses) beta, alpha = self.attend(A, B, train) V_A, V_B = self.compare(A, B, beta, alpha, train) Y_hat = self.aggregate(V_A, V_B, train) return Y_hat
class TokenEmbedding: """Token Embedding.""" def __init__(self, embedding_name): self.idx_to_token, self.idx_to_vec = self._load_embedding(embedding_name) self.unknown_idx = 0 self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)} def _load_embedding(self, embedding_name): idx_to_token, idx_to_vec = ["<unk>"], [] data_dir = download_extract(embedding_name) # GloVe website: https://nlp.stanford.edu/projects/glove/ # fastText website: https://fasttext.cc/ with open(os.path.join(data_dir, "vec.txt"), "r") as f: for line in f: elems = line.rstrip().split(" ") token, elems = elems[0], [float(elem) for elem in elems[1:]] # Skip header information, such as the top row in fastText if len(elems) > 1: idx_to_token.append(token) idx_to_vec.append(elems) idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec return idx_to_token, jnp.array(idx_to_vec) def __getitem__(self, tokens): indices = [self.token_to_idx.get(token, self.unknown_idx) for token in tokens] vecs = self.idx_to_vec[jnp.array(indices)] return vecs def __len__(self): return len(self.idx_to_token)
def embedding_init(rng, shape, dtype): # get pre-trained GloVE embeddings of size 100 # glove_embedding = TokenEmbedding('glove.6b.100d') # embeds = glove_embedding[vocab.idx_to_token] return embeds
DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip" DATA_HUB["glove.6b.100d"] = (DATA_URL, "cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a") glove_embedding = TokenEmbedding("glove.6b.100d") embeds = glove_embedding[vocab.idx_to_token] embed_size, num_hiddens = 100, 200 AttentionNetwork = functools.partial(DecomposableAttention, vocab, embed_size, num_hiddens, embedding_init)
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...

Training

class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid()
def cross_entropy_loss(logits, labels) -> float: one_hot = jax.nn.one_hot(labels, num_classes=3) loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot) return loss.sum() def compute_metrics(logits, labels): """Computes metrics and returns them.""" loss = cross_entropy_loss(logits, labels) accuracy = jnp.sum(jnp.argmax(logits, -1) == labels) metrics = { "loss": loss, "accuracy": accuracy, } return metrics
def get_initial_params(model, rng): X = jnp.ones((2, 128, 50), dtype=jnp.int32) variables = model.init(jax.random.PRNGKey(0), X, False) return variables["params"] def get_train_state(rng, lr) -> train_state.TrainState: """Returns a train state.""" model = AttentionNetwork() params = get_initial_params(model, rng) tx = optax.adam(lr) return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def evaluate_accuracy(state: train_state.TrainState, data_iter): """Compute the accuracy for a model on a dataset using a GPU.""" # No. of correct predictions, no. of predictions metric = Accumulator(2) for X, y in data_iter: X = [jnp.array(x) for x in X] y = jnp.array(y) logits = state.apply_fn({"params": state.params}, X, False) accuracy = jnp.sum(jnp.argmax(logits, -1) == y) metric.add(accuracy, y.size) return metric[0] / metric[1]
@jax.jit def train_step(state: train_state.TrainState, dropout_rng, features, labels): """Trains one step.""" def loss_fn(params): logits = state.apply_fn({"params": params}, features, True, rngs={"dropout": dropout_rng}) loss = cross_entropy_loss(logits, labels) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, labels) return state, metrics
def train(train_iter, test_iter, num_epochs, lr): key = jax.random.PRNGKey(42) state = get_train_state(key, lr) timer, num_batches = Timer(), len(train_iter) animator = Animator( xlabel="epoch", xlim=[1, num_epochs], ylim=[0, 1], legend=["train loss", "train acc", "test acc"] ) for epoch in range(num_epochs): # Store training_loss, training_accuracy, num_examples, num_features metric = Accumulator(4) for i, (features, labels) in enumerate(train_iter): features = [jnp.array(x) for x in features] labels = jnp.array(labels) timer.start() state, metrics = train_step(state, key, features, labels) l = metrics["loss"] acc = metrics["accuracy"] metric.add(l, acc, labels.shape[0], labels.size) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[2], metric[1] / metric[3], None)) # Calculate Test Accuracy test_acc = evaluate_accuracy(state, test_iter) animator.add(epoch + 1, (None, None, test_acc)) device = jax.default_backend() print(f"loss {metric[0] / metric[2]:.3f}, train acc " f"{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}") print(f"{metric[2] * num_epochs / timer.sum():.1f} examples/sec on " f"{str(device)}") return state
lr, num_epochs = 0.001, 4 state = train(train_iter, test_iter, num_epochs, lr)
loss 0.500, train acc 0.802, test acc 0.811 6486.3 examples/sec on gpu
Image in a Jupyter notebook

Testing

def predict_snli(state, vocab, premise, hypothesis): model = AttentionNetwork() premise = jnp.array(vocab[premise]) hypothesis = jnp.array(vocab[hypothesis]) features = [premise.reshape((1, -1)), hypothesis.reshape((1, -1))] logits = state.apply_fn({"params": state.params}, features, False) label = jnp.argmax(logits) return "entailment" if label == 0 else "contradiction" if label == 1 else "neutral"
predict_snli(state, vocab, ["he", "is", "good", "."], ["he", "is", "bad", "."])
'contradiction'
predict_snli(state, vocab, ["he", "is", "very", "naughty", "."], ["he", "is", "bad", "."])
'entailment'
predict_snli(state, vocab, ["he", "is", "awful", "."], ["he", "is", "bad", "."])
'entailment'
predict_snli(state, vocab, ["he", "is", "handsome", "."], ["he", "is", "bad", "."])
'contradiction'

Examples from training set

predict_snli( state, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "outdoors", "on", "a", "horse", "."], )
'entailment'
predict_snli( state, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "at", "a", "diner", "ordering", "an", "omelette", "."], )
'contradiction'
predict_snli( state, vocab, ["a", "person", "on", "a", "horse", "jumps", "over", "a", "log" "."], ["a", "person", "is", "training", "a", "horse", "for", "a", "competition", "."], )
'neutral'