Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/15/nmt_torch.ipynb
1192 views
Kernel: Python 3

Open In Colab

Neural machine translation using encoder-decoder RNN

We show how to implement NMT using an encoder-decoder.

Based on sec 9.7 of http://d2l.ai/chapter_recurrent-modern/seq2seq.html

import numpy as np import matplotlib.pyplot as plt import pandas as pd import math from IPython import display try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch import nn from torch.nn import functional as F from torch.utils import data import collections import re import random import os import requests import zipfile import hashlib import time np.random.seed(seed=1) torch.manual_seed(1) !mkdir figures # for saving plots
mkdir: cannot create directory ‘figures’: File exists

Required functions for text preprocessing

For more details on this functions: See this colab for details.

# Required functions for downloading data def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir
def read_data_nmt(): """Load the English-French dataset.""" data_dir = download_extract("fra-eng") with open(os.path.join(data_dir, "fra.txt"), "r") as f: return f.read() def preprocess_nmt(text): """Preprocess the English-French dataset.""" def no_space(char, prev_char): return char in set(",.!?") and prev_char != " " # Replace non-breaking space with space, and convert uppercase letters to # lowercase ones text = text.replace("\u202f", " ").replace("\xa0", " ").lower() # Insert space between words and punctuation marks out = [" " + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)] return "".join(out) def tokenize_nmt(text, num_examples=None): """Tokenize the English-French dataset.""" source, target = [], [] for i, line in enumerate(text.split("\n")): if num_examples and i > num_examples: break parts = line.split("\t") if len(parts) == 2: source.append(parts[0].split(" ")) target.append(parts[1].split(" ")) return source, target
class Vocab: """Vocabulary for text.""" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): if tokens is None: tokens = [] if reserved_tokens is None: reserved_tokens = [] # Sort according to frequencies counter = count_corpus(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The index for the unknown token is 0 self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens uniq_tokens += [token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens] self.idx_to_token, self.token_to_idx = [], dict() for token in uniq_tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] return [self.idx_to_token[index] for index in indices] def count_corpus(tokens): """Count token frequencies.""" # Here `tokens` is a 1D list or 2D list if len(tokens) == 0 or isinstance(tokens[0], list): # Flatten a list of token lists into a list of tokens tokens = [token for line in tokens for token in line] return collections.Counter(tokens)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs) astype = lambda x, *args, **kwargs: x.type(*args, **kwargs) def build_array_nmt(lines, vocab, num_steps): """Transform text sequences of machine translation into minibatches.""" lines = [vocab[l] for l in lines] lines = [l + [vocab["<eos>"]] for l in lines] array = torch.tensor([truncate_pad(l, num_steps, vocab["<pad>"]) for l in lines]) valid_len = reduce_sum(astype(array != vocab["<pad>"], torch.int32), 1) return array, valid_len
def load_array(data_arrays, batch_size, is_train=True): """Construct a PyTorch data iterator.""" dataset = data.TensorDataset(*data_arrays) return data.DataLoader(dataset, batch_size, shuffle=is_train) def truncate_pad(line, num_steps, padding_token): """Truncate or pad sequences.""" if len(line) > num_steps: return line[:num_steps] # Truncate return line + [padding_token] * (num_steps - len(line)) def load_data_nmt(batch_size, num_steps, num_examples=600): """Return the iterator and the vocabularies of the translation dataset.""" text = preprocess_nmt(read_data_nmt()) source, target = tokenize_nmt(text, num_examples) src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps) tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps) data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len) data_iter = load_array(data_arrays, batch_size) return data_iter, src_vocab, tgt_vocab

Data

We use a english-french dataset. See this colab for details.

DATA_HUB = dict() DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/" DATA_HUB["fra-eng"] = (DATA_URL + "fra-eng.zip", "94646ad1522d915e7b0f9296181140edcf86a4f5") batch_size, num_steps = 64, 10 train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)

Encoder-decoder

Abstract base class

class Encoder(nn.Module): """The base encoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs) def forward(self, X, *args): raise NotImplementedError
class Decoder(nn.Module): """The base decoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs) def init_state(self, enc_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class EncoderDecoder(nn.Module): """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder, **kwargs): super(EncoderDecoder, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_outputs, *args) return self.decoder(dec_X, dec_state)

Encoder

We use a 2-level GRU for the encoder; we set the context as the final state of the GRU. The input to the GRU is the word embedding of each token.

class Seq2SeqEncoder(Encoder): """The RNN encoder for sequence to sequence learning.""" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqEncoder, self).__init__(**kwargs) # Embedding layer self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout) def forward(self, X, *args): # The output `X` shape: (`batch_size`, `num_steps`, `embed_size`) X = self.embedding(X) # In RNN models, the first axis corresponds to time steps X = X.permute(1, 0, 2) # When state is not mentioned, it defaults to zeros output, state = self.rnn(X) # `output` shape: (`num_steps`, `batch_size`, `num_hiddens`) # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`) return output, state
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() batch_size = 4 num_steps = 7 X = torch.zeros((batch_size, num_steps), dtype=torch.long) output, state = encoder(X) print(output.shape) print(state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

Decoder

We use another GRU as the decoder. The initial state is the final state of the encoder, so we must use the same number of hidden units. In addition, we pass in the context (ie final state of encoder) as input to every step of the decoder.

class Seq2SeqDecoder(Decoder): """The RNN decoder for sequence to sequence learning.""" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqDecoder, self).__init__(**kwargs) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, *args): return enc_outputs[1] def forward(self, X, state): # The output `X` shape: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).permute(1, 0, 2) # Broadcast `context` so it has the same `num_steps` as `X` context = state[-1].repeat(X.shape[0], 1, 1) X_and_context = torch.cat((X, context), 2) output, state = self.rnn(X_and_context, state) output = self.dense(output).permute(1, 0, 2) # `output` shape: (`batch_size`, `num_steps`, `vocab_size`) # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`) return output, state
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() state = decoder.init_state(encoder(X)) output, state = decoder(X, state) print(output.shape) # (batch size, number of time steps, vocabulary size) print(state.shape) # (num layers, batch size, num hiddens)
torch.Size([4, 7, 10]) torch.Size([2, 4, 16])

Loss function

We use cross entropy loss, but we must mask out target tokens that are just padding. We replace all outputs beyond the valid length to the target value of 0.

def sequence_mask(X, valid_len, value=0): """Mask irrelevant entries in sequences.""" maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X X = torch.tensor([[1, 2, 3], [4, 5, 6]]) sequence_mask(X, torch.tensor([1, 2]))
tensor([[1, 0, 0], [4, 5, 0]])

We now use this to create a weight mask of 0s and 1s, where 0 corresponds to invalid locations. When we compute the cross entropy loss, we multiply by this weight mask, thus ignoring invalid locations.

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): """The softmax cross-entropy loss with masks.""" # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`) # `label` shape: (`batch_size`, `num_steps`) # `valid_len` shape: (`batch_size`,) def forward(self, pred, label, valid_len): weights = torch.ones_like(label) weights = sequence_mask(weights, valid_len) self.reduction = "none" unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label) weighted_loss = (unweighted_loss * weights).mean(dim=1) return weighted_loss

As an example, let us create a prediction tensor of all ones of size (3,4,10) and a target label tensor of all ones of size (3,4). We specify the valud lengths to (4,2,0). Thus the first loss should be twice the second. And the third loss should be 0.

loss = MaskedSoftmaxCELoss() loss(torch.ones(3, 4, 10), torch.ones((3, 4), dtype=torch.long), torch.tensor([4, 2, 0]))
tensor([2.3026, 1.1513, 0.0000])

Training

We use teacher forcing, where the inputs to the decoder are "bos" (beginning of sentence), followed by the ground truth target tokens from the previous step, as shown below.

class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid() def grad_clipping(net, theta): """Clip the gradient.""" if isinstance(net, nn.Module): params = [p for p in net.parameters() if p.requires_grad] else: params = net.params norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params)) if norm > theta: for param in params: param.grad[:] *= theta / norm def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().""" if torch.cuda.device_count() >= i + 1: return torch.device(f"cuda:{i}") return torch.device("cpu")
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): """Train a model for sequence to sequence.""" def xavier_init_weights(m): if type(m) == nn.Linear: nn.init.xavier_uniform_(m.weight) if type(m) == nn.GRU: for param in m._flat_weights_names: if "weight" in param: nn.init.xavier_uniform_(m._parameters[param]) net.apply(xavier_init_weights) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss = MaskedSoftmaxCELoss() net.train() animator = Animator(xlabel="epoch", ylabel="loss", xlim=[10, num_epochs]) for epoch in range(num_epochs): timer = Timer() metric = Accumulator(2) # Sum of training loss, no. of tokens for batch in data_iter: X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] bos = torch.tensor([tgt_vocab["<bos>"]] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing Y_hat, _ = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # Make the loss scalar for `backward` grad_clipping(net, 1) num_tokens = Y_valid_len.sum() optimizer.step() with torch.no_grad(): metric.add(l.sum(), num_tokens) if (epoch + 1) % 10 == 0: animator.add(epoch + 1, (metric[0] / metric[1],)) print(f"loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} " f"tokens/sec on {str(device)}")
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 lr, num_epochs, device = 0.005, 300, try_gpu() encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = EncoderDecoder(encoder, decoder) train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.019, 7621.4 tokens/sec on cpu
Image in a Jupyter notebook

Prediction

We use greedy decoding, where the inputs to the decoder are "bos" (beginning of sentence), followed by the most likely target token from the previous step, as shown below. We keep decoding until the model generates "eos" (end of sentence).

def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False): """Predict for sequence to sequence.""" # Set `net` to eval mode for inference net.eval() src_tokens = src_vocab[src_sentence.lower().split(" ")] + [src_vocab["<eos>"]] enc_valid_len = torch.tensor([len(src_tokens)], device=device) src_tokens = truncate_pad(src_tokens, num_steps, src_vocab["<pad>"]) # Add the batch axis enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0) enc_outputs = net.encoder(enc_X, enc_valid_len) dec_state = net.decoder.init_state(enc_outputs, enc_valid_len) # Add the batch axis dec_X = torch.unsqueeze(torch.tensor([tgt_vocab["<bos>"]], dtype=torch.long, device=device), dim=0) output_seq, attention_weight_seq = [], [] for _ in range(num_steps): Y, dec_state = net.decoder(dec_X, dec_state) # We use the token with the highest prediction likelihood as the input # of the decoder at the next time step dec_X = Y.argmax(dim=2) pred = dec_X.squeeze(dim=0).type(torch.int32).item() # Save attention weights (to be covered later) if save_attention_weights: attention_weight_seq.append(net.decoder.attention_weights) # Once the end-of-sequence token is predicted, the generation of the # output sequence is complete if pred == tgt_vocab["<eos>"]: break output_seq.append(pred) return " ".join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

Evaluation

In the MT community, the standard evaluation metric is known as BLEU (Bilingual Evaluation Understudy), which measures how many n-grams in the predicted target match the true label target.

For example, suppose the prediction is A,B,B,C,D and the target is A,B,C,D,E,F. There are five 1-grams in the prediction, of which 4 find a match in the target (the second "B" is a "false positive"), so the precision for 1-grams is p1=4/5p_1=4/5. Similarly, there are four 2-grams, of which 3 find a match (the bigram "BB" does not occur), so p2=3/4p_2 = 3/4. We continue in this way to compute up to pkp_k, where kk is the max n-gram length. (Since we are using words, not characters, we typically keep kk small, to avoid sparse counts.)

The BLEU score is then defined by exp(min(0,1LyLp))n=1kpn0.5n \exp(\min(0, 1-\frac{L_y}{L_p})) \prod_{n=1}^k p_n^{0.5^n} where LyL_y is the length of the target label sequence, and LpL_p is the length of the prediction.

Since predicting shorter sequences tends to give higher pnp_n values, short sequences are penalized by the exponential factor. For example, suppose k=2k=2 and the label sequence is A,B,C,D,E,F. If the predicted sequence is A,B,B,C,D, we have p1=4/5p_1=4/5 and p2=3/4p_2=3/4, and the penalty factor is exp(16/5)=0.818\exp(1-6/5)=0.818. If the predicted sequence is A,B, we have p1=p2=1p_1=p_2=1, but the penalty factor is exp(16/2)0.135\exp(1−6/2)≈0.135.

def bleu(pred_seq, label_seq, k): """Compute the BLEU.""" pred_tokens, label_tokens = pred_seq.split(" "), label_seq.split(" ") len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, k + 1): num_matches, label_subs = 0, collections.defaultdict(int) for i in range(len_label - n + 1): label_subs["".join(label_tokens[i : i + n])] += 1 for i in range(len_pred - n + 1): if label_subs["".join(pred_tokens[i : i + n])] > 0: num_matches += 1 label_subs["".join(pred_tokens[i : i + n])] -= 1 score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score
# prediction engs = ["go .", "i lost .", "he's calm .", "i'm home ."] fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."] data = [] for eng, fra in zip(engs, fras): translation, attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device) score = bleu(translation, fra, k=2) data.append((eng, fra, translation, score)) df = pd.DataFrame.from_records(data, columns=["English", "Truth", "Prediction", "Bleu"]) display.display(df)