Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
deeplearningzerotoall
GitHub Repository: deeplearningzerotoall/PyTorch
Path: blob/master/RNN/6-seq2seq.py
618 views
1
# main reference
2
# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
3
4
import random
5
import torch
6
import torch.nn as nn
7
import torch.optim as optim
8
9
torch.manual_seed(0)
10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
12
raw = ["I feel hungry. 나는 배가 고프다.",
13
"Pytorch is very easy. 파이토치는 매우 쉽다.",
14
"Pytorch is a framework for deep learning. 파이토치는 딥러닝을 위한 프레임워크이다.",
15
"Pytorch is very clear to use. 파이토치는 사용하기 매우 직관적이다."]
16
17
SOS_token = 0
18
EOS_token = 1
19
20
21
class Vocab:
22
def __init__(self):
23
self.vocab2index = {"<SOS>": SOS_token, "<EOS>": EOS_token}
24
self.index2vocab = {SOS_token: "<SOS>", EOS_token: "<EOS>"}
25
self.vocab_count = {}
26
self.n_vocab = len(self.vocab2index)
27
28
def add_vocab(self, sentence):
29
for word in sentence.split(" "):
30
if word not in self.vocab2index:
31
self.vocab2index[word] = self.n_vocab
32
self.vocab_count[word] = 1
33
self.index2vocab[self.n_vocab] = word
34
self.n_vocab += 1
35
else:
36
self.vocab_count[word] += 1
37
38
39
def filter_pair(pair, source_max_length, target_max_length):
40
return len(pair[0].split(" ")) < source_max_length and len(pair[1].split(" ")) < target_max_length
41
42
43
def preprocess(corpus, source_max_length, target_max_length):
44
print("reading corpus...")
45
pairs = []
46
for line in corpus:
47
pairs.append([s for s in line.strip().lower().split("\t")])
48
print("Read {} sentence pairs".format(len(pairs)))
49
50
pairs = [pair for pair in pairs if filter_pair(pair, source_max_length, target_max_length)]
51
print("Trimmed to {} sentence pairs".format(len(pairs)))
52
53
source_vocab = Vocab()
54
target_vocab = Vocab()
55
56
print("Counting words...")
57
for pair in pairs:
58
source_vocab.add_vocab(pair[0])
59
target_vocab.add_vocab(pair[1])
60
print("source vocab size =", source_vocab.n_vocab)
61
print("target vocab size =", target_vocab.n_vocab)
62
63
return pairs, source_vocab, target_vocab
64
65
66
class Encoder(nn.Module):
67
def __init__(self, input_size, hidden_size):
68
super(Encoder, self).__init__()
69
self.hidden_size = hidden_size
70
self.embedding = nn.Embedding(input_size, hidden_size)
71
self.gru = nn.GRU(hidden_size, hidden_size)
72
73
def forward(self, x, hidden):
74
x = self.embedding(x).view(1, 1, -1)
75
x, hidden = self.gru(x, hidden)
76
return x, hidden
77
78
79
class Decoder(nn.Module):
80
def __init__(self, hidden_size, output_size):
81
super(Decoder, self).__init__()
82
self.hidden_size = hidden_size
83
self.embedding = nn.Embedding(output_size, hidden_size)
84
self.gru = nn.GRU(hidden_size, hidden_size)
85
self.out = nn.Linear(hidden_size, output_size)
86
self.softmax = nn.LogSoftmax(dim=1)
87
88
def forward(self, x, hidden):
89
x = self.embedding(x).view(1, 1, -1)
90
x, hidden = self.gru(x, hidden)
91
x = self.softmax(self.out(x[0]))
92
return x, hidden
93
94
95
def tensorize(vocab, sentence):
96
indexes = [vocab.vocab2index[word] for word in sentence.split(" ")]
97
indexes.append(vocab.vocab2index["<EOS>"])
98
return torch.Tensor(indexes).long().to(device).view(-1, 1)
99
100
101
def train(pairs, source_vocab, target_vocab, encoder, decoder, n_iter, print_every=1000, learning_rate=0.01):
102
loss_total = 0
103
104
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
105
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
106
107
training_batch = [random.choice(pairs) for _ in range(n_iter)]
108
training_source = [tensorize(source_vocab, pair[0]) for pair in training_batch]
109
training_target = [tensorize(target_vocab, pair[1]) for pair in training_batch]
110
111
criterion = nn.NLLLoss()
112
113
for i in range(1, n_iter + 1):
114
source_tensor = training_source[i - 1]
115
target_tensor = training_target[i - 1]
116
117
encoder_hidden = torch.zeros([1, 1, encoder.hidden_size]).to(device)
118
119
encoder_optimizer.zero_grad()
120
decoder_optimizer.zero_grad()
121
122
source_length = source_tensor.size(0)
123
target_length = target_tensor.size(0)
124
125
loss = 0
126
127
for enc_input in range(source_length):
128
_, encoder_hidden = encoder(source_tensor[enc_input], encoder_hidden)
129
130
decoder_input = torch.Tensor([[SOS_token]]).long().to(device)
131
decoder_hidden = encoder_hidden
132
133
for di in range(target_length):
134
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
135
loss += criterion(decoder_output, target_tensor[di])
136
decoder_input = target_tensor[di] # teacher forcing
137
138
loss.backward()
139
140
encoder_optimizer.step()
141
decoder_optimizer.step()
142
143
loss_iter = loss.item() / target_length
144
loss_total += loss_iter
145
146
if i % print_every == 0:
147
loss_avg = loss_total / print_every
148
loss_total = 0
149
print("[{} - {}%] loss = {:05.4f}".format(i, i / n_iter * 100, loss_avg))
150
151
152
def evaluate(pairs, source_vocab, target_vocab, encoder, decoder, target_max_length):
153
for pair in pairs:
154
print(">", pair[0])
155
print("=", pair[1])
156
source_tensor = tensorize(source_vocab, pair[0])
157
source_length = source_tensor.size()[0]
158
encoder_hidden = torch.zeros([1, 1, encoder.hidden_size]).to(device)
159
160
for ei in range(source_length):
161
_, encoder_hidden = encoder(source_tensor[ei], encoder_hidden)
162
163
decoder_input = torch.Tensor([[SOS_token]], device=device).long()
164
decoder_hidden = encoder_hidden
165
decoded_words = []
166
167
for di in range(target_max_length):
168
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
169
_, top_index = decoder_output.data.topk(1)
170
if top_index.item() == EOS_token:
171
decoded_words.append("<EOS>")
172
break
173
else:
174
decoded_words.append(target_vocab.index2vocab[top_index.item()])
175
176
decoder_input = top_index.squeeze().detach()
177
178
predict_words = decoded_words
179
predict_sentence = " ".join(predict_words)
180
print("<", predict_sentence)
181
print("")
182
183
184
SOURCE_MAX_LENGTH = 10
185
TARGET_MAX_LENGTH = 12
186
load_pairs, load_source_vocab, load_target_vocab = preprocess(raw, SOURCE_MAX_LENGTH, TARGET_MAX_LENGTH)
187
print(random.choice(load_pairs))
188
189
enc_hidden_size = 16
190
dec_hidden_size = enc_hidden_size
191
enc = Encoder(load_source_vocab.n_vocab, enc_hidden_size).to(device)
192
dec = Decoder(dec_hidden_size, load_target_vocab.n_vocab).to(device)
193
194
train(load_pairs, load_source_vocab, load_target_vocab, enc, dec, 5000, print_every=1000)
195
evaluate(load_pairs, load_source_vocab, load_target_vocab, enc, dec, TARGET_MAX_LENGTH)
196
197