Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/seq2seq_translation_tutorial.py
1686 views
1
# -*- coding: utf-8 -*-
2
"""
3
NLP From Scratch: Translation with a Sequence to Sequence Network and Attention
4
*******************************************************************************
5
**Author**: `Sean Robertson <https://github.com/spro>`_
6
7
This tutorials is part of a three-part series:
8
9
* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__
10
* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__
11
* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__
12
13
This is the third and final tutorial on doing **NLP From Scratch**, where we
14
write our own classes and functions to preprocess the data to do our NLP
15
modeling tasks.
16
17
In this project we will be teaching a neural network to translate from
18
French to English.
19
20
.. code-block:: sh
21
22
[KEY: > input, = target, < output]
23
24
> il est en train de peindre un tableau .
25
= he is painting a picture .
26
< he is painting a picture .
27
28
> pourquoi ne pas essayer ce vin delicieux ?
29
= why not try that delicious wine ?
30
< why not try that delicious wine ?
31
32
> elle n est pas poete mais romanciere .
33
= she is not a poet but a novelist .
34
< she not not a poet but a novelist .
35
36
> vous etes trop maigre .
37
= you re too skinny .
38
< you re all alone .
39
40
... to varying degrees of success.
41
42
This is made possible by the simple but powerful idea of the `sequence
43
to sequence network <https://arxiv.org/abs/1409.3215>`__, in which two
44
recurrent neural networks work together to transform one sequence to
45
another. An encoder network condenses an input sequence into a vector,
46
and a decoder network unfolds that vector into a new sequence.
47
48
.. figure:: /_static/img/seq-seq-images/seq2seq.png
49
:alt:
50
51
To improve upon this model we'll use an `attention
52
mechanism <https://arxiv.org/abs/1409.0473>`__, which lets the decoder
53
learn to focus over a specific range of the input sequence.
54
55
**Recommended Reading:**
56
57
I assume you have at least installed PyTorch, know Python, and
58
understand Tensors:
59
60
- https://pytorch.org/ For installation instructions
61
- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general
62
- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview
63
- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user
64
65
66
It would also be useful to know about Sequence to Sequence networks and
67
how they work:
68
69
- `Learning Phrase Representations using RNN Encoder-Decoder for
70
Statistical Machine Translation <https://arxiv.org/abs/1406.1078>`__
71
- `Sequence to Sequence Learning with Neural
72
Networks <https://arxiv.org/abs/1409.3215>`__
73
- `Neural Machine Translation by Jointly Learning to Align and
74
Translate <https://arxiv.org/abs/1409.0473>`__
75
- `A Neural Conversational Model <https://arxiv.org/abs/1506.05869>`__
76
77
You will also find the previous tutorials on
78
:doc:`/intermediate/char_rnn_classification_tutorial`
79
and :doc:`/intermediate/char_rnn_generation_tutorial`
80
helpful as those concepts are very similar to the Encoder and Decoder
81
models, respectively.
82
83
**Requirements**
84
"""
85
from __future__ import unicode_literals, print_function, division
86
from io import open
87
import unicodedata
88
import re
89
import random
90
91
import torch
92
import torch.nn as nn
93
from torch import optim
94
import torch.nn.functional as F
95
96
import numpy as np
97
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
98
99
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
101
######################################################################
102
# Loading data files
103
# ==================
104
#
105
# The data for this project is a set of many thousands of English to
106
# French translation pairs.
107
#
108
# `This question on Open Data Stack
109
# Exchange <https://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages>`__
110
# pointed me to the open translation site https://tatoeba.org/ which has
111
# downloads available at https://tatoeba.org/eng/downloads - and better
112
# yet, someone did the extra work of splitting language pairs into
113
# individual text files here: https://www.manythings.org/anki/
114
#
115
# The English to French pairs are too big to include in the repository, so
116
# download to ``data/eng-fra.txt`` before continuing. The file is a tab
117
# separated list of translation pairs:
118
#
119
# .. code-block:: sh
120
#
121
# I am cold. J'ai froid.
122
#
123
# .. note::
124
# Download the data from
125
# `here <https://download.pytorch.org/tutorial/data.zip>`_
126
# and extract it to the current directory.
127
128
######################################################################
129
# Similar to the character encoding used in the character-level RNN
130
# tutorials, we will be representing each word in a language as a one-hot
131
# vector, or giant vector of zeros except for a single one (at the index
132
# of the word). Compared to the dozens of characters that might exist in a
133
# language, there are many many more words, so the encoding vector is much
134
# larger. We will however cheat a bit and trim the data to only use a few
135
# thousand words per language.
136
#
137
# .. figure:: /_static/img/seq-seq-images/word-encoding.png
138
# :alt:
139
#
140
#
141
142
143
######################################################################
144
# We'll need a unique index per word to use as the inputs and targets of
145
# the networks later. To keep track of all this we will use a helper class
146
# called ``Lang`` which has word → index (``word2index``) and index → word
147
# (``index2word``) dictionaries, as well as a count of each word
148
# ``word2count`` which will be used to replace rare words later.
149
#
150
151
SOS_token = 0
152
EOS_token = 1
153
154
class Lang:
155
def __init__(self, name):
156
self.name = name
157
self.word2index = {}
158
self.word2count = {}
159
self.index2word = {0: "SOS", 1: "EOS"}
160
self.n_words = 2 # Count SOS and EOS
161
162
def addSentence(self, sentence):
163
for word in sentence.split(' '):
164
self.addWord(word)
165
166
def addWord(self, word):
167
if word not in self.word2index:
168
self.word2index[word] = self.n_words
169
self.word2count[word] = 1
170
self.index2word[self.n_words] = word
171
self.n_words += 1
172
else:
173
self.word2count[word] += 1
174
175
176
######################################################################
177
# The files are all in Unicode, to simplify we will turn Unicode
178
# characters to ASCII, make everything lowercase, and trim most
179
# punctuation.
180
#
181
182
# Turn a Unicode string to plain ASCII, thanks to
183
# https://stackoverflow.com/a/518232/2809427
184
def unicodeToAscii(s):
185
return ''.join(
186
c for c in unicodedata.normalize('NFD', s)
187
if unicodedata.category(c) != 'Mn'
188
)
189
190
# Lowercase, trim, and remove non-letter characters
191
def normalizeString(s):
192
s = unicodeToAscii(s.lower().strip())
193
s = re.sub(r"([.!?])", r" \1", s)
194
s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
195
return s.strip()
196
197
198
######################################################################
199
# To read the data file we will split the file into lines, and then split
200
# lines into pairs. The files are all English → Other Language, so if we
201
# want to translate from Other Language → English I added the ``reverse``
202
# flag to reverse the pairs.
203
#
204
205
def readLangs(lang1, lang2, reverse=False):
206
print("Reading lines...")
207
208
# Read the file and split into lines
209
lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
210
read().strip().split('\n')
211
212
# Split every line into pairs and normalize
213
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
214
215
# Reverse pairs, make Lang instances
216
if reverse:
217
pairs = [list(reversed(p)) for p in pairs]
218
input_lang = Lang(lang2)
219
output_lang = Lang(lang1)
220
else:
221
input_lang = Lang(lang1)
222
output_lang = Lang(lang2)
223
224
return input_lang, output_lang, pairs
225
226
227
######################################################################
228
# Since there are a *lot* of example sentences and we want to train
229
# something quickly, we'll trim the data set to only relatively short and
230
# simple sentences. Here the maximum length is 10 words (that includes
231
# ending punctuation) and we're filtering to sentences that translate to
232
# the form "I am" or "He is" etc. (accounting for apostrophes replaced
233
# earlier).
234
#
235
236
MAX_LENGTH = 10
237
238
eng_prefixes = (
239
"i am ", "i m ",
240
"he is", "he s ",
241
"she is", "she s ",
242
"you are", "you re ",
243
"we are", "we re ",
244
"they are", "they re "
245
)
246
247
def filterPair(p):
248
return len(p[0].split(' ')) < MAX_LENGTH and \
249
len(p[1].split(' ')) < MAX_LENGTH and \
250
p[1].startswith(eng_prefixes)
251
252
253
def filterPairs(pairs):
254
return [pair for pair in pairs if filterPair(pair)]
255
256
257
######################################################################
258
# The full process for preparing the data is:
259
#
260
# - Read text file and split into lines, split lines into pairs
261
# - Normalize text, filter by length and content
262
# - Make word lists from sentences in pairs
263
#
264
265
def prepareData(lang1, lang2, reverse=False):
266
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
267
print("Read %s sentence pairs" % len(pairs))
268
pairs = filterPairs(pairs)
269
print("Trimmed to %s sentence pairs" % len(pairs))
270
print("Counting words...")
271
for pair in pairs:
272
input_lang.addSentence(pair[0])
273
output_lang.addSentence(pair[1])
274
print("Counted words:")
275
print(input_lang.name, input_lang.n_words)
276
print(output_lang.name, output_lang.n_words)
277
return input_lang, output_lang, pairs
278
279
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
280
print(random.choice(pairs))
281
282
283
######################################################################
284
# The Seq2Seq Model
285
# =================
286
#
287
# A Recurrent Neural Network, or RNN, is a network that operates on a
288
# sequence and uses its own output as input for subsequent steps.
289
#
290
# A `Sequence to Sequence network <https://arxiv.org/abs/1409.3215>`__, or
291
# seq2seq network, or `Encoder Decoder
292
# network <https://arxiv.org/pdf/1406.1078v3.pdf>`__, is a model
293
# consisting of two RNNs called the encoder and decoder. The encoder reads
294
# an input sequence and outputs a single vector, and the decoder reads
295
# that vector to produce an output sequence.
296
#
297
# .. figure:: /_static/img/seq-seq-images/seq2seq.png
298
# :alt:
299
#
300
# Unlike sequence prediction with a single RNN, where every input
301
# corresponds to an output, the seq2seq model frees us from sequence
302
# length and order, which makes it ideal for translation between two
303
# languages.
304
#
305
# Consider the sentence ``Je ne suis pas le chat noir`` → ``I am not the
306
# black cat``. Most of the words in the input sentence have a direct
307
# translation in the output sentence, but are in slightly different
308
# orders, e.g. ``chat noir`` and ``black cat``. Because of the ``ne/pas``
309
# construction there is also one more word in the input sentence. It would
310
# be difficult to produce a correct translation directly from the sequence
311
# of input words.
312
#
313
# With a seq2seq model the encoder creates a single vector which, in the
314
# ideal case, encodes the "meaning" of the input sequence into a single
315
# vector — a single point in some N dimensional space of sentences.
316
#
317
318
319
######################################################################
320
# The Encoder
321
# -----------
322
#
323
# The encoder of a seq2seq network is a RNN that outputs some value for
324
# every word from the input sentence. For every input word the encoder
325
# outputs a vector and a hidden state, and uses the hidden state for the
326
# next input word.
327
#
328
# .. figure:: /_static/img/seq-seq-images/encoder-network.png
329
# :alt:
330
#
331
#
332
333
class EncoderRNN(nn.Module):
334
def __init__(self, input_size, hidden_size, dropout_p=0.1):
335
super(EncoderRNN, self).__init__()
336
self.hidden_size = hidden_size
337
338
self.embedding = nn.Embedding(input_size, hidden_size)
339
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
340
self.dropout = nn.Dropout(dropout_p)
341
342
def forward(self, input):
343
embedded = self.dropout(self.embedding(input))
344
output, hidden = self.gru(embedded)
345
return output, hidden
346
347
######################################################################
348
# The Decoder
349
# -----------
350
#
351
# The decoder is another RNN that takes the encoder output vector(s) and
352
# outputs a sequence of words to create the translation.
353
#
354
355
356
######################################################################
357
# Simple Decoder
358
# ^^^^^^^^^^^^^^
359
#
360
# In the simplest seq2seq decoder we use only last output of the encoder.
361
# This last output is sometimes called the *context vector* as it encodes
362
# context from the entire sequence. This context vector is used as the
363
# initial hidden state of the decoder.
364
#
365
# At every step of decoding, the decoder is given an input token and
366
# hidden state. The initial input token is the start-of-string ``<SOS>``
367
# token, and the first hidden state is the context vector (the encoder's
368
# last hidden state).
369
#
370
# .. figure:: /_static/img/seq-seq-images/decoder-network.png
371
# :alt:
372
#
373
#
374
375
class DecoderRNN(nn.Module):
376
def __init__(self, hidden_size, output_size):
377
super(DecoderRNN, self).__init__()
378
self.embedding = nn.Embedding(output_size, hidden_size)
379
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
380
self.out = nn.Linear(hidden_size, output_size)
381
382
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
383
batch_size = encoder_outputs.size(0)
384
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
385
decoder_hidden = encoder_hidden
386
decoder_outputs = []
387
388
for i in range(MAX_LENGTH):
389
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
390
decoder_outputs.append(decoder_output)
391
392
if target_tensor is not None:
393
# Teacher forcing: Feed the target as the next input
394
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
395
else:
396
# Without teacher forcing: use its own predictions as the next input
397
_, topi = decoder_output.topk(1)
398
decoder_input = topi.squeeze(-1).detach() # detach from history as input
399
400
decoder_outputs = torch.cat(decoder_outputs, dim=1)
401
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
402
return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop
403
404
def forward_step(self, input, hidden):
405
output = self.embedding(input)
406
output = F.relu(output)
407
output, hidden = self.gru(output, hidden)
408
output = self.out(output)
409
return output, hidden
410
411
######################################################################
412
# I encourage you to train and observe the results of this model, but to
413
# save space we'll be going straight for the gold and introducing the
414
# Attention Mechanism.
415
#
416
417
418
######################################################################
419
# Attention Decoder
420
# ^^^^^^^^^^^^^^^^^
421
#
422
# If only the context vector is passed between the encoder and decoder,
423
# that single vector carries the burden of encoding the entire sentence.
424
#
425
# Attention allows the decoder network to "focus" on a different part of
426
# the encoder's outputs for every step of the decoder's own outputs. First
427
# we calculate a set of *attention weights*. These will be multiplied by
428
# the encoder output vectors to create a weighted combination. The result
429
# (called ``attn_applied`` in the code) should contain information about
430
# that specific part of the input sequence, and thus help the decoder
431
# choose the right output words.
432
#
433
# .. figure:: https://i.imgur.com/1152PYf.png
434
# :alt:
435
#
436
# Calculating the attention weights is done with another feed-forward
437
# layer ``attn``, using the decoder's input and hidden state as inputs.
438
# Because there are sentences of all sizes in the training data, to
439
# actually create and train this layer we have to choose a maximum
440
# sentence length (input length, for encoder outputs) that it can apply
441
# to. Sentences of the maximum length will use all the attention weights,
442
# while shorter sentences will only use the first few.
443
#
444
# .. figure:: /_static/img/seq-seq-images/attention-decoder-network.png
445
# :alt:
446
#
447
#
448
# Bahdanau attention, also known as additive attention, is a commonly used
449
# attention mechanism in sequence-to-sequence models, particularly in neural
450
# machine translation tasks. It was introduced by Bahdanau et al. in their
451
# paper titled `Neural Machine Translation by Jointly Learning to Align and Translate <https://arxiv.org/pdf/1409.0473.pdf>`__.
452
# This attention mechanism employs a learned alignment model to compute attention
453
# scores between the encoder and decoder hidden states. It utilizes a feed-forward
454
# neural network to calculate alignment scores.
455
#
456
# However, there are alternative attention mechanisms available, such as Luong attention,
457
# which computes attention scores by taking the dot product between the decoder hidden
458
# state and the encoder hidden states. It does not involve the non-linear transformation
459
# used in Bahdanau attention.
460
#
461
# In this tutorial, we will be using Bahdanau attention. However, it would be a valuable
462
# exercise to explore modifying the attention mechanism to use Luong attention.
463
464
class BahdanauAttention(nn.Module):
465
def __init__(self, hidden_size):
466
super(BahdanauAttention, self).__init__()
467
self.Wa = nn.Linear(hidden_size, hidden_size)
468
self.Ua = nn.Linear(hidden_size, hidden_size)
469
self.Va = nn.Linear(hidden_size, 1)
470
471
def forward(self, query, keys):
472
scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
473
scores = scores.squeeze(2).unsqueeze(1)
474
475
weights = F.softmax(scores, dim=-1)
476
context = torch.bmm(weights, keys)
477
478
return context, weights
479
480
class AttnDecoderRNN(nn.Module):
481
def __init__(self, hidden_size, output_size, dropout_p=0.1):
482
super(AttnDecoderRNN, self).__init__()
483
self.embedding = nn.Embedding(output_size, hidden_size)
484
self.attention = BahdanauAttention(hidden_size)
485
self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
486
self.out = nn.Linear(hidden_size, output_size)
487
self.dropout = nn.Dropout(dropout_p)
488
489
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
490
batch_size = encoder_outputs.size(0)
491
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
492
decoder_hidden = encoder_hidden
493
decoder_outputs = []
494
attentions = []
495
496
for i in range(MAX_LENGTH):
497
decoder_output, decoder_hidden, attn_weights = self.forward_step(
498
decoder_input, decoder_hidden, encoder_outputs
499
)
500
decoder_outputs.append(decoder_output)
501
attentions.append(attn_weights)
502
503
if target_tensor is not None:
504
# Teacher forcing: Feed the target as the next input
505
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
506
else:
507
# Without teacher forcing: use its own predictions as the next input
508
_, topi = decoder_output.topk(1)
509
decoder_input = topi.squeeze(-1).detach() # detach from history as input
510
511
decoder_outputs = torch.cat(decoder_outputs, dim=1)
512
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
513
attentions = torch.cat(attentions, dim=1)
514
515
return decoder_outputs, decoder_hidden, attentions
516
517
518
def forward_step(self, input, hidden, encoder_outputs):
519
embedded = self.dropout(self.embedding(input))
520
521
query = hidden.permute(1, 0, 2)
522
context, attn_weights = self.attention(query, encoder_outputs)
523
input_gru = torch.cat((embedded, context), dim=2)
524
525
output, hidden = self.gru(input_gru, hidden)
526
output = self.out(output)
527
528
return output, hidden, attn_weights
529
530
531
######################################################################
532
# .. note:: There are other forms of attention that work around the length
533
# limitation by using a relative position approach. Read about "local
534
# attention" in `Effective Approaches to Attention-based Neural Machine
535
# Translation <https://arxiv.org/abs/1508.04025>`__.
536
#
537
# Training
538
# ========
539
#
540
# Preparing Training Data
541
# -----------------------
542
#
543
# To train, for each pair we will need an input tensor (indexes of the
544
# words in the input sentence) and target tensor (indexes of the words in
545
# the target sentence). While creating these vectors we will append the
546
# EOS token to both sequences.
547
#
548
549
def indexesFromSentence(lang, sentence):
550
return [lang.word2index[word] for word in sentence.split(' ')]
551
552
def tensorFromSentence(lang, sentence):
553
indexes = indexesFromSentence(lang, sentence)
554
indexes.append(EOS_token)
555
return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)
556
557
def tensorsFromPair(pair):
558
input_tensor = tensorFromSentence(input_lang, pair[0])
559
target_tensor = tensorFromSentence(output_lang, pair[1])
560
return (input_tensor, target_tensor)
561
562
def get_dataloader(batch_size):
563
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
564
565
n = len(pairs)
566
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
567
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
568
569
for idx, (inp, tgt) in enumerate(pairs):
570
inp_ids = indexesFromSentence(input_lang, inp)
571
tgt_ids = indexesFromSentence(output_lang, tgt)
572
inp_ids.append(EOS_token)
573
tgt_ids.append(EOS_token)
574
input_ids[idx, :len(inp_ids)] = inp_ids
575
target_ids[idx, :len(tgt_ids)] = tgt_ids
576
577
train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
578
torch.LongTensor(target_ids).to(device))
579
580
train_sampler = RandomSampler(train_data)
581
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
582
return input_lang, output_lang, train_dataloader
583
584
585
######################################################################
586
# Training the Model
587
# ------------------
588
#
589
# To train we run the input sentence through the encoder, and keep track
590
# of every output and the latest hidden state. Then the decoder is given
591
# the ``<SOS>`` token as its first input, and the last hidden state of the
592
# encoder as its first hidden state.
593
#
594
# "Teacher forcing" is the concept of using the real target outputs as
595
# each next input, instead of using the decoder's guess as the next input.
596
# Using teacher forcing causes it to converge faster but `when the trained
597
# network is exploited, it may exhibit
598
# instability <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf>`__.
599
#
600
# You can observe outputs of teacher-forced networks that read with
601
# coherent grammar but wander far from the correct translation -
602
# intuitively it has learned to represent the output grammar and can "pick
603
# up" the meaning once the teacher tells it the first few words, but it
604
# has not properly learned how to create the sentence from the translation
605
# in the first place.
606
#
607
# Because of the freedom PyTorch's autograd gives us, we can randomly
608
# choose to use teacher forcing or not with a simple if statement. Turn
609
# ``teacher_forcing_ratio`` up to use more of it.
610
#
611
612
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
613
decoder_optimizer, criterion):
614
615
total_loss = 0
616
for data in dataloader:
617
input_tensor, target_tensor = data
618
619
encoder_optimizer.zero_grad()
620
decoder_optimizer.zero_grad()
621
622
encoder_outputs, encoder_hidden = encoder(input_tensor)
623
decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
624
625
loss = criterion(
626
decoder_outputs.view(-1, decoder_outputs.size(-1)),
627
target_tensor.view(-1)
628
)
629
loss.backward()
630
631
encoder_optimizer.step()
632
decoder_optimizer.step()
633
634
total_loss += loss.item()
635
636
return total_loss / len(dataloader)
637
638
639
######################################################################
640
# This is a helper function to print time elapsed and estimated time
641
# remaining given the current time and progress %.
642
#
643
644
import time
645
import math
646
647
def asMinutes(s):
648
m = math.floor(s / 60)
649
s -= m * 60
650
return '%dm %ds' % (m, s)
651
652
def timeSince(since, percent):
653
now = time.time()
654
s = now - since
655
es = s / (percent)
656
rs = es - s
657
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
658
659
660
######################################################################
661
# The whole training process looks like this:
662
#
663
# - Start a timer
664
# - Initialize optimizers and criterion
665
# - Create set of training pairs
666
# - Start empty losses array for plotting
667
#
668
# Then we call ``train`` many times and occasionally print the progress (%
669
# of examples, time so far, estimated time) and average loss.
670
#
671
672
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
673
print_every=100, plot_every=100):
674
start = time.time()
675
plot_losses = []
676
print_loss_total = 0 # Reset every print_every
677
plot_loss_total = 0 # Reset every plot_every
678
679
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
680
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
681
criterion = nn.NLLLoss()
682
683
for epoch in range(1, n_epochs + 1):
684
loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
685
print_loss_total += loss
686
plot_loss_total += loss
687
688
if epoch % print_every == 0:
689
print_loss_avg = print_loss_total / print_every
690
print_loss_total = 0
691
print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
692
epoch, epoch / n_epochs * 100, print_loss_avg))
693
694
if epoch % plot_every == 0:
695
plot_loss_avg = plot_loss_total / plot_every
696
plot_losses.append(plot_loss_avg)
697
plot_loss_total = 0
698
699
showPlot(plot_losses)
700
701
######################################################################
702
# Plotting results
703
# ----------------
704
#
705
# Plotting is done with matplotlib, using the array of loss values
706
# ``plot_losses`` saved while training.
707
#
708
709
import matplotlib.pyplot as plt
710
plt.switch_backend('agg')
711
import matplotlib.ticker as ticker
712
import numpy as np
713
714
def showPlot(points):
715
plt.figure()
716
fig, ax = plt.subplots()
717
# this locator puts ticks at regular intervals
718
loc = ticker.MultipleLocator(base=0.2)
719
ax.yaxis.set_major_locator(loc)
720
plt.plot(points)
721
722
723
######################################################################
724
# Evaluation
725
# ==========
726
#
727
# Evaluation is mostly the same as training, but there are no targets so
728
# we simply feed the decoder's predictions back to itself for each step.
729
# Every time it predicts a word we add it to the output string, and if it
730
# predicts the EOS token we stop there. We also store the decoder's
731
# attention outputs for display later.
732
#
733
734
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
735
with torch.no_grad():
736
input_tensor = tensorFromSentence(input_lang, sentence)
737
738
encoder_outputs, encoder_hidden = encoder(input_tensor)
739
decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)
740
741
_, topi = decoder_outputs.topk(1)
742
decoded_ids = topi.squeeze()
743
744
decoded_words = []
745
for idx in decoded_ids:
746
if idx.item() == EOS_token:
747
decoded_words.append('<EOS>')
748
break
749
decoded_words.append(output_lang.index2word[idx.item()])
750
return decoded_words, decoder_attn
751
752
753
######################################################################
754
# We can evaluate random sentences from the training set and print out the
755
# input, target, and output to make some subjective quality judgements:
756
#
757
758
def evaluateRandomly(encoder, decoder, n=10):
759
for i in range(n):
760
pair = random.choice(pairs)
761
print('>', pair[0])
762
print('=', pair[1])
763
output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
764
output_sentence = ' '.join(output_words)
765
print('<', output_sentence)
766
print('')
767
768
769
######################################################################
770
# Training and Evaluating
771
# =======================
772
#
773
# With all these helper functions in place (it looks like extra work, but
774
# it makes it easier to run multiple experiments) we can actually
775
# initialize a network and start training.
776
#
777
# Remember that the input sentences were heavily filtered. For this small
778
# dataset we can use relatively small networks of 256 hidden nodes and a
779
# single GRU layer. After about 40 minutes on a MacBook CPU we'll get some
780
# reasonable results.
781
#
782
# .. note::
783
# If you run this notebook you can train, interrupt the kernel,
784
# evaluate, and continue training later. Comment out the lines where the
785
# encoder and decoder are initialized and run ``trainIters`` again.
786
#
787
788
hidden_size = 128
789
batch_size = 32
790
791
input_lang, output_lang, train_dataloader = get_dataloader(batch_size)
792
793
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
794
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
795
796
train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)
797
798
######################################################################
799
#
800
# Set dropout layers to ``eval`` mode
801
encoder.eval()
802
decoder.eval()
803
evaluateRandomly(encoder, decoder)
804
805
806
######################################################################
807
# Visualizing Attention
808
# ---------------------
809
#
810
# A useful property of the attention mechanism is its highly interpretable
811
# outputs. Because it is used to weight specific encoder outputs of the
812
# input sequence, we can imagine looking where the network is focused most
813
# at each time step.
814
#
815
# You could simply run ``plt.matshow(attentions)`` to see attention output
816
# displayed as a matrix. For a better viewing experience we will do the
817
# extra work of adding axes and labels:
818
#
819
820
def showAttention(input_sentence, output_words, attentions):
821
fig = plt.figure()
822
ax = fig.add_subplot(111)
823
cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
824
fig.colorbar(cax)
825
826
# Set up axes
827
ax.set_xticklabels([''] + input_sentence.split(' ') +
828
['<EOS>'], rotation=90)
829
ax.set_yticklabels([''] + output_words)
830
831
# Show label at every tick
832
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
833
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
834
835
plt.show()
836
837
838
def evaluateAndShowAttention(input_sentence):
839
output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)
840
print('input =', input_sentence)
841
print('output =', ' '.join(output_words))
842
showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])
843
844
845
evaluateAndShowAttention('il n est pas aussi grand que son pere')
846
847
evaluateAndShowAttention('je suis trop fatigue pour conduire')
848
849
evaluateAndShowAttention('je suis desole si c est une question idiote')
850
851
evaluateAndShowAttention('je suis reellement fiere de vous')
852
853
854
######################################################################
855
# Exercises
856
# =========
857
#
858
# - Try with a different dataset
859
#
860
# - Another language pair
861
# - Human → Machine (e.g. IOT commands)
862
# - Chat → Response
863
# - Question → Answer
864
#
865
# - Replace the embeddings with pretrained word embeddings such as ``word2vec`` or
866
# ``GloVe``
867
# - Try with more layers, more hidden units, and more sentences. Compare
868
# the training time and results.
869
# - If you use a translation file where pairs have two of the same phrase
870
# (``I am test \t I am test``), you can use this as an autoencoder. Try
871
# this:
872
#
873
# - Train as an autoencoder
874
# - Save only the Encoder network
875
# - Train a new Decoder for translation from there
876
#
877
878