Path: blob/main/intermediate_source/seq2seq_translation_tutorial.py
1686 views
# -*- coding: utf-8 -*-1"""2NLP From Scratch: Translation with a Sequence to Sequence Network and Attention3*******************************************************************************4**Author**: `Sean Robertson <https://github.com/spro>`_56This tutorials is part of a three-part series:78* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__9* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__10* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__1112This is the third and final tutorial on doing **NLP From Scratch**, where we13write our own classes and functions to preprocess the data to do our NLP14modeling tasks.1516In this project we will be teaching a neural network to translate from17French to English.1819.. code-block:: sh2021[KEY: > input, = target, < output]2223> il est en train de peindre un tableau .24= he is painting a picture .25< he is painting a picture .2627> pourquoi ne pas essayer ce vin delicieux ?28= why not try that delicious wine ?29< why not try that delicious wine ?3031> elle n est pas poete mais romanciere .32= she is not a poet but a novelist .33< she not not a poet but a novelist .3435> vous etes trop maigre .36= you re too skinny .37< you re all alone .3839... to varying degrees of success.4041This is made possible by the simple but powerful idea of the `sequence42to sequence network <https://arxiv.org/abs/1409.3215>`__, in which two43recurrent neural networks work together to transform one sequence to44another. An encoder network condenses an input sequence into a vector,45and a decoder network unfolds that vector into a new sequence.4647.. figure:: /_static/img/seq-seq-images/seq2seq.png48:alt:4950To improve upon this model we'll use an `attention51mechanism <https://arxiv.org/abs/1409.0473>`__, which lets the decoder52learn to focus over a specific range of the input sequence.5354**Recommended Reading:**5556I assume you have at least installed PyTorch, know Python, and57understand Tensors:5859- https://pytorch.org/ For installation instructions60- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general61- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview62- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user636465It would also be useful to know about Sequence to Sequence networks and66how they work:6768- `Learning Phrase Representations using RNN Encoder-Decoder for69Statistical Machine Translation <https://arxiv.org/abs/1406.1078>`__70- `Sequence to Sequence Learning with Neural71Networks <https://arxiv.org/abs/1409.3215>`__72- `Neural Machine Translation by Jointly Learning to Align and73Translate <https://arxiv.org/abs/1409.0473>`__74- `A Neural Conversational Model <https://arxiv.org/abs/1506.05869>`__7576You will also find the previous tutorials on77:doc:`/intermediate/char_rnn_classification_tutorial`78and :doc:`/intermediate/char_rnn_generation_tutorial`79helpful as those concepts are very similar to the Encoder and Decoder80models, respectively.8182**Requirements**83"""84from __future__ import unicode_literals, print_function, division85from io import open86import unicodedata87import re88import random8990import torch91import torch.nn as nn92from torch import optim93import torch.nn.functional as F9495import numpy as np96from torch.utils.data import TensorDataset, DataLoader, RandomSampler9798device = torch.device("cuda" if torch.cuda.is_available() else "cpu")99100######################################################################101# Loading data files102# ==================103#104# The data for this project is a set of many thousands of English to105# French translation pairs.106#107# `This question on Open Data Stack108# Exchange <https://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages>`__109# pointed me to the open translation site https://tatoeba.org/ which has110# downloads available at https://tatoeba.org/eng/downloads - and better111# yet, someone did the extra work of splitting language pairs into112# individual text files here: https://www.manythings.org/anki/113#114# The English to French pairs are too big to include in the repository, so115# download to ``data/eng-fra.txt`` before continuing. The file is a tab116# separated list of translation pairs:117#118# .. code-block:: sh119#120# I am cold. J'ai froid.121#122# .. note::123# Download the data from124# `here <https://download.pytorch.org/tutorial/data.zip>`_125# and extract it to the current directory.126127######################################################################128# Similar to the character encoding used in the character-level RNN129# tutorials, we will be representing each word in a language as a one-hot130# vector, or giant vector of zeros except for a single one (at the index131# of the word). Compared to the dozens of characters that might exist in a132# language, there are many many more words, so the encoding vector is much133# larger. We will however cheat a bit and trim the data to only use a few134# thousand words per language.135#136# .. figure:: /_static/img/seq-seq-images/word-encoding.png137# :alt:138#139#140141142######################################################################143# We'll need a unique index per word to use as the inputs and targets of144# the networks later. To keep track of all this we will use a helper class145# called ``Lang`` which has word → index (``word2index``) and index → word146# (``index2word``) dictionaries, as well as a count of each word147# ``word2count`` which will be used to replace rare words later.148#149150SOS_token = 0151EOS_token = 1152153class Lang:154def __init__(self, name):155self.name = name156self.word2index = {}157self.word2count = {}158self.index2word = {0: "SOS", 1: "EOS"}159self.n_words = 2 # Count SOS and EOS160161def addSentence(self, sentence):162for word in sentence.split(' '):163self.addWord(word)164165def addWord(self, word):166if word not in self.word2index:167self.word2index[word] = self.n_words168self.word2count[word] = 1169self.index2word[self.n_words] = word170self.n_words += 1171else:172self.word2count[word] += 1173174175######################################################################176# The files are all in Unicode, to simplify we will turn Unicode177# characters to ASCII, make everything lowercase, and trim most178# punctuation.179#180181# Turn a Unicode string to plain ASCII, thanks to182# https://stackoverflow.com/a/518232/2809427183def unicodeToAscii(s):184return ''.join(185c for c in unicodedata.normalize('NFD', s)186if unicodedata.category(c) != 'Mn'187)188189# Lowercase, trim, and remove non-letter characters190def normalizeString(s):191s = unicodeToAscii(s.lower().strip())192s = re.sub(r"([.!?])", r" \1", s)193s = re.sub(r"[^a-zA-Z!?]+", r" ", s)194return s.strip()195196197######################################################################198# To read the data file we will split the file into lines, and then split199# lines into pairs. The files are all English → Other Language, so if we200# want to translate from Other Language → English I added the ``reverse``201# flag to reverse the pairs.202#203204def readLangs(lang1, lang2, reverse=False):205print("Reading lines...")206207# Read the file and split into lines208lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\209read().strip().split('\n')210211# Split every line into pairs and normalize212pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]213214# Reverse pairs, make Lang instances215if reverse:216pairs = [list(reversed(p)) for p in pairs]217input_lang = Lang(lang2)218output_lang = Lang(lang1)219else:220input_lang = Lang(lang1)221output_lang = Lang(lang2)222223return input_lang, output_lang, pairs224225226######################################################################227# Since there are a *lot* of example sentences and we want to train228# something quickly, we'll trim the data set to only relatively short and229# simple sentences. Here the maximum length is 10 words (that includes230# ending punctuation) and we're filtering to sentences that translate to231# the form "I am" or "He is" etc. (accounting for apostrophes replaced232# earlier).233#234235MAX_LENGTH = 10236237eng_prefixes = (238"i am ", "i m ",239"he is", "he s ",240"she is", "she s ",241"you are", "you re ",242"we are", "we re ",243"they are", "they re "244)245246def filterPair(p):247return len(p[0].split(' ')) < MAX_LENGTH and \248len(p[1].split(' ')) < MAX_LENGTH and \249p[1].startswith(eng_prefixes)250251252def filterPairs(pairs):253return [pair for pair in pairs if filterPair(pair)]254255256######################################################################257# The full process for preparing the data is:258#259# - Read text file and split into lines, split lines into pairs260# - Normalize text, filter by length and content261# - Make word lists from sentences in pairs262#263264def prepareData(lang1, lang2, reverse=False):265input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)266print("Read %s sentence pairs" % len(pairs))267pairs = filterPairs(pairs)268print("Trimmed to %s sentence pairs" % len(pairs))269print("Counting words...")270for pair in pairs:271input_lang.addSentence(pair[0])272output_lang.addSentence(pair[1])273print("Counted words:")274print(input_lang.name, input_lang.n_words)275print(output_lang.name, output_lang.n_words)276return input_lang, output_lang, pairs277278input_lang, output_lang, pairs = prepareData('eng', 'fra', True)279print(random.choice(pairs))280281282######################################################################283# The Seq2Seq Model284# =================285#286# A Recurrent Neural Network, or RNN, is a network that operates on a287# sequence and uses its own output as input for subsequent steps.288#289# A `Sequence to Sequence network <https://arxiv.org/abs/1409.3215>`__, or290# seq2seq network, or `Encoder Decoder291# network <https://arxiv.org/pdf/1406.1078v3.pdf>`__, is a model292# consisting of two RNNs called the encoder and decoder. The encoder reads293# an input sequence and outputs a single vector, and the decoder reads294# that vector to produce an output sequence.295#296# .. figure:: /_static/img/seq-seq-images/seq2seq.png297# :alt:298#299# Unlike sequence prediction with a single RNN, where every input300# corresponds to an output, the seq2seq model frees us from sequence301# length and order, which makes it ideal for translation between two302# languages.303#304# Consider the sentence ``Je ne suis pas le chat noir`` → ``I am not the305# black cat``. Most of the words in the input sentence have a direct306# translation in the output sentence, but are in slightly different307# orders, e.g. ``chat noir`` and ``black cat``. Because of the ``ne/pas``308# construction there is also one more word in the input sentence. It would309# be difficult to produce a correct translation directly from the sequence310# of input words.311#312# With a seq2seq model the encoder creates a single vector which, in the313# ideal case, encodes the "meaning" of the input sequence into a single314# vector — a single point in some N dimensional space of sentences.315#316317318######################################################################319# The Encoder320# -----------321#322# The encoder of a seq2seq network is a RNN that outputs some value for323# every word from the input sentence. For every input word the encoder324# outputs a vector and a hidden state, and uses the hidden state for the325# next input word.326#327# .. figure:: /_static/img/seq-seq-images/encoder-network.png328# :alt:329#330#331332class EncoderRNN(nn.Module):333def __init__(self, input_size, hidden_size, dropout_p=0.1):334super(EncoderRNN, self).__init__()335self.hidden_size = hidden_size336337self.embedding = nn.Embedding(input_size, hidden_size)338self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)339self.dropout = nn.Dropout(dropout_p)340341def forward(self, input):342embedded = self.dropout(self.embedding(input))343output, hidden = self.gru(embedded)344return output, hidden345346######################################################################347# The Decoder348# -----------349#350# The decoder is another RNN that takes the encoder output vector(s) and351# outputs a sequence of words to create the translation.352#353354355######################################################################356# Simple Decoder357# ^^^^^^^^^^^^^^358#359# In the simplest seq2seq decoder we use only last output of the encoder.360# This last output is sometimes called the *context vector* as it encodes361# context from the entire sequence. This context vector is used as the362# initial hidden state of the decoder.363#364# At every step of decoding, the decoder is given an input token and365# hidden state. The initial input token is the start-of-string ``<SOS>``366# token, and the first hidden state is the context vector (the encoder's367# last hidden state).368#369# .. figure:: /_static/img/seq-seq-images/decoder-network.png370# :alt:371#372#373374class DecoderRNN(nn.Module):375def __init__(self, hidden_size, output_size):376super(DecoderRNN, self).__init__()377self.embedding = nn.Embedding(output_size, hidden_size)378self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)379self.out = nn.Linear(hidden_size, output_size)380381def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):382batch_size = encoder_outputs.size(0)383decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)384decoder_hidden = encoder_hidden385decoder_outputs = []386387for i in range(MAX_LENGTH):388decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)389decoder_outputs.append(decoder_output)390391if target_tensor is not None:392# Teacher forcing: Feed the target as the next input393decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing394else:395# Without teacher forcing: use its own predictions as the next input396_, topi = decoder_output.topk(1)397decoder_input = topi.squeeze(-1).detach() # detach from history as input398399decoder_outputs = torch.cat(decoder_outputs, dim=1)400decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)401return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop402403def forward_step(self, input, hidden):404output = self.embedding(input)405output = F.relu(output)406output, hidden = self.gru(output, hidden)407output = self.out(output)408return output, hidden409410######################################################################411# I encourage you to train and observe the results of this model, but to412# save space we'll be going straight for the gold and introducing the413# Attention Mechanism.414#415416417######################################################################418# Attention Decoder419# ^^^^^^^^^^^^^^^^^420#421# If only the context vector is passed between the encoder and decoder,422# that single vector carries the burden of encoding the entire sentence.423#424# Attention allows the decoder network to "focus" on a different part of425# the encoder's outputs for every step of the decoder's own outputs. First426# we calculate a set of *attention weights*. These will be multiplied by427# the encoder output vectors to create a weighted combination. The result428# (called ``attn_applied`` in the code) should contain information about429# that specific part of the input sequence, and thus help the decoder430# choose the right output words.431#432# .. figure:: https://i.imgur.com/1152PYf.png433# :alt:434#435# Calculating the attention weights is done with another feed-forward436# layer ``attn``, using the decoder's input and hidden state as inputs.437# Because there are sentences of all sizes in the training data, to438# actually create and train this layer we have to choose a maximum439# sentence length (input length, for encoder outputs) that it can apply440# to. Sentences of the maximum length will use all the attention weights,441# while shorter sentences will only use the first few.442#443# .. figure:: /_static/img/seq-seq-images/attention-decoder-network.png444# :alt:445#446#447# Bahdanau attention, also known as additive attention, is a commonly used448# attention mechanism in sequence-to-sequence models, particularly in neural449# machine translation tasks. It was introduced by Bahdanau et al. in their450# paper titled `Neural Machine Translation by Jointly Learning to Align and Translate <https://arxiv.org/pdf/1409.0473.pdf>`__.451# This attention mechanism employs a learned alignment model to compute attention452# scores between the encoder and decoder hidden states. It utilizes a feed-forward453# neural network to calculate alignment scores.454#455# However, there are alternative attention mechanisms available, such as Luong attention,456# which computes attention scores by taking the dot product between the decoder hidden457# state and the encoder hidden states. It does not involve the non-linear transformation458# used in Bahdanau attention.459#460# In this tutorial, we will be using Bahdanau attention. However, it would be a valuable461# exercise to explore modifying the attention mechanism to use Luong attention.462463class BahdanauAttention(nn.Module):464def __init__(self, hidden_size):465super(BahdanauAttention, self).__init__()466self.Wa = nn.Linear(hidden_size, hidden_size)467self.Ua = nn.Linear(hidden_size, hidden_size)468self.Va = nn.Linear(hidden_size, 1)469470def forward(self, query, keys):471scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))472scores = scores.squeeze(2).unsqueeze(1)473474weights = F.softmax(scores, dim=-1)475context = torch.bmm(weights, keys)476477return context, weights478479class AttnDecoderRNN(nn.Module):480def __init__(self, hidden_size, output_size, dropout_p=0.1):481super(AttnDecoderRNN, self).__init__()482self.embedding = nn.Embedding(output_size, hidden_size)483self.attention = BahdanauAttention(hidden_size)484self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)485self.out = nn.Linear(hidden_size, output_size)486self.dropout = nn.Dropout(dropout_p)487488def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):489batch_size = encoder_outputs.size(0)490decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)491decoder_hidden = encoder_hidden492decoder_outputs = []493attentions = []494495for i in range(MAX_LENGTH):496decoder_output, decoder_hidden, attn_weights = self.forward_step(497decoder_input, decoder_hidden, encoder_outputs498)499decoder_outputs.append(decoder_output)500attentions.append(attn_weights)501502if target_tensor is not None:503# Teacher forcing: Feed the target as the next input504decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing505else:506# Without teacher forcing: use its own predictions as the next input507_, topi = decoder_output.topk(1)508decoder_input = topi.squeeze(-1).detach() # detach from history as input509510decoder_outputs = torch.cat(decoder_outputs, dim=1)511decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)512attentions = torch.cat(attentions, dim=1)513514return decoder_outputs, decoder_hidden, attentions515516517def forward_step(self, input, hidden, encoder_outputs):518embedded = self.dropout(self.embedding(input))519520query = hidden.permute(1, 0, 2)521context, attn_weights = self.attention(query, encoder_outputs)522input_gru = torch.cat((embedded, context), dim=2)523524output, hidden = self.gru(input_gru, hidden)525output = self.out(output)526527return output, hidden, attn_weights528529530######################################################################531# .. note:: There are other forms of attention that work around the length532# limitation by using a relative position approach. Read about "local533# attention" in `Effective Approaches to Attention-based Neural Machine534# Translation <https://arxiv.org/abs/1508.04025>`__.535#536# Training537# ========538#539# Preparing Training Data540# -----------------------541#542# To train, for each pair we will need an input tensor (indexes of the543# words in the input sentence) and target tensor (indexes of the words in544# the target sentence). While creating these vectors we will append the545# EOS token to both sequences.546#547548def indexesFromSentence(lang, sentence):549return [lang.word2index[word] for word in sentence.split(' ')]550551def tensorFromSentence(lang, sentence):552indexes = indexesFromSentence(lang, sentence)553indexes.append(EOS_token)554return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)555556def tensorsFromPair(pair):557input_tensor = tensorFromSentence(input_lang, pair[0])558target_tensor = tensorFromSentence(output_lang, pair[1])559return (input_tensor, target_tensor)560561def get_dataloader(batch_size):562input_lang, output_lang, pairs = prepareData('eng', 'fra', True)563564n = len(pairs)565input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)566target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)567568for idx, (inp, tgt) in enumerate(pairs):569inp_ids = indexesFromSentence(input_lang, inp)570tgt_ids = indexesFromSentence(output_lang, tgt)571inp_ids.append(EOS_token)572tgt_ids.append(EOS_token)573input_ids[idx, :len(inp_ids)] = inp_ids574target_ids[idx, :len(tgt_ids)] = tgt_ids575576train_data = TensorDataset(torch.LongTensor(input_ids).to(device),577torch.LongTensor(target_ids).to(device))578579train_sampler = RandomSampler(train_data)580train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)581return input_lang, output_lang, train_dataloader582583584######################################################################585# Training the Model586# ------------------587#588# To train we run the input sentence through the encoder, and keep track589# of every output and the latest hidden state. Then the decoder is given590# the ``<SOS>`` token as its first input, and the last hidden state of the591# encoder as its first hidden state.592#593# "Teacher forcing" is the concept of using the real target outputs as594# each next input, instead of using the decoder's guess as the next input.595# Using teacher forcing causes it to converge faster but `when the trained596# network is exploited, it may exhibit597# instability <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf>`__.598#599# You can observe outputs of teacher-forced networks that read with600# coherent grammar but wander far from the correct translation -601# intuitively it has learned to represent the output grammar and can "pick602# up" the meaning once the teacher tells it the first few words, but it603# has not properly learned how to create the sentence from the translation604# in the first place.605#606# Because of the freedom PyTorch's autograd gives us, we can randomly607# choose to use teacher forcing or not with a simple if statement. Turn608# ``teacher_forcing_ratio`` up to use more of it.609#610611def train_epoch(dataloader, encoder, decoder, encoder_optimizer,612decoder_optimizer, criterion):613614total_loss = 0615for data in dataloader:616input_tensor, target_tensor = data617618encoder_optimizer.zero_grad()619decoder_optimizer.zero_grad()620621encoder_outputs, encoder_hidden = encoder(input_tensor)622decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)623624loss = criterion(625decoder_outputs.view(-1, decoder_outputs.size(-1)),626target_tensor.view(-1)627)628loss.backward()629630encoder_optimizer.step()631decoder_optimizer.step()632633total_loss += loss.item()634635return total_loss / len(dataloader)636637638######################################################################639# This is a helper function to print time elapsed and estimated time640# remaining given the current time and progress %.641#642643import time644import math645646def asMinutes(s):647m = math.floor(s / 60)648s -= m * 60649return '%dm %ds' % (m, s)650651def timeSince(since, percent):652now = time.time()653s = now - since654es = s / (percent)655rs = es - s656return '%s (- %s)' % (asMinutes(s), asMinutes(rs))657658659######################################################################660# The whole training process looks like this:661#662# - Start a timer663# - Initialize optimizers and criterion664# - Create set of training pairs665# - Start empty losses array for plotting666#667# Then we call ``train`` many times and occasionally print the progress (%668# of examples, time so far, estimated time) and average loss.669#670671def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,672print_every=100, plot_every=100):673start = time.time()674plot_losses = []675print_loss_total = 0 # Reset every print_every676plot_loss_total = 0 # Reset every plot_every677678encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)679decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)680criterion = nn.NLLLoss()681682for epoch in range(1, n_epochs + 1):683loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)684print_loss_total += loss685plot_loss_total += loss686687if epoch % print_every == 0:688print_loss_avg = print_loss_total / print_every689print_loss_total = 0690print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),691epoch, epoch / n_epochs * 100, print_loss_avg))692693if epoch % plot_every == 0:694plot_loss_avg = plot_loss_total / plot_every695plot_losses.append(plot_loss_avg)696plot_loss_total = 0697698showPlot(plot_losses)699700######################################################################701# Plotting results702# ----------------703#704# Plotting is done with matplotlib, using the array of loss values705# ``plot_losses`` saved while training.706#707708import matplotlib.pyplot as plt709plt.switch_backend('agg')710import matplotlib.ticker as ticker711import numpy as np712713def showPlot(points):714plt.figure()715fig, ax = plt.subplots()716# this locator puts ticks at regular intervals717loc = ticker.MultipleLocator(base=0.2)718ax.yaxis.set_major_locator(loc)719plt.plot(points)720721722######################################################################723# Evaluation724# ==========725#726# Evaluation is mostly the same as training, but there are no targets so727# we simply feed the decoder's predictions back to itself for each step.728# Every time it predicts a word we add it to the output string, and if it729# predicts the EOS token we stop there. We also store the decoder's730# attention outputs for display later.731#732733def evaluate(encoder, decoder, sentence, input_lang, output_lang):734with torch.no_grad():735input_tensor = tensorFromSentence(input_lang, sentence)736737encoder_outputs, encoder_hidden = encoder(input_tensor)738decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)739740_, topi = decoder_outputs.topk(1)741decoded_ids = topi.squeeze()742743decoded_words = []744for idx in decoded_ids:745if idx.item() == EOS_token:746decoded_words.append('<EOS>')747break748decoded_words.append(output_lang.index2word[idx.item()])749return decoded_words, decoder_attn750751752######################################################################753# We can evaluate random sentences from the training set and print out the754# input, target, and output to make some subjective quality judgements:755#756757def evaluateRandomly(encoder, decoder, n=10):758for i in range(n):759pair = random.choice(pairs)760print('>', pair[0])761print('=', pair[1])762output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)763output_sentence = ' '.join(output_words)764print('<', output_sentence)765print('')766767768######################################################################769# Training and Evaluating770# =======================771#772# With all these helper functions in place (it looks like extra work, but773# it makes it easier to run multiple experiments) we can actually774# initialize a network and start training.775#776# Remember that the input sentences were heavily filtered. For this small777# dataset we can use relatively small networks of 256 hidden nodes and a778# single GRU layer. After about 40 minutes on a MacBook CPU we'll get some779# reasonable results.780#781# .. note::782# If you run this notebook you can train, interrupt the kernel,783# evaluate, and continue training later. Comment out the lines where the784# encoder and decoder are initialized and run ``trainIters`` again.785#786787hidden_size = 128788batch_size = 32789790input_lang, output_lang, train_dataloader = get_dataloader(batch_size)791792encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)793decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)794795train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)796797######################################################################798#799# Set dropout layers to ``eval`` mode800encoder.eval()801decoder.eval()802evaluateRandomly(encoder, decoder)803804805######################################################################806# Visualizing Attention807# ---------------------808#809# A useful property of the attention mechanism is its highly interpretable810# outputs. Because it is used to weight specific encoder outputs of the811# input sequence, we can imagine looking where the network is focused most812# at each time step.813#814# You could simply run ``plt.matshow(attentions)`` to see attention output815# displayed as a matrix. For a better viewing experience we will do the816# extra work of adding axes and labels:817#818819def showAttention(input_sentence, output_words, attentions):820fig = plt.figure()821ax = fig.add_subplot(111)822cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')823fig.colorbar(cax)824825# Set up axes826ax.set_xticklabels([''] + input_sentence.split(' ') +827['<EOS>'], rotation=90)828ax.set_yticklabels([''] + output_words)829830# Show label at every tick831ax.xaxis.set_major_locator(ticker.MultipleLocator(1))832ax.yaxis.set_major_locator(ticker.MultipleLocator(1))833834plt.show()835836837def evaluateAndShowAttention(input_sentence):838output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)839print('input =', input_sentence)840print('output =', ' '.join(output_words))841showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])842843844evaluateAndShowAttention('il n est pas aussi grand que son pere')845846evaluateAndShowAttention('je suis trop fatigue pour conduire')847848evaluateAndShowAttention('je suis desole si c est une question idiote')849850evaluateAndShowAttention('je suis reellement fiere de vous')851852853######################################################################854# Exercises855# =========856#857# - Try with a different dataset858#859# - Another language pair860# - Human → Machine (e.g. IOT commands)861# - Chat → Response862# - Question → Answer863#864# - Replace the embeddings with pretrained word embeddings such as ``word2vec`` or865# ``GloVe``866# - Try with more layers, more hidden units, and more sentences. Compare867# the training time and results.868# - If you use a translation file where pairs have two of the same phrase869# (``I am test \t I am test``), you can use this as an autoencoder. Try870# this:871#872# - Train as an autoencoder873# - Save only the Encoder network874# - Train a new Decoder for translation from there875#876877878