Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/chatbot_tutorial.py
1686 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
Chatbot Tutorial
5
================
6
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
7
"""
8
9
10
######################################################################
11
# In this tutorial, we explore a fun and interesting use-case of recurrent
12
# sequence-to-sequence models. We will train a simple chatbot using movie
13
# scripts from the `Cornell Movie-Dialogs
14
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__.
15
#
16
# Conversational models are a hot topic in artificial intelligence
17
# research. Chatbots can be found in a variety of settings, including
18
# customer service applications and online helpdesks. These bots are often
19
# powered by retrieval-based models, which output predefined responses to
20
# questions of certain forms. In a highly restricted domain like a
21
# company’s IT helpdesk, these models may be sufficient, however, they are
22
# not robust enough for more general use-cases. Teaching a machine to
23
# carry out a meaningful conversation with a human in multiple domains is
24
# a research question that is far from solved. Recently, the deep learning
25
# boom has allowed for powerful generative models like Google’s `Neural
26
# Conversational Model <https://arxiv.org/abs/1506.05869>`__, which marks
27
# a large step towards multi-domain generative conversational models. In
28
# this tutorial, we will implement this kind of model in PyTorch.
29
#
30
# .. figure:: /_static/img/chatbot/bot.png
31
# :align: center
32
# :alt: bot
33
#
34
# .. code-block:: python
35
#
36
# > hello?
37
# Bot: hello .
38
# > where am I?
39
# Bot: you re in a hospital .
40
# > who are you?
41
# Bot: i m a lawyer .
42
# > how are you doing?
43
# Bot: i m fine .
44
# > are you my friend?
45
# Bot: no .
46
# > you're under arrest
47
# Bot: i m trying to help you !
48
# > i'm just kidding
49
# Bot: i m sorry .
50
# > where are you from?
51
# Bot: san francisco .
52
# > it's time for me to leave
53
# Bot: i know .
54
# > goodbye
55
# Bot: goodbye .
56
#
57
# **Tutorial Highlights**
58
#
59
# - Handle loading and preprocessing of `Cornell Movie-Dialogs
60
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
61
# dataset
62
# - Implement a sequence-to-sequence model with `Luong attention
63
# mechanism(s) <https://arxiv.org/abs/1508.04025>`__
64
# - Jointly train encoder and decoder models using mini-batches
65
# - Implement greedy-search decoding module
66
# - Interact with trained chatbot
67
#
68
# **Acknowledgments**
69
#
70
# This tutorial borrows code from the following sources:
71
#
72
# 1) Yuan-Kuei Wu’s pytorch-chatbot implementation:
73
# https://github.com/ywk991112/pytorch-chatbot
74
#
75
# 2) Sean Robertson’s practical-pytorch seq2seq-translation example:
76
# https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation
77
#
78
# 3) FloydHub Cornell Movie Corpus preprocessing code:
79
# https://github.com/floydhub/textutil-preprocess-cornell-movie-corpus
80
#
81
82
83
######################################################################
84
# Preparations
85
# ------------
86
#
87
# To get started, `download <https://zissou.infosci.cornell.edu/convokit/datasets/movie-corpus/movie-corpus.zip>`__ the Movie-Dialogs Corpus zip file.
88
89
# and put in a ``data/`` directory under the current directory.
90
#
91
# After that, let’s import some necessities.
92
#
93
94
import torch
95
from torch.jit import script, trace
96
import torch.nn as nn
97
from torch import optim
98
import torch.nn.functional as F
99
import csv
100
import random
101
import re
102
import os
103
import unicodedata
104
import codecs
105
from io import open
106
import itertools
107
import math
108
import json
109
110
111
# If the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__ is available,
112
# we will use it. Otherwise, we use the CPU.
113
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
114
print(f"Using {device} device")
115
116
117
######################################################################
118
# Load & Preprocess Data
119
# ----------------------
120
#
121
# The next step is to reformat our data file and load the data into
122
# structures that we can work with.
123
#
124
# The `Cornell Movie-Dialogs
125
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
126
# is a rich dataset of movie character dialog:
127
#
128
# - 220,579 conversational exchanges between 10,292 pairs of movie
129
# characters
130
# - 9,035 characters from 617 movies
131
# - 304,713 total utterances
132
#
133
# This dataset is large and diverse, and there is a great variation of
134
# language formality, time periods, sentiment, etc. Our hope is that this
135
# diversity makes our model robust to many forms of inputs and queries.
136
#
137
# First, we’ll take a look at some lines of our datafile to see the
138
# original format.
139
#
140
141
corpus_name = "movie-corpus"
142
corpus = os.path.join("data", corpus_name)
143
144
def printLines(file, n=10):
145
with open(file, 'rb') as datafile:
146
lines = datafile.readlines()
147
for line in lines[:n]:
148
print(line)
149
150
printLines(os.path.join(corpus, "utterances.jsonl"))
151
152
153
######################################################################
154
# Create formatted data file
155
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
156
#
157
# For convenience, we'll create a nicely formatted data file in which each line
158
# contains a tab-separated *query sentence* and a *response sentence* pair.
159
#
160
# The following functions facilitate the parsing of the raw
161
# ``utterances.jsonl`` data file.
162
#
163
# - ``loadLinesAndConversations`` splits each line of the file into a dictionary of
164
# lines with fields: ``lineID``, ``characterID``, and text and then groups them
165
# into conversations with fields: ``conversationID``, ``movieID``, and lines.
166
# - ``extractSentencePairs`` extracts pairs of sentences from
167
# conversations
168
#
169
170
# Splits each line of the file to create lines and conversations
171
def loadLinesAndConversations(fileName):
172
lines = {}
173
conversations = {}
174
with open(fileName, 'r', encoding='iso-8859-1') as f:
175
for line in f:
176
lineJson = json.loads(line)
177
# Extract fields for line object
178
lineObj = {}
179
lineObj["lineID"] = lineJson["id"]
180
lineObj["characterID"] = lineJson["speaker"]
181
lineObj["text"] = lineJson["text"]
182
lines[lineObj['lineID']] = lineObj
183
184
# Extract fields for conversation object
185
if lineJson["conversation_id"] not in conversations:
186
convObj = {}
187
convObj["conversationID"] = lineJson["conversation_id"]
188
convObj["movieID"] = lineJson["meta"]["movie_id"]
189
convObj["lines"] = [lineObj]
190
else:
191
convObj = conversations[lineJson["conversation_id"]]
192
convObj["lines"].insert(0, lineObj)
193
conversations[convObj["conversationID"]] = convObj
194
195
return lines, conversations
196
197
198
# Extracts pairs of sentences from conversations
199
def extractSentencePairs(conversations):
200
qa_pairs = []
201
for conversation in conversations.values():
202
# Iterate over all the lines of the conversation
203
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
204
inputLine = conversation["lines"][i]["text"].strip()
205
targetLine = conversation["lines"][i+1]["text"].strip()
206
# Filter wrong samples (if one of the lists is empty)
207
if inputLine and targetLine:
208
qa_pairs.append([inputLine, targetLine])
209
return qa_pairs
210
211
212
######################################################################
213
# Now we’ll call these functions and create the file. We’ll call it
214
# ``formatted_movie_lines.txt``.
215
#
216
217
# Define path to new file
218
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
219
220
delimiter = '\t'
221
# Unescape the delimiter
222
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
223
224
# Initialize lines dict and conversations dict
225
lines = {}
226
conversations = {}
227
# Load lines and conversations
228
print("\nProcessing corpus into lines and conversations...")
229
lines, conversations = loadLinesAndConversations(os.path.join(corpus, "utterances.jsonl"))
230
231
# Write new csv file
232
print("\nWriting newly formatted file...")
233
with open(datafile, 'w', encoding='utf-8') as outputfile:
234
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
235
for pair in extractSentencePairs(conversations):
236
writer.writerow(pair)
237
238
# Print a sample of lines
239
print("\nSample lines from file:")
240
printLines(datafile)
241
242
243
######################################################################
244
# Load and trim data
245
# ~~~~~~~~~~~~~~~~~~
246
#
247
# Our next order of business is to create a vocabulary and load
248
# query/response sentence pairs into memory.
249
#
250
# Note that we are dealing with sequences of **words**, which do not have
251
# an implicit mapping to a discrete numerical space. Thus, we must create
252
# one by mapping each unique word that we encounter in our dataset to an
253
# index value.
254
#
255
# For this we define a ``Voc`` class, which keeps a mapping from words to
256
# indexes, a reverse mapping of indexes to words, a count of each word and
257
# a total word count. The class provides methods for adding a word to the
258
# vocabulary (``addWord``), adding all words in a sentence
259
# (``addSentence``) and trimming infrequently seen words (``trim``). More
260
# on trimming later.
261
#
262
263
# Default word tokens
264
PAD_token = 0 # Used for padding short sentences
265
SOS_token = 1 # Start-of-sentence token
266
EOS_token = 2 # End-of-sentence token
267
268
class Voc:
269
def __init__(self, name):
270
self.name = name
271
self.trimmed = False
272
self.word2index = {}
273
self.word2count = {}
274
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
275
self.num_words = 3 # Count SOS, EOS, PAD
276
277
def addSentence(self, sentence):
278
for word in sentence.split(' '):
279
self.addWord(word)
280
281
def addWord(self, word):
282
if word not in self.word2index:
283
self.word2index[word] = self.num_words
284
self.word2count[word] = 1
285
self.index2word[self.num_words] = word
286
self.num_words += 1
287
else:
288
self.word2count[word] += 1
289
290
# Remove words below a certain count threshold
291
def trim(self, min_count):
292
if self.trimmed:
293
return
294
self.trimmed = True
295
296
keep_words = []
297
298
for k, v in self.word2count.items():
299
if v >= min_count:
300
keep_words.append(k)
301
302
print('keep_words {} / {} = {:.4f}'.format(
303
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
304
))
305
306
# Reinitialize dictionaries
307
self.word2index = {}
308
self.word2count = {}
309
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
310
self.num_words = 3 # Count default tokens
311
312
for word in keep_words:
313
self.addWord(word)
314
315
316
######################################################################
317
# Now we can assemble our vocabulary and query/response sentence pairs.
318
# Before we are ready to use this data, we must perform some
319
# preprocessing.
320
#
321
# First, we must convert the Unicode strings to ASCII using
322
# ``unicodeToAscii``. Next, we should convert all letters to lowercase and
323
# trim all non-letter characters except for basic punctuation
324
# (``normalizeString``). Finally, to aid in training convergence, we will
325
# filter out sentences with length greater than the ``MAX_LENGTH``
326
# threshold (``filterPairs``).
327
#
328
329
MAX_LENGTH = 10 # Maximum sentence length to consider
330
331
# Turn a Unicode string to plain ASCII, thanks to
332
# https://stackoverflow.com/a/518232/2809427
333
def unicodeToAscii(s):
334
return ''.join(
335
c for c in unicodedata.normalize('NFD', s)
336
if unicodedata.category(c) != 'Mn'
337
)
338
339
# Lowercase, trim, and remove non-letter characters
340
def normalizeString(s):
341
s = unicodeToAscii(s.lower().strip())
342
s = re.sub(r"([.!?])", r" \1", s)
343
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
344
s = re.sub(r"\s+", r" ", s).strip()
345
return s
346
347
# Read query/response pairs and return a voc object
348
def readVocs(datafile, corpus_name):
349
print("Reading lines...")
350
# Read the file and split into lines
351
lines = open(datafile, encoding='utf-8').\
352
read().strip().split('\n')
353
# Split every line into pairs and normalize
354
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
355
voc = Voc(corpus_name)
356
return voc, pairs
357
358
# Returns True if both sentences in a pair 'p' are under the MAX_LENGTH threshold
359
def filterPair(p):
360
# Input sequences need to preserve the last word for EOS token
361
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
362
363
# Filter pairs using the ``filterPair`` condition
364
def filterPairs(pairs):
365
return [pair for pair in pairs if filterPair(pair)]
366
367
# Using the functions defined above, return a populated voc object and pairs list
368
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
369
print("Start preparing training data ...")
370
voc, pairs = readVocs(datafile, corpus_name)
371
print("Read {!s} sentence pairs".format(len(pairs)))
372
pairs = filterPairs(pairs)
373
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
374
print("Counting words...")
375
for pair in pairs:
376
voc.addSentence(pair[0])
377
voc.addSentence(pair[1])
378
print("Counted words:", voc.num_words)
379
return voc, pairs
380
381
382
# Load/Assemble voc and pairs
383
save_dir = os.path.join("data", "save")
384
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
385
# Print some pairs to validate
386
print("\npairs:")
387
for pair in pairs[:10]:
388
print(pair)
389
390
391
######################################################################
392
# Another tactic that is beneficial to achieving faster convergence during
393
# training is trimming rarely used words out of our vocabulary. Decreasing
394
# the feature space will also soften the difficulty of the function that
395
# the model must learn to approximate. We will do this as a two-step
396
# process:
397
#
398
# 1) Trim words used under ``MIN_COUNT`` threshold using the ``voc.trim``
399
# function.
400
#
401
# 2) Filter out pairs with trimmed words.
402
#
403
404
MIN_COUNT = 3 # Minimum word count threshold for trimming
405
406
def trimRareWords(voc, pairs, MIN_COUNT):
407
# Trim words used under the MIN_COUNT from the voc
408
voc.trim(MIN_COUNT)
409
# Filter out pairs with trimmed words
410
keep_pairs = []
411
for pair in pairs:
412
input_sentence = pair[0]
413
output_sentence = pair[1]
414
keep_input = True
415
keep_output = True
416
# Check input sentence
417
for word in input_sentence.split(' '):
418
if word not in voc.word2index:
419
keep_input = False
420
break
421
# Check output sentence
422
for word in output_sentence.split(' '):
423
if word not in voc.word2index:
424
keep_output = False
425
break
426
427
# Only keep pairs that do not contain trimmed word(s) in their input or output sentence
428
if keep_input and keep_output:
429
keep_pairs.append(pair)
430
431
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
432
return keep_pairs
433
434
435
# Trim voc and pairs
436
pairs = trimRareWords(voc, pairs, MIN_COUNT)
437
438
439
######################################################################
440
# Prepare Data for Models
441
# -----------------------
442
#
443
# Although we have put a great deal of effort into preparing and massaging our
444
# data into a nice vocabulary object and list of sentence pairs, our models
445
# will ultimately expect numerical torch tensors as inputs. One way to
446
# prepare the processed data for the models can be found in the `seq2seq
447
# translation
448
# tutorial <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__.
449
# In that tutorial, we use a batch size of 1, meaning that all we have to
450
# do is convert the words in our sentence pairs to their corresponding
451
# indexes from the vocabulary and feed this to the models.
452
#
453
# However, if you’re interested in speeding up training and/or would like
454
# to leverage GPU parallelization capabilities, you will need to train
455
# with mini-batches.
456
#
457
# Using mini-batches also means that we must be mindful of the variation
458
# of sentence length in our batches. To accommodate sentences of different
459
# sizes in the same batch, we will make our batched input tensor of shape
460
# *(max_length, batch_size)*, where sentences shorter than the
461
# *max_length* are zero padded after an *EOS_token*.
462
#
463
# If we simply convert our English sentences to tensors by converting
464
# words to their indexes(\ ``indexesFromSentence``) and zero-pad, our
465
# tensor would have shape *(batch_size, max_length)* and indexing the
466
# first dimension would return a full sequence across all time-steps.
467
# However, we need to be able to index our batch along time, and across
468
# all sequences in the batch. Therefore, we transpose our input batch
469
# shape to *(max_length, batch_size)*, so that indexing across the first
470
# dimension returns a time step across all sentences in the batch. We
471
# handle this transpose implicitly in the ``zeroPadding`` function.
472
#
473
# .. figure:: /_static/img/chatbot/seq2seq_batches.png
474
# :align: center
475
# :alt: batches
476
#
477
# The ``inputVar`` function handles the process of converting sentences to
478
# tensor, ultimately creating a correctly shaped zero-padded tensor. It
479
# also returns a tensor of ``lengths`` for each of the sequences in the
480
# batch which will be passed to our decoder later.
481
#
482
# The ``outputVar`` function performs a similar function to ``inputVar``,
483
# but instead of returning a ``lengths`` tensor, it returns a binary mask
484
# tensor and a maximum target sentence length. The binary mask tensor has
485
# the same shape as the output target tensor, but every element that is a
486
# *PAD_token* is 0 and all others are 1.
487
#
488
# ``batch2TrainData`` simply takes a bunch of pairs and returns the input
489
# and target tensors using the aforementioned functions.
490
#
491
492
def indexesFromSentence(voc, sentence):
493
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
494
495
496
def zeroPadding(l, fillvalue=PAD_token):
497
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
498
499
def binaryMatrix(l, value=PAD_token):
500
m = []
501
for i, seq in enumerate(l):
502
m.append([])
503
for token in seq:
504
if token == PAD_token:
505
m[i].append(0)
506
else:
507
m[i].append(1)
508
return m
509
510
# Returns padded input sequence tensor and lengths
511
def inputVar(l, voc):
512
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
513
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
514
padList = zeroPadding(indexes_batch)
515
padVar = torch.LongTensor(padList)
516
return padVar, lengths
517
518
# Returns padded target sequence tensor, padding mask, and max target length
519
def outputVar(l, voc):
520
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
521
max_target_len = max([len(indexes) for indexes in indexes_batch])
522
padList = zeroPadding(indexes_batch)
523
mask = binaryMatrix(padList)
524
mask = torch.BoolTensor(mask)
525
padVar = torch.LongTensor(padList)
526
return padVar, mask, max_target_len
527
528
# Returns all items for a given batch of pairs
529
def batch2TrainData(voc, pair_batch):
530
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
531
input_batch, output_batch = [], []
532
for pair in pair_batch:
533
input_batch.append(pair[0])
534
output_batch.append(pair[1])
535
inp, lengths = inputVar(input_batch, voc)
536
output, mask, max_target_len = outputVar(output_batch, voc)
537
return inp, lengths, output, mask, max_target_len
538
539
540
# Example for validation
541
small_batch_size = 5
542
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
543
input_variable, lengths, target_variable, mask, max_target_len = batches
544
545
print("input_variable:", input_variable)
546
print("lengths:", lengths)
547
print("target_variable:", target_variable)
548
print("mask:", mask)
549
print("max_target_len:", max_target_len)
550
551
552
######################################################################
553
# Define Models
554
# -------------
555
#
556
# Seq2Seq Model
557
# ~~~~~~~~~~~~~
558
#
559
# The brains of our chatbot is a sequence-to-sequence (seq2seq) model. The
560
# goal of a seq2seq model is to take a variable-length sequence as an
561
# input, and return a variable-length sequence as an output using a
562
# fixed-sized model.
563
#
564
# `Sutskever et al. <https://arxiv.org/abs/1409.3215>`__ discovered that
565
# by using two separate recurrent neural nets together, we can accomplish
566
# this task. One RNN acts as an **encoder**, which encodes a variable
567
# length input sequence to a fixed-length context vector. In theory, this
568
# context vector (the final hidden layer of the RNN) will contain semantic
569
# information about the query sentence that is input to the bot. The
570
# second RNN is a **decoder**, which takes an input word and the context
571
# vector, and returns a guess for the next word in the sequence and a
572
# hidden state to use in the next iteration.
573
#
574
# .. figure:: /_static/img/chatbot/seq2seq_ts.png
575
# :align: center
576
# :alt: model
577
#
578
# Image source:
579
# https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/
580
#
581
582
583
######################################################################
584
# Encoder
585
# ~~~~~~~
586
#
587
# The encoder RNN iterates through the input sentence one token
588
# (e.g. word) at a time, at each time step outputting an “output” vector
589
# and a “hidden state” vector. The hidden state vector is then passed to
590
# the next time step, while the output vector is recorded. The encoder
591
# transforms the context it saw at each point in the sequence into a set
592
# of points in a high-dimensional space, which the decoder will use to
593
# generate a meaningful output for the given task.
594
#
595
# At the heart of our encoder is a multi-layered Gated Recurrent Unit,
596
# invented by `Cho et al. <https://arxiv.org/pdf/1406.1078v3.pdf>`__ in
597
# 2014. We will use a bidirectional variant of the GRU, meaning that there
598
# are essentially two independent RNNs: one that is fed the input sequence
599
# in normal sequential order, and one that is fed the input sequence in
600
# reverse order. The outputs of each network are summed at each time step.
601
# Using a bidirectional GRU will give us the advantage of encoding both
602
# past and future contexts.
603
#
604
# Bidirectional RNN:
605
#
606
# .. figure:: /_static/img/chatbot/RNN-bidirectional.png
607
# :width: 70%
608
# :align: center
609
# :alt: rnn_bidir
610
#
611
# Image source: https://colah.github.io/posts/2015-09-NN-Types-FP/
612
#
613
# Note that an ``embedding`` layer is used to encode our word indices in
614
# an arbitrarily sized feature space. For our models, this layer will map
615
# each word to a feature space of size *hidden_size*. When trained, these
616
# values should encode semantic similarity between similar meaning words.
617
#
618
# Finally, if passing a padded batch of sequences to an RNN module, we
619
# must pack and unpack padding around the RNN pass using
620
# ``nn.utils.rnn.pack_padded_sequence`` and
621
# ``nn.utils.rnn.pad_packed_sequence`` respectively.
622
#
623
# **Computation Graph:**
624
#
625
# 1) Convert word indexes to embeddings.
626
# 2) Pack padded batch of sequences for RNN module.
627
# 3) Forward pass through GRU.
628
# 4) Unpack padding.
629
# 5) Sum bidirectional GRU outputs.
630
# 6) Return output and final hidden state.
631
#
632
# **Inputs:**
633
#
634
# - ``input_seq``: batch of input sentences; shape=\ *(max_length,
635
# batch_size)*
636
# - ``input_lengths``: list of sentence lengths corresponding to each
637
# sentence in the batch; shape=\ *(batch_size)*
638
# - ``hidden``: hidden state; shape=\ *(n_layers x num_directions,
639
# batch_size, hidden_size)*
640
#
641
# **Outputs:**
642
#
643
# - ``outputs``: output features from the last hidden layer of the GRU
644
# (sum of bidirectional outputs); shape=\ *(max_length, batch_size,
645
# hidden_size)*
646
# - ``hidden``: updated hidden state from GRU; shape=\ *(n_layers x
647
# num_directions, batch_size, hidden_size)*
648
#
649
#
650
651
class EncoderRNN(nn.Module):
652
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
653
super(EncoderRNN, self).__init__()
654
self.n_layers = n_layers
655
self.hidden_size = hidden_size
656
self.embedding = embedding
657
658
# Initialize GRU; the input_size and hidden_size parameters are both set to 'hidden_size'
659
# because our input size is a word embedding with number of features == hidden_size
660
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
661
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
662
663
def forward(self, input_seq, input_lengths, hidden=None):
664
# Convert word indexes to embeddings
665
embedded = self.embedding(input_seq)
666
# Pack padded batch of sequences for RNN module
667
packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
668
# Forward pass through GRU
669
outputs, hidden = self.gru(packed, hidden)
670
# Unpack padding
671
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
672
# Sum bidirectional GRU outputs
673
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
674
# Return output and final hidden state
675
return outputs, hidden
676
677
678
######################################################################
679
# Decoder
680
# ~~~~~~~
681
#
682
# The decoder RNN generates the response sentence in a token-by-token
683
# fashion. It uses the encoder’s context vectors, and internal hidden
684
# states to generate the next word in the sequence. It continues
685
# generating words until it outputs an *EOS_token*, representing the end
686
# of the sentence. A common problem with a vanilla seq2seq decoder is that
687
# if we rely solely on the context vector to encode the entire input
688
# sequence’s meaning, it is likely that we will have information loss.
689
# This is especially the case when dealing with long input sequences,
690
# greatly limiting the capability of our decoder.
691
#
692
# To combat this, `Bahdanau et al. <https://arxiv.org/abs/1409.0473>`__
693
# created an “attention mechanism” that allows the decoder to pay
694
# attention to certain parts of the input sequence, rather than using the
695
# entire fixed context at every step.
696
#
697
# At a high level, attention is calculated using the decoder’s current
698
# hidden state and the encoder’s outputs. The output attention weights
699
# have the same shape as the input sequence, allowing us to multiply them
700
# by the encoder outputs, giving us a weighted sum which indicates the
701
# parts of encoder output to pay attention to. `Sean
702
# Robertson’s <https://github.com/spro>`__ figure describes this very
703
# well:
704
#
705
# .. figure:: /_static/img/chatbot/attn2.png
706
# :align: center
707
# :alt: attn2
708
#
709
# `Luong et al. <https://arxiv.org/abs/1508.04025>`__ improved upon
710
# Bahdanau et al.’s groundwork by creating “Global attention”. The key
711
# difference is that with “Global attention”, we consider all of the
712
# encoder’s hidden states, as opposed to Bahdanau et al.’s “Local
713
# attention”, which only considers the encoder’s hidden state from the
714
# current time step. Another difference is that with “Global attention”,
715
# we calculate attention weights, or energies, using the hidden state of
716
# the decoder from the current time step only. Bahdanau et al.’s attention
717
# calculation requires knowledge of the decoder’s state from the previous
718
# time step. Also, Luong et al. provides various methods to calculate the
719
# attention energies between the encoder output and decoder output which
720
# are called “score functions”:
721
#
722
# .. figure:: /_static/img/chatbot/scores.png
723
# :width: 60%
724
# :align: center
725
# :alt: scores
726
#
727
# where :math:`h_t` = current target decoder state and :math:`\bar{h}_s` =
728
# all encoder states.
729
#
730
# Overall, the Global attention mechanism can be summarized by the
731
# following figure. Note that we will implement the “Attention Layer” as a
732
# separate ``nn.Module`` called ``Attn``. The output of this module is a
733
# softmax normalized weights tensor of shape *(batch_size, 1,
734
# max_length)*.
735
#
736
# .. figure:: /_static/img/chatbot/global_attn.png
737
# :align: center
738
# :width: 60%
739
# :alt: global_attn
740
#
741
742
# Luong attention layer
743
class Attn(nn.Module):
744
def __init__(self, method, hidden_size):
745
super(Attn, self).__init__()
746
self.method = method
747
if self.method not in ['dot', 'general', 'concat']:
748
raise ValueError(self.method, "is not an appropriate attention method.")
749
self.hidden_size = hidden_size
750
if self.method == 'general':
751
self.attn = nn.Linear(self.hidden_size, hidden_size)
752
elif self.method == 'concat':
753
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
754
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
755
756
def dot_score(self, hidden, encoder_output):
757
return torch.sum(hidden * encoder_output, dim=2)
758
759
def general_score(self, hidden, encoder_output):
760
energy = self.attn(encoder_output)
761
return torch.sum(hidden * energy, dim=2)
762
763
def concat_score(self, hidden, encoder_output):
764
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
765
return torch.sum(self.v * energy, dim=2)
766
767
def forward(self, hidden, encoder_outputs):
768
# Calculate the attention weights (energies) based on the given method
769
if self.method == 'general':
770
attn_energies = self.general_score(hidden, encoder_outputs)
771
elif self.method == 'concat':
772
attn_energies = self.concat_score(hidden, encoder_outputs)
773
elif self.method == 'dot':
774
attn_energies = self.dot_score(hidden, encoder_outputs)
775
776
# Transpose max_length and batch_size dimensions
777
attn_energies = attn_energies.t()
778
779
# Return the softmax normalized probability scores (with added dimension)
780
return F.softmax(attn_energies, dim=1).unsqueeze(1)
781
782
783
######################################################################
784
# Now that we have defined our attention submodule, we can implement the
785
# actual decoder model. For the decoder, we will manually feed our batch
786
# one time step at a time. This means that our embedded word tensor and
787
# GRU output will both have shape *(1, batch_size, hidden_size)*.
788
#
789
# **Computation Graph:**
790
#
791
# 1) Get embedding of current input word.
792
# 2) Forward through unidirectional GRU.
793
# 3) Calculate attention weights from the current GRU output from (2).
794
# 4) Multiply attention weights to encoder outputs to get new "weighted sum" context vector.
795
# 5) Concatenate weighted context vector and GRU output using Luong eq. 5.
796
# 6) Predict next word using Luong eq. 6 (without softmax).
797
# 7) Return output and final hidden state.
798
#
799
# **Inputs:**
800
#
801
# - ``input_step``: one time step (one word) of input sequence batch;
802
# shape=\ *(1, batch_size)*
803
# - ``last_hidden``: final hidden layer of GRU; shape=\ *(n_layers x
804
# num_directions, batch_size, hidden_size)*
805
# - ``encoder_outputs``: encoder model’s output; shape=\ *(max_length,
806
# batch_size, hidden_size)*
807
#
808
# **Outputs:**
809
#
810
# - ``output``: softmax normalized tensor giving probabilities of each
811
# word being the correct next word in the decoded sequence;
812
# shape=\ *(batch_size, voc.num_words)*
813
# - ``hidden``: final hidden state of GRU; shape=\ *(n_layers x
814
# num_directions, batch_size, hidden_size)*
815
#
816
817
class LuongAttnDecoderRNN(nn.Module):
818
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
819
super(LuongAttnDecoderRNN, self).__init__()
820
821
# Keep for reference
822
self.attn_model = attn_model
823
self.hidden_size = hidden_size
824
self.output_size = output_size
825
self.n_layers = n_layers
826
self.dropout = dropout
827
828
# Define layers
829
self.embedding = embedding
830
self.embedding_dropout = nn.Dropout(dropout)
831
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
832
self.concat = nn.Linear(hidden_size * 2, hidden_size)
833
self.out = nn.Linear(hidden_size, output_size)
834
835
self.attn = Attn(attn_model, hidden_size)
836
837
def forward(self, input_step, last_hidden, encoder_outputs):
838
# Note: we run this one step (word) at a time
839
# Get embedding of current input word
840
embedded = self.embedding(input_step)
841
embedded = self.embedding_dropout(embedded)
842
# Forward through unidirectional GRU
843
rnn_output, hidden = self.gru(embedded, last_hidden)
844
# Calculate attention weights from the current GRU output
845
attn_weights = self.attn(rnn_output, encoder_outputs)
846
# Multiply attention weights to encoder outputs to get new "weighted sum" context vector
847
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
848
# Concatenate weighted context vector and GRU output using Luong eq. 5
849
rnn_output = rnn_output.squeeze(0)
850
context = context.squeeze(1)
851
concat_input = torch.cat((rnn_output, context), 1)
852
concat_output = torch.tanh(self.concat(concat_input))
853
# Predict next word using Luong eq. 6
854
output = self.out(concat_output)
855
output = F.softmax(output, dim=1)
856
# Return output and final hidden state
857
return output, hidden
858
859
860
######################################################################
861
# Define Training Procedure
862
# -------------------------
863
#
864
# Masked loss
865
# ~~~~~~~~~~~
866
#
867
# Since we are dealing with batches of padded sequences, we cannot simply
868
# consider all elements of the tensor when calculating loss. We define
869
# ``maskNLLLoss`` to calculate our loss based on our decoder’s output
870
# tensor, the target tensor, and a binary mask tensor describing the
871
# padding of the target tensor. This loss function calculates the average
872
# negative log likelihood of the elements that correspond to a *1* in the
873
# mask tensor.
874
#
875
876
def maskNLLLoss(inp, target, mask):
877
nTotal = mask.sum()
878
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
879
loss = crossEntropy.masked_select(mask).mean()
880
loss = loss.to(device)
881
return loss, nTotal.item()
882
883
884
######################################################################
885
# Single training iteration
886
# ~~~~~~~~~~~~~~~~~~~~~~~~~
887
#
888
# The ``train`` function contains the algorithm for a single training
889
# iteration (a single batch of inputs).
890
#
891
# We will use a couple of clever tricks to aid in convergence:
892
#
893
# - The first trick is using **teacher forcing**. This means that at some
894
# probability, set by ``teacher_forcing_ratio``, we use the current
895
# target word as the decoder’s next input rather than using the
896
# decoder’s current guess. This technique acts as training wheels for
897
# the decoder, aiding in more efficient training. However, teacher
898
# forcing can lead to model instability during inference, as the
899
# decoder may not have a sufficient chance to truly craft its own
900
# output sequences during training. Thus, we must be mindful of how we
901
# are setting the ``teacher_forcing_ratio``, and not be fooled by fast
902
# convergence.
903
#
904
# - The second trick that we implement is **gradient clipping**. This is
905
# a commonly used technique for countering the “exploding gradient”
906
# problem. In essence, by clipping or thresholding gradients to a
907
# maximum value, we prevent the gradients from growing exponentially
908
# and either overflow (NaN), or overshoot steep cliffs in the cost
909
# function.
910
#
911
# .. figure:: /_static/img/chatbot/grad_clip.png
912
# :align: center
913
# :width: 60%
914
# :alt: grad_clip
915
#
916
# Image source: Goodfellow et al. *Deep Learning*. 2016. https://www.deeplearningbook.org/
917
#
918
# **Sequence of Operations:**
919
#
920
# 1) Forward pass entire input batch through encoder.
921
# 2) Initialize decoder inputs as SOS_token, and hidden state as the encoder's final hidden state.
922
# 3) Forward input batch sequence through decoder one time step at a time.
923
# 4) If teacher forcing: set next decoder input as the current target; else: set next decoder input as current decoder output.
924
# 5) Calculate and accumulate loss.
925
# 6) Perform backpropagation.
926
# 7) Clip gradients.
927
# 8) Update encoder and decoder model parameters.
928
#
929
#
930
# .. Note ::
931
#
932
# PyTorch’s RNN modules (``RNN``, ``LSTM``, ``GRU``) can be used like any
933
# other non-recurrent layers by simply passing them the entire input
934
# sequence (or batch of sequences). We use the ``GRU`` layer like this in
935
# the ``encoder``. The reality is that under the hood, there is an
936
# iterative process looping over each time step calculating hidden states.
937
# Alternatively, you can run these modules one time-step at a time. In
938
# this case, we manually loop over the sequences during the training
939
# process like we must do for the ``decoder`` model. As long as you
940
# maintain the correct conceptual model of these modules, implementing
941
# sequential models can be very straightforward.
942
#
943
#
944
945
946
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
947
encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):
948
949
# Zero gradients
950
encoder_optimizer.zero_grad()
951
decoder_optimizer.zero_grad()
952
953
# Set device options
954
input_variable = input_variable.to(device)
955
target_variable = target_variable.to(device)
956
mask = mask.to(device)
957
# Lengths for RNN packing should always be on the CPU
958
lengths = lengths.to("cpu")
959
960
# Initialize variables
961
loss = 0
962
print_losses = []
963
n_totals = 0
964
965
# Forward pass through encoder
966
encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
967
968
# Create initial decoder input (start with SOS tokens for each sentence)
969
decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
970
decoder_input = decoder_input.to(device)
971
972
# Set initial decoder hidden state to the encoder's final hidden state
973
decoder_hidden = encoder_hidden[:decoder.n_layers]
974
975
# Determine if we are using teacher forcing this iteration
976
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
977
978
# Forward batch of sequences through decoder one time step at a time
979
if use_teacher_forcing:
980
for t in range(max_target_len):
981
decoder_output, decoder_hidden = decoder(
982
decoder_input, decoder_hidden, encoder_outputs
983
)
984
# Teacher forcing: next input is current target
985
decoder_input = target_variable[t].view(1, -1)
986
# Calculate and accumulate loss
987
mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
988
loss += mask_loss
989
print_losses.append(mask_loss.item() * nTotal)
990
n_totals += nTotal
991
else:
992
for t in range(max_target_len):
993
decoder_output, decoder_hidden = decoder(
994
decoder_input, decoder_hidden, encoder_outputs
995
)
996
# No teacher forcing: next input is decoder's own current output
997
_, topi = decoder_output.topk(1)
998
decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
999
decoder_input = decoder_input.to(device)
1000
# Calculate and accumulate loss
1001
mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
1002
loss += mask_loss
1003
print_losses.append(mask_loss.item() * nTotal)
1004
n_totals += nTotal
1005
1006
# Perform backpropagation
1007
loss.backward()
1008
1009
# Clip gradients: gradients are modified in place
1010
_ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
1011
_ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)
1012
1013
# Adjust model weights
1014
encoder_optimizer.step()
1015
decoder_optimizer.step()
1016
1017
return sum(print_losses) / n_totals
1018
1019
1020
######################################################################
1021
# Training iterations
1022
# ~~~~~~~~~~~~~~~~~~~
1023
#
1024
# It is finally time to tie the full training procedure together with the
1025
# data. The ``trainIters`` function is responsible for running
1026
# ``n_iterations`` of training given the passed models, optimizers, data,
1027
# etc. This function is quite self explanatory, as we have done the heavy
1028
# lifting with the ``train`` function.
1029
#
1030
# One thing to note is that when we save our model, we save a tarball
1031
# containing the encoder and decoder ``state_dicts`` (parameters), the
1032
# optimizers’ ``state_dicts``, the loss, the iteration, etc. Saving the model
1033
# in this way will give us the ultimate flexibility with the checkpoint.
1034
# After loading a checkpoint, we will be able to use the model parameters
1035
# to run inference, or we can continue training right where we left off.
1036
#
1037
1038
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):
1039
1040
# Load batches for each iteration
1041
training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
1042
for _ in range(n_iteration)]
1043
1044
# Initializations
1045
print('Initializing ...')
1046
start_iteration = 1
1047
print_loss = 0
1048
if loadFilename:
1049
start_iteration = checkpoint['iteration'] + 1
1050
1051
# Training loop
1052
print("Training...")
1053
for iteration in range(start_iteration, n_iteration + 1):
1054
training_batch = training_batches[iteration - 1]
1055
# Extract fields from batch
1056
input_variable, lengths, target_variable, mask, max_target_len = training_batch
1057
1058
# Run a training iteration with batch
1059
loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
1060
decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
1061
print_loss += loss
1062
1063
# Print progress
1064
if iteration % print_every == 0:
1065
print_loss_avg = print_loss / print_every
1066
print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
1067
print_loss = 0
1068
1069
# Save checkpoint
1070
if (iteration % save_every == 0):
1071
directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
1072
if not os.path.exists(directory):
1073
os.makedirs(directory)
1074
torch.save({
1075
'iteration': iteration,
1076
'en': encoder.state_dict(),
1077
'de': decoder.state_dict(),
1078
'en_opt': encoder_optimizer.state_dict(),
1079
'de_opt': decoder_optimizer.state_dict(),
1080
'loss': loss,
1081
'voc_dict': voc.__dict__,
1082
'embedding': embedding.state_dict()
1083
}, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))
1084
1085
1086
######################################################################
1087
# Define Evaluation
1088
# -----------------
1089
#
1090
# After training a model, we want to be able to talk to the bot ourselves.
1091
# First, we must define how we want the model to decode the encoded input.
1092
#
1093
# Greedy decoding
1094
# ~~~~~~~~~~~~~~~
1095
#
1096
# Greedy decoding is the decoding method that we use during training when
1097
# we are **NOT** using teacher forcing. In other words, for each time
1098
# step, we simply choose the word from ``decoder_output`` with the highest
1099
# softmax value. This decoding method is optimal on a single time-step
1100
# level.
1101
#
1102
# To facilitate the greedy decoding operation, we define a
1103
# ``GreedySearchDecoder`` class. When run, an object of this class takes
1104
# an input sequence (``input_seq``) of shape *(input_seq length, 1)*, a
1105
# scalar input length (``input_length``) tensor, and a ``max_length`` to
1106
# bound the response sentence length. The input sentence is evaluated
1107
# using the following computational graph:
1108
#
1109
# **Computation Graph:**
1110
#
1111
# 1) Forward input through encoder model.
1112
# 2) Prepare encoder's final hidden layer to be first hidden input to the decoder.
1113
# 3) Initialize decoder's first input as SOS_token.
1114
# 4) Initialize tensors to append decoded words to.
1115
# 5) Iteratively decode one word token at a time:
1116
# a) Forward pass through decoder.
1117
# b) Obtain most likely word token and its softmax score.
1118
# c) Record token and score.
1119
# d) Prepare current token to be next decoder input.
1120
# 6) Return collections of word tokens and scores.
1121
#
1122
1123
class GreedySearchDecoder(nn.Module):
1124
def __init__(self, encoder, decoder):
1125
super(GreedySearchDecoder, self).__init__()
1126
self.encoder = encoder
1127
self.decoder = decoder
1128
1129
def forward(self, input_seq, input_length, max_length):
1130
# Forward input through encoder model
1131
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
1132
# Prepare encoder's final hidden layer to be first hidden input to the decoder
1133
decoder_hidden = encoder_hidden[:self.decoder.n_layers]
1134
# Initialize decoder input with SOS_token
1135
decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
1136
# Initialize tensors to append decoded words to
1137
all_tokens = torch.zeros([0], device=device, dtype=torch.long)
1138
all_scores = torch.zeros([0], device=device)
1139
# Iteratively decode one word token at a time
1140
for _ in range(max_length):
1141
# Forward pass through decoder
1142
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
1143
# Obtain most likely word token and its softmax score
1144
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
1145
# Record token and score
1146
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
1147
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
1148
# Prepare current token to be next decoder input (add a dimension)
1149
decoder_input = torch.unsqueeze(decoder_input, 0)
1150
# Return collections of word tokens and scores
1151
return all_tokens, all_scores
1152
1153
1154
######################################################################
1155
# Evaluate my text
1156
# ~~~~~~~~~~~~~~~~
1157
#
1158
# Now that we have our decoding method defined, we can write functions for
1159
# evaluating a string input sentence. The ``evaluate`` function manages
1160
# the low-level process of handling the input sentence. We first format
1161
# the sentence as an input batch of word indexes with *batch_size==1*. We
1162
# do this by converting the words of the sentence to their corresponding
1163
# indexes, and transposing the dimensions to prepare the tensor for our
1164
# models. We also create a ``lengths`` tensor which contains the length of
1165
# our input sentence. In this case, ``lengths`` is scalar because we are
1166
# only evaluating one sentence at a time (batch_size==1). Next, we obtain
1167
# the decoded response sentence tensor using our ``GreedySearchDecoder``
1168
# object (``searcher``). Finally, we convert the response’s indexes to
1169
# words and return the list of decoded words.
1170
#
1171
# ``evaluateInput`` acts as the user interface for our chatbot. When
1172
# called, an input text field will spawn in which we can enter our query
1173
# sentence. After typing our input sentence and pressing *Enter*, our text
1174
# is normalized in the same way as our training data, and is ultimately
1175
# fed to the ``evaluate`` function to obtain a decoded output sentence. We
1176
# loop this process, so we can keep chatting with our bot until we enter
1177
# either “q” or “quit”.
1178
#
1179
# Finally, if a sentence is entered that contains a word that is not in
1180
# the vocabulary, we handle this gracefully by printing an error message
1181
# and prompting the user to enter another sentence.
1182
#
1183
1184
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
1185
### Format input sentence as a batch
1186
# words -> indexes
1187
indexes_batch = [indexesFromSentence(voc, sentence)]
1188
# Create lengths tensor
1189
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
1190
# Transpose dimensions of batch to match models' expectations
1191
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
1192
# Use appropriate device
1193
input_batch = input_batch.to(device)
1194
lengths = lengths.to("cpu")
1195
# Decode sentence with searcher
1196
tokens, scores = searcher(input_batch, lengths, max_length)
1197
# indexes -> words
1198
decoded_words = [voc.index2word[token.item()] for token in tokens]
1199
return decoded_words
1200
1201
1202
def evaluateInput(encoder, decoder, searcher, voc):
1203
input_sentence = ''
1204
while(1):
1205
try:
1206
# Get input sentence
1207
input_sentence = input('> ')
1208
# Check if it is quit case
1209
if input_sentence == 'q' or input_sentence == 'quit': break
1210
# Normalize sentence
1211
input_sentence = normalizeString(input_sentence)
1212
# Evaluate sentence
1213
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
1214
# Format and print response sentence
1215
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
1216
print('Bot:', ' '.join(output_words))
1217
1218
except KeyError:
1219
print("Error: Encountered unknown word.")
1220
1221
1222
######################################################################
1223
# Run Model
1224
# ---------
1225
#
1226
# Finally, it is time to run our model!
1227
#
1228
# Regardless of whether we want to train or test the chatbot model, we
1229
# must initialize the individual encoder and decoder models. In the
1230
# following block, we set our desired configurations, choose to start from
1231
# scratch or set a checkpoint to load from, and build and initialize the
1232
# models. Feel free to play with different model configurations to
1233
# optimize performance.
1234
#
1235
1236
# Configure models
1237
model_name = 'cb_model'
1238
attn_model = 'dot'
1239
#``attn_model = 'general'``
1240
#``attn_model = 'concat'``
1241
hidden_size = 500
1242
encoder_n_layers = 2
1243
decoder_n_layers = 2
1244
dropout = 0.1
1245
batch_size = 64
1246
1247
# Set checkpoint to load from; set to None if starting from scratch
1248
loadFilename = None
1249
checkpoint_iter = 4000
1250
1251
#############################################################
1252
# Sample code to load from a checkpoint:
1253
#
1254
# .. code-block:: python
1255
#
1256
# loadFilename = os.path.join(save_dir, model_name, corpus_name,
1257
# '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
1258
# '{}_checkpoint.tar'.format(checkpoint_iter))
1259
1260
# Load model if a ``loadFilename`` is provided
1261
if loadFilename:
1262
# If loading on same machine the model was trained on
1263
checkpoint = torch.load(loadFilename)
1264
# If loading a model trained on GPU to CPU
1265
#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
1266
encoder_sd = checkpoint['en']
1267
decoder_sd = checkpoint['de']
1268
encoder_optimizer_sd = checkpoint['en_opt']
1269
decoder_optimizer_sd = checkpoint['de_opt']
1270
embedding_sd = checkpoint['embedding']
1271
voc.__dict__ = checkpoint['voc_dict']
1272
1273
1274
print('Building encoder and decoder ...')
1275
# Initialize word embeddings
1276
embedding = nn.Embedding(voc.num_words, hidden_size)
1277
if loadFilename:
1278
embedding.load_state_dict(embedding_sd)
1279
# Initialize encoder & decoder models
1280
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
1281
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
1282
if loadFilename:
1283
encoder.load_state_dict(encoder_sd)
1284
decoder.load_state_dict(decoder_sd)
1285
# Use appropriate device
1286
encoder = encoder.to(device)
1287
decoder = decoder.to(device)
1288
print('Models built and ready to go!')
1289
1290
1291
######################################################################
1292
# Run Training
1293
# ~~~~~~~~~~~~
1294
#
1295
# Run the following block if you want to train the model.
1296
#
1297
# First we set training parameters, then we initialize our optimizers, and
1298
# finally we call the ``trainIters`` function to run our training
1299
# iterations.
1300
#
1301
1302
# Configure training/optimization
1303
clip = 50.0
1304
teacher_forcing_ratio = 1.0
1305
learning_rate = 0.0001
1306
decoder_learning_ratio = 5.0
1307
n_iteration = 4000
1308
print_every = 1
1309
save_every = 500
1310
1311
# Ensure dropout layers are in train mode
1312
encoder.train()
1313
decoder.train()
1314
1315
# Initialize optimizers
1316
print('Building optimizers ...')
1317
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
1318
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
1319
if loadFilename:
1320
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
1321
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
1322
1323
# If you have an accelerator, configure it to call
1324
for state in encoder_optimizer.state.values():
1325
for k, v in state.items():
1326
if isinstance(v, torch.Tensor):
1327
state[k] = v.to(device)
1328
1329
for state in decoder_optimizer.state.values():
1330
for k, v in state.items():
1331
if isinstance(v, torch.Tensor):
1332
state[k] = v.to(device)
1333
1334
# Run training iterations
1335
print("Starting Training!")
1336
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
1337
embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
1338
print_every, save_every, clip, corpus_name, loadFilename)
1339
1340
1341
######################################################################
1342
# Run Evaluation
1343
# ~~~~~~~~~~~~~~
1344
#
1345
# To chat with your model, run the following block.
1346
#
1347
1348
# Set dropout layers to ``eval`` mode
1349
encoder.eval()
1350
decoder.eval()
1351
1352
# Initialize search module
1353
searcher = GreedySearchDecoder(encoder, decoder)
1354
1355
# Begin chatting (uncomment and run the following line to begin)
1356
# evaluateInput(encoder, decoder, searcher, voc)
1357
1358
1359
######################################################################
1360
# Conclusion
1361
# ----------
1362
#
1363
# That’s all for this one, folks. Congratulations, you now know the
1364
# fundamentals to building a generative chatbot model! If you’re
1365
# interested, you can try tailoring the chatbot’s behavior by tweaking the
1366
# model and training parameters and customizing the data that you train
1367
# the model on.
1368
#
1369
# Check out the other tutorials for more cool deep learning applications
1370
# in PyTorch!
1371
#
1372
1373