Path: blob/master/notebooks/book1/15/nmt_attention_torch.ipynb
1192 views
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/nmt_attention_jax.ipynb
Neural machine translation using RNN with an attention-based decoder.
We show how to implement NMT using an RNN with an attention-based decoder Based on sec 10.4 of http://d2l.ai/chapter_attention-mechanisms/bahdanau-attention.html
import numpy as np import matplotlib.pyplot as plt import pandas as pd import math from IPython import display try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch import nn from torch.nn import functional as F from torch.utils import data import collections import re import random import os import requests import zipfile import hashlib import time np.random.seed(seed=1) torch.manual_seed(1) !mkdir figures # for saving plots
Decoder
We define an abstract class for decoders that stores attention weights, so we can visualize them.
class Encoder(nn.Module): """The base encoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs) def forward(self, X, *args): raise NotImplementedError
class Decoder(nn.Module): """The base decoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs) def init_state(self, enc_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class EncoderDecoder(nn.Module): """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder, **kwargs): super(EncoderDecoder, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_outputs, *args) return self.decoder(dec_X, dec_state)
class AttentionDecoder(Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError
Now we define the RNN+attention decoder. The state of the decoder is initialized with i) the encoder final-layer hidden states at all the time steps (as keys and values of the attention); ii) the encoder all-layer hidden state at the final time step (to initialize the hidden state of the decoder); and iii) the encoder valid length (to exclude the padding tokens in attention pooling). At each decoding time step, the decoder final-layer hidden state at the previous time step is used as the query of the attention. As a result, both the attention output and the input embedding are concatenated as the input of the RNN decoder.
class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = torch.unsqueeze(hidden_state[-1], dim=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights class AdditiveAttention(nn.Module): def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k = nn.Linear(key_size, num_hiddens, bias=False) self.W_q = nn.Linear(query_size, num_hiddens, bias=False) self.w_v = nn.Linear(num_hiddens, 1, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of `queries`: (`batch_size`, no. of # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1, # no. of key-value pairs, `num_hiddens`). Sum them up with # broadcasting features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) # There is only one output of `self.w_v`, so we remove the last # one-dimensional entry from the shape. Shape of `scores`: # (`batch_size`, no. of queries, no. of key-value pairs) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) return torch.bmm(self.dropout(self.attention_weights), values)
def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis.""" # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1) def sequence_mask(X, valid_len, value=0): """Mask irrelevant entries in sequences.""" maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X
Example. Minibatch of 4 sequences of length 7.
class Seq2SeqEncoder(Encoder): """The RNN encoder for sequence to sequence learning.""" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqEncoder, self).__init__(**kwargs) # Embedding layer self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout) def forward(self, X, *args): # The output `X` shape: (`batch_size`, `num_steps`, `embed_size`) X = self.embedding(X) # In RNN models, the first axis corresponds to time steps X = X.permute(1, 0, 2) # When state is not mentioned, it defaults to zeros output, state = self.rnn(X) # `output` shape: (`num_steps`, `batch_size`, `num_hiddens`) # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`) return output, state
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() X = torch.zeros((4, 7), dtype=torch.long) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
Training
class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid() def grad_clipping(net, theta): """Clip the gradient.""" if isinstance(net, nn.Module): params = [p for p in net.parameters() if p.requires_grad] else: params = net.params norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params)) if norm > theta: for param in params: param.grad[:] *= theta / norm def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().""" if torch.cuda.device_count() >= i + 1: return torch.device(f"cuda:{i}") return torch.device("cpu")
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): """The softmax cross-entropy loss with masks.""" # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`) # `label` shape: (`batch_size`, `num_steps`) # `valid_len` shape: (`batch_size`,) def forward(self, pred, label, valid_len): weights = torch.ones_like(label) weights = sequence_mask(weights, valid_len) self.reduction = "none" unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label) weighted_loss = (unweighted_loss * weights).mean(dim=1) return weighted_loss
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): """Train a model for sequence to sequence.""" def xavier_init_weights(m): if type(m) == nn.Linear: nn.init.xavier_uniform_(m.weight) if type(m) == nn.GRU: for param in m._flat_weights_names: if "weight" in param: nn.init.xavier_uniform_(m._parameters[param]) net.apply(xavier_init_weights) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss = MaskedSoftmaxCELoss() net.train() animator = Animator(xlabel="epoch", ylabel="loss", xlim=[10, num_epochs]) for epoch in range(num_epochs): timer = Timer() metric = Accumulator(2) # Sum of training loss, no. of tokens for batch in data_iter: X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] bos = torch.tensor([tgt_vocab["<bos>"]] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing Y_hat, _ = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # Make the loss scalar for `backward` grad_clipping(net, 1) num_tokens = Y_valid_len.sum() optimizer.step() with torch.no_grad(): metric.add(l.sum(), num_tokens) if (epoch + 1) % 10 == 0: animator.add(epoch + 1, (metric[0] / metric[1],)) print(f"loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} " f"tokens/sec on {str(device)}")
Data downloading and preprocessing
# Required functions for downloading data def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir
def read_data_nmt(): """Load the English-French dataset.""" data_dir = download_extract("fra-eng") with open(os.path.join(data_dir, "fra.txt"), "r") as f: return f.read() def preprocess_nmt(text): """Preprocess the English-French dataset.""" def no_space(char, prev_char): return char in set(",.!?") and prev_char != " " # Replace non-breaking space with space, and convert uppercase letters to # lowercase ones text = text.replace("\u202f", " ").replace("\xa0", " ").lower() # Insert space between words and punctuation marks out = [" " + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)] return "".join(out) def tokenize_nmt(text, num_examples=None): """Tokenize the English-French dataset.""" source, target = [], [] for i, line in enumerate(text.split("\n")): if num_examples and i > num_examples: break parts = line.split("\t") if len(parts) == 2: source.append(parts[0].split(" ")) target.append(parts[1].split(" ")) return source, target
class Vocab: """Vocabulary for text.""" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): if tokens is None: tokens = [] if reserved_tokens is None: reserved_tokens = [] # Sort according to frequencies counter = count_corpus(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The index for the unknown token is 0 self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens uniq_tokens += [token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens] self.idx_to_token, self.token_to_idx = [], dict() for token in uniq_tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] return [self.idx_to_token[index] for index in indices] def count_corpus(tokens): """Count token frequencies.""" # Here `tokens` is a 1D list or 2D list if len(tokens) == 0 or isinstance(tokens[0], list): # Flatten a list of token lists into a list of tokens tokens = [token for line in tokens for token in line] return collections.Counter(tokens)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs) astype = lambda x, *args, **kwargs: x.type(*args, **kwargs) def build_array_nmt(lines, vocab, num_steps): """Transform text sequences of machine translation into minibatches.""" lines = [vocab[l] for l in lines] lines = [l + [vocab["<eos>"]] for l in lines] array = torch.tensor([truncate_pad(l, num_steps, vocab["<pad>"]) for l in lines]) valid_len = reduce_sum(astype(array != vocab["<pad>"], torch.int32), 1) return array, valid_len
def load_array(data_arrays, batch_size, is_train=True): """Construct a PyTorch data iterator.""" dataset = data.TensorDataset(*data_arrays) return data.DataLoader(dataset, batch_size, shuffle=is_train) def truncate_pad(line, num_steps, padding_token): """Truncate or pad sequences.""" if len(line) > num_steps: return line[:num_steps] # Truncate return line + [padding_token] * (num_steps - len(line)) def load_data_nmt(batch_size, num_steps, num_examples=600): """Return the iterator and the vocabularies of the translation dataset.""" text = preprocess_nmt(read_data_nmt()) source, target = tokenize_nmt(text, num_examples) src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps) tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps) data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len) data_iter = load_array(data_arrays, batch_size) return data_iter, src_vocab, tgt_vocab
DATA_HUB = dict() DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/" DATA_HUB["fra-eng"] = (DATA_URL + "fra-eng.zip", "94646ad1522d915e7b0f9296181140edcf86a4f5")
Learning curve
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, try_gpu() train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps) encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = EncoderDecoder(encoder, decoder) train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
Evaluation
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False): """Predict for sequence to sequence.""" # Set `net` to eval mode for inference net.eval() src_tokens = src_vocab[src_sentence.lower().split(" ")] + [src_vocab["<eos>"]] enc_valid_len = torch.tensor([len(src_tokens)], device=device) src_tokens = truncate_pad(src_tokens, num_steps, src_vocab["<pad>"]) # Add the batch axis enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0) enc_outputs = net.encoder(enc_X, enc_valid_len) dec_state = net.decoder.init_state(enc_outputs, enc_valid_len) # Add the batch axis dec_X = torch.unsqueeze(torch.tensor([tgt_vocab["<bos>"]], dtype=torch.long, device=device), dim=0) output_seq, attention_weight_seq = [], [] for _ in range(num_steps): Y, dec_state = net.decoder(dec_X, dec_state) # We use the token with the highest prediction likelihood as the input # of the decoder at the next time step dec_X = Y.argmax(dim=2) pred = dec_X.squeeze(dim=0).type(torch.int32).item() # Save attention weights (to be covered later) if save_attention_weights: attention_weight_seq.append(net.decoder.attention_weights) # Once the end-of-sequence token is predicted, the generation of the # output sequence is complete if pred == tgt_vocab["<eos>"]: break output_seq.append(pred) return " ".join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
def bleu(pred_seq, label_seq, k): """Compute the BLEU.""" pred_tokens, label_tokens = pred_seq.split(" "), label_seq.split(" ") len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, k + 1): num_matches, label_subs = 0, collections.defaultdict(int) for i in range(len_label - n + 1): label_subs["".join(label_tokens[i : i + n])] += 1 for i in range(len_pred - n + 1): if label_subs["".join(pred_tokens[i : i + n])] > 0: num_matches += 1 label_subs["".join(pred_tokens[i : i + n])] -= 1 score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score
engs = ["go .", "i lost .", "he's calm .", "i'm home ."] fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f"{eng} => {translation}, ", f"bleu {bleu(translation, fra, k=2):.3f}")
Visualize attention weights
We apply the model to an input of length 4. The output has length 6. Thus the heatmap is 6x4. Unfortunately the result is not very interpretable, perhaps because the model was not trained well (small data, small time).
eng = engs[-1] # length 3+1 fra = fras[-1] # length 5+1 translation, dec_attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True) attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap="Reds"): display.set_matplotlib_formats("svg") num_rows, num_cols = matrices.shape[0], matrices.shape[1] fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix.detach(), cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6)
# Plus one to include the end-of-sequence token show_heatmaps( attention_weights[:, :, :, : len(engs[-1].split()) + 1].cpu(), xlabel="Key posistions", ylabel="Query posistions" )