Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/15/transformers_torch.ipynb
1192 views
Kernel: Python 3

Open In Colab

Transformers

We show how to implement transformers. Based on sec 10.7 of http://d2l.ai/chapter_attention-mechanisms/transformer.html

import numpy as np import matplotlib.pyplot as plt import pandas as pd import math from IPython import display try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch import nn from torch.nn import functional as F from torch.utils import data import collections import re import random import os import requests import zipfile import hashlib import time np.random.seed(seed=1) torch.manual_seed(1) !mkdir figures # for saving plots

Layers

class PositionWiseFFN(nn.Module): def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs): super(PositionWiseFFN, self).__init__(**kwargs) self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens) self.relu = nn.ReLU() self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) def forward(self, X): return self.dense2(self.relu(self.dense1(X)))
ffn = PositionWiseFFN(4, 4, 8) # batch 4 lengh 4 embed 8 ffn.eval() Y = ffn(torch.ones((2, 3, 4))) print(Y.shape)
torch.Size([2, 3, 8])
class AddNorm(nn.Module): def __init__(self, normalized_shape, dropout, **kwargs): super(AddNorm, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) self.ln = nn.LayerNorm(normalized_shape) def forward(self, X, Y): return self.ln(self.dropout(Y) + X)
add_norm = AddNorm([3, 4], 0.5) # Normalized_shape is input.size()[1:] add_norm.eval() add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
torch.Size([2, 3, 4])

Abstract base class

class Encoder(nn.Module): """The base encoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs) def forward(self, X, *args): raise NotImplementedError
class Decoder(nn.Module): """The base decoder interface for the encoder-decoder architecture.""" def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs) def init_state(self, enc_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class EncoderDecoder(nn.Module): """The base class for the encoder-decoder architecture.""" def __init__(self, encoder, decoder, **kwargs): super(EncoderDecoder, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_outputs, *args) return self.decoder(dec_X, dec_state)
class AttentionDecoder(Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError

Encoder

class DotProductAttention(nn.Module): """Scaled dot product attention.""" def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # Shape of `queries`: (`batch_size`, no. of queries, `d`) # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # Set `transpose_b=True` to swap the last two dimensions of `keys` scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values) class MultiHeadAttention(nn.Module): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # Shape of `queries`, `keys`, or `values`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`) # Shape of `valid_lens`: # (`batch_size`,) or (`batch_size`, no. of queries) # After transposing, shape of output `queries`, `keys`, or `values`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for # `num_heads` times, then copy the next item, and so on valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0) # Shape of `output`: (`batch_size` * `num_heads`, no. of queries, # `num_hiddens` / `num_heads`) output = self.attention(queries, keys, values, valid_lens) # Shape of `output_concat`: # (`batch_size`, no. of queries, `num_hiddens`) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat)
def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis.""" # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1) def sequence_mask(X, valid_len, value=0): """Mask irrelevant entries in sequences.""" maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X def transpose_qkv(X, num_heads): # Shape of input `X`: # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`). # Shape of output `X`: # (`batch_size`, no. of queries or key-value pairs, `num_heads`, # `num_hiddens` / `num_heads`) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # Shape of output `X`: # (`batch_size`, `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) X = X.permute(0, 2, 1, 3) # Shape of `output`: # (`batch_size` * `num_heads`, no. of queries or key-value pairs, # `num_hiddens` / `num_heads`) return X.reshape(-1, X.shape[2], X.shape[3]) def transpose_output(X, num_heads): """Reverse the operation of `transpose_qkv`""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)
class PositionalEncoding(nn.Module): def __init__(self, num_hiddens, dropout, max_len=1000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(dropout) # Create a long enough `P` self.P = torch.zeros((1, max_len, num_hiddens)) X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow( 10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens ) self.P[:, :, 0::2] = torch.sin(X) self.P[:, :, 1::2] = torch.cos(X) def forward(self, X): X = X + self.P[:, : X.shape[1], :].to(X.device) return self.dropout(X) class EncoderBlock(nn.Module): def __init__( self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs ): super(EncoderBlock, self).__init__(**kwargs) self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm2 = AddNorm(norm_shape, dropout) def forward(self, X, valid_lens): Y = self.addnorm1(X, self.attention(X, X, X, valid_lens)) return self.addnorm2(Y, self.ffn(Y))
X = torch.ones((2, 100, 24)) valid_lens = torch.tensor([3, 2]) encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5) encoder_blk.eval() encoder_blk(X, valid_lens).shape
torch.Size([2, 100, 24])
class TransformerEncoder(Encoder): def __init__( self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs ): super(TransformerEncoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module( "block" + str(i), EncoderBlock( key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias, ), ) def forward(self, X, valid_lens, *args): # Since positional encoding values are between -1 and 1, the embedding # values are multiplied by the square root of the embedding dimension # to rescale before they are summed up X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self.attention_weights = [None] * len(self.blks) for i, blk in enumerate(self.blks): X = blk(X, valid_lens) self.attention_weights[i] = blk.attention.attention.attention_weights return X

The shape of the transformer encoder output is (batch size, number of time steps, num_hiddens).

encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) encoder.eval() encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
torch.Size([2, 100, 24])

Decoder

class DecoderBlock(nn.Module): # The `i`-th block in the decoder def __init__( self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs ): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) def forward(self, X, state): enc_outputs, enc_valid_lens = state[0], state[1] # During training, all the tokens of any output sequence are processed # at the same time, so `state[2][self.i]` is `None` as initialized. # When decoding any output sequence token by token during prediction, # `state[2][self.i]` contains representations of the decoded output at # the `i`-th block up to the current time step if state[2][self.i] is None: key_values = X else: key_values = torch.cat((state[2][self.i], X), axis=1) state[2][self.i] = key_values if self.training: batch_size, num_steps, _ = X.shape # Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where # every row is [1, 2, ..., `num_steps`] dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1) else: dec_valid_lens = None # Self-attention X2 = self.attention1(X, key_values, key_values, dec_valid_lens) Y = self.addnorm1(X, X2) # Encoder-decoder attention. Shape of `enc_outputs`: # (`batch_size`, `num_steps`, `num_hiddens`) Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) Z = self.addnorm2(Y, Y2) return self.addnorm3(Z, self.ffn(Z)), state
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0) decoder_blk.eval() X = torch.ones((2, 100, 24)) state = [encoder_blk(X, valid_lens), valid_lens, [None]] decoder_blk(X, state)[0].shape
torch.Size([2, 100, 24])
class TransformerDecoder(AttentionDecoder): def __init__( self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs ): super(TransformerDecoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module( "block" + str(i), DecoderBlock( key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, ), ) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): return [enc_outputs, enc_valid_lens, [None] * self.num_layers] def forward(self, X, state): X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self._attention_weights = [[None] * len(self.blks) for _ in range(2)] for i, blk in enumerate(self.blks): X, state = blk(X, state) # Decoder self-attention weights self._attention_weights[0][i] = blk.attention1.attention.attention_weights # Encoder-decoder attention weights self._attention_weights[1][i] = blk.attention2.attention.attention_weights return self.dense(X), state @property def attention_weights(self): return self._attention_weights

Full model

# Required functions for downloading data def download(name, cache_dir=os.path.join("..", "data")): """Download a file inserted into DATA_HUB, return the local filename.""" assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname, "rb") as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Hit cache print(f"Downloading {fname} from {url}...") r = requests.get(url, stream=True, verify=True) with open(fname, "wb") as f: f.write(r.content) return fname def download_extract(name, folder=None): """Download and extract a zip/tar file.""" fname = download(name) base_dir = os.path.dirname(fname) data_dir, ext = os.path.splitext(fname) if ext == ".zip": fp = zipfile.ZipFile(fname, "r") elif ext in (".tar", ".gz"): fp = tarfile.open(fname, "r") else: assert False, "Only zip/tar files can be extracted." fp.extractall(base_dir) return os.path.join(base_dir, folder) if folder else data_dir
def read_data_nmt(): """Load the English-French dataset.""" data_dir = download_extract("fra-eng") with open(os.path.join(data_dir, "fra.txt"), "r") as f: return f.read() def preprocess_nmt(text): """Preprocess the English-French dataset.""" def no_space(char, prev_char): return char in set(",.!?") and prev_char != " " # Replace non-breaking space with space, and convert uppercase letters to # lowercase ones text = text.replace("\u202f", " ").replace("\xa0", " ").lower() # Insert space between words and punctuation marks out = [" " + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)] return "".join(out) def tokenize_nmt(text, num_examples=None): """Tokenize the English-French dataset.""" source, target = [], [] for i, line in enumerate(text.split("\n")): if num_examples and i > num_examples: break parts = line.split("\t") if len(parts) == 2: source.append(parts[0].split(" ")) target.append(parts[1].split(" ")) return source, target
class Vocab: """Vocabulary for text.""" def __init__(self, tokens=None, min_freq=0, reserved_tokens=None): if tokens is None: tokens = [] if reserved_tokens is None: reserved_tokens = [] # Sort according to frequencies counter = count_corpus(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The index for the unknown token is 0 self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens uniq_tokens += [token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens] self.idx_to_token, self.token_to_idx = [], dict() for token in uniq_tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] return [self.idx_to_token[index] for index in indices] def count_corpus(tokens): """Count token frequencies.""" # Here `tokens` is a 1D list or 2D list if len(tokens) == 0 or isinstance(tokens[0], list): # Flatten a list of token lists into a list of tokens tokens = [token for line in tokens for token in line] return collections.Counter(tokens)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs) astype = lambda x, *args, **kwargs: x.type(*args, **kwargs) def build_array_nmt(lines, vocab, num_steps): """Transform text sequences of machine translation into minibatches.""" lines = [vocab[l] for l in lines] lines = [l + [vocab["<eos>"]] for l in lines] array = torch.tensor([truncate_pad(l, num_steps, vocab["<pad>"]) for l in lines]) valid_len = reduce_sum(astype(array != vocab["<pad>"], torch.int32), 1) return array, valid_len
def load_array(data_arrays, batch_size, is_train=True): """Construct a PyTorch data iterator.""" dataset = data.TensorDataset(*data_arrays) return data.DataLoader(dataset, batch_size, shuffle=is_train) def truncate_pad(line, num_steps, padding_token): """Truncate or pad sequences.""" if len(line) > num_steps: return line[:num_steps] # Truncate return line + [padding_token] * (num_steps - len(line)) def load_data_nmt(batch_size, num_steps, num_examples=600): """Return the iterator and the vocabularies of the translation dataset.""" text = preprocess_nmt(read_data_nmt()) source, target = tokenize_nmt(text, num_examples) src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"]) src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps) tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps) data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len) data_iter = load_array(data_arrays, batch_size) return data_iter, src_vocab, tgt_vocab

Data

We use a english-french dataset. See this colab for details.

DATA_HUB = dict() DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/" DATA_HUB["fra-eng"] = (DATA_URL + "fra-eng.zip", "94646ad1522d915e7b0f9296181140edcf86a4f5") batch_size, num_steps = 64, 10 train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps)
Downloading ../data/fra-eng.zip from http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip...
num_hiddens, num_layers, dropout = 32, 2, 0.1 ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4 key_size, query_size, value_size = 32, 32, 32 norm_shape = [32] encoder = TransformerEncoder( len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, ) decoder = TransformerDecoder( len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, ) net = EncoderDecoder(encoder, decoder)

Training

class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid() def grad_clipping(net, theta): """Clip the gradient.""" if isinstance(net, nn.Module): params = [p for p in net.parameters() if p.requires_grad] else: params = net.params norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params)) if norm > theta: for param in params: param.grad[:] *= theta / norm def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().""" if torch.cuda.device_count() >= i + 1: return torch.device(f"cuda:{i}") return torch.device("cpu")
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): """The softmax cross-entropy loss with masks.""" # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`) # `label` shape: (`batch_size`, `num_steps`) # `valid_len` shape: (`batch_size`,) def forward(self, pred, label, valid_len): weights = torch.ones_like(label) weights = sequence_mask(weights, valid_len) self.reduction = "none" unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label) weighted_loss = (unweighted_loss * weights).mean(dim=1) return weighted_loss
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): """Train a model for sequence to sequence.""" def xavier_init_weights(m): if type(m) == nn.Linear: nn.init.xavier_uniform_(m.weight) if type(m) == nn.GRU: for param in m._flat_weights_names: if "weight" in param: nn.init.xavier_uniform_(m._parameters[param]) net.apply(xavier_init_weights) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss = MaskedSoftmaxCELoss() net.train() animator = Animator(xlabel="epoch", ylabel="loss", xlim=[10, num_epochs]) for epoch in range(num_epochs): timer = Timer() metric = Accumulator(2) # Sum of training loss, no. of tokens for batch in data_iter: X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] bos = torch.tensor([tgt_vocab["<bos>"]] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing Y_hat, _ = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # Make the loss scalar for `backward` grad_clipping(net, 1) num_tokens = Y_valid_len.sum() optimizer.step() with torch.no_grad(): metric.add(l.sum(), num_tokens) if (epoch + 1) % 10 == 0: animator.add(epoch + 1, (metric[0] / metric[1],)) print(f"loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} " f"tokens/sec on {str(device)}")
lr, num_epochs, device = 0.005, 200, try_gpu() train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.032, 7682.5 tokens/sec on cuda:0
Image in a Jupyter notebook

Evaluation

def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False): """Predict for sequence to sequence.""" # Set `net` to eval mode for inference net.eval() src_tokens = src_vocab[src_sentence.lower().split(" ")] + [src_vocab["<eos>"]] enc_valid_len = torch.tensor([len(src_tokens)], device=device) src_tokens = truncate_pad(src_tokens, num_steps, src_vocab["<pad>"]) # Add the batch axis enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0) enc_outputs = net.encoder(enc_X, enc_valid_len) dec_state = net.decoder.init_state(enc_outputs, enc_valid_len) # Add the batch axis dec_X = torch.unsqueeze(torch.tensor([tgt_vocab["<bos>"]], dtype=torch.long, device=device), dim=0) output_seq, attention_weight_seq = [], [] for _ in range(num_steps): Y, dec_state = net.decoder(dec_X, dec_state) # We use the token with the highest prediction likelihood as the input # of the decoder at the next time step dec_X = Y.argmax(dim=2) pred = dec_X.squeeze(dim=0).type(torch.int32).item() # Save attention weights (to be covered later) if save_attention_weights: attention_weight_seq.append(net.decoder.attention_weights) # Once the end-of-sequence token is predicted, the generation of the # output sequence is complete if pred == tgt_vocab["<eos>"]: break output_seq.append(pred) return " ".join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
def bleu(pred_seq, label_seq, k): """Compute the BLEU.""" pred_tokens, label_tokens = pred_seq.split(" "), label_seq.split(" ") len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, k + 1): num_matches, label_subs = 0, collections.defaultdict(int) for i in range(len_label - n + 1): label_subs["".join(label_tokens[i : i + n])] += 1 for i in range(len_pred - n + 1): if label_subs["".join(pred_tokens[i : i + n])] > 0: num_matches += 1 label_subs["".join(pred_tokens[i : i + n])] -= 1 score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score
engs = ["go .", "i lost .", "he's calm .", "i'm home ."] fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f"{eng} => {translation}, ", f"bleu {bleu(translation, fra, k=2):.3f}")
go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est calme ., bleu 1.000 i'm home . => je suis chez moi ., bleu 1.000

Visualization of attention heatmaps

We visualize the attention heatmaps for the last (english, french) pair, where the input has length 3 and the output has length 5.

The shape of the encoder self-attention weights is (number of encoder layers, number of attention heads, num_steps or number of queries, num_steps or number of key-value pairs).

enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads, -1, num_steps)) enc_attention_weights.shape
torch.Size([2, 4, 10, 10])

Encoder self-attention for each of the 2 encoder blocks. The input has length 4, so all keys are 0 after that.

def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap="Reds"): display.set_matplotlib_formats("svg") num_rows, num_cols = matrices.shape[0], matrices.shape[1] fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix.detach(), cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6)
show_heatmaps( enc_attention_weights.cpu(), xlabel="Key positions", ylabel="Query positions", titles=["Head %d" % i for i in range(1, 5)], figsize=(7, 3.5), )
Image in a Jupyter notebook

Next we visualize decoder attention heatmaps.

dec_attention_weights_2d = [ head[0].tolist() for step in dec_attention_weight_seq for attn in step for blk in attn for head in blk ] dec_attention_weights_filled = torch.tensor(pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values) dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps)) dec_self_attention_weights, dec_inter_attention_weights = dec_attention_weights.permute(1, 2, 3, 0, 4) dec_self_attention_weights.shape, dec_inter_attention_weights.shape
(torch.Size([2, 4, 6, 10]), torch.Size([2, 4, 6, 10]))

Decoder self-attention.

# Plus one to include the beginning-of-sequence token show_heatmaps( dec_self_attention_weights[:, :, :, : len(translation.split()) + 1], xlabel="Key positions", ylabel="Query positions", titles=["Head %d" % i for i in range(1, 5)], figsize=(7, 3.5), )
Image in a Jupyter notebook

Decoder encoder-attention.

show_heatmaps( dec_inter_attention_weights, xlabel="Key positions", ylabel="Query positions", titles=["Head %d" % i for i in range(1, 5)], figsize=(7, 3.5), )
Image in a Jupyter notebook