Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
deeplearningzerotoall
GitHub Repository: deeplearningzerotoall/PyTorch
Path: blob/master/RNN/season1_refactored/RNN_intro_4.py
631 views
1
# Lab 12 RNN
2
import torch
3
import torch.nn as nn
4
5
torch.manual_seed(777) # reproducibility
6
7
8
idx2char = ['h', 'i', 'e', 'l', 'o']
9
10
# Teach hihell -> ihello
11
x_data = [[0, 1, 0, 2, 3, 3]] # hihell
12
y_data = [1, 0, 2, 3, 3, 4] # ihello
13
14
# As we have one batch of samples, we will change them to variables only once
15
inputs = torch.LongTensor(x_data)
16
labels = torch.LongTensor(y_data)
17
18
num_classes = 5
19
input_size = 5
20
embedding_size = 10 # embedding size
21
hidden_size = 5 # output from the LSTM. 5 to directly predict one-hot
22
batch_size = 1 # one sentence
23
sequence_length = 6 # |ihello| == 6
24
num_layers = 1 # one-layer rnn
25
26
27
class Model(nn.Module):
28
29
def __init__(self):
30
super(Model, self).__init__()
31
self.embedding = nn.Embedding(input_size, embedding_size)
32
self.rnn = nn.RNN(input_size=embedding_size,
33
hidden_size=hidden_size,
34
batch_first=True)
35
self.fc = nn.Linear(hidden_size, num_classes)
36
37
def forward(self, x):
38
# Initialize hidden and cell states
39
# (num_layers * num_directions, batch, hidden_size)
40
h_0 = torch.zeros(num_layers, x.size(0), hidden_size)
41
42
emb = self.embedding(x)
43
emb = emb.view(batch_size, sequence_length, -1)
44
45
# Propagate embedding through RNN
46
# Input: (batch, seq_len, embedding_size)
47
# h_0: (num_layers * num_directions, batch, hidden_size)
48
out, _ = self.rnn(emb, h_0)
49
return self.fc(out.view(-1, num_classes))
50
51
52
# Instantiate RNN model
53
model = Model()
54
print(model)
55
56
# Set loss and optimizer function
57
# CrossEntropyLoss = LogSoftmax + NLLLoss
58
criterion = torch.nn.CrossEntropyLoss()
59
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
60
61
# Train the model
62
for epoch in range(100):
63
outputs = model(inputs)
64
optimizer.zero_grad()
65
loss = criterion(outputs, labels)
66
loss.backward()
67
optimizer.step()
68
_, idx = outputs.max(1)
69
idx = idx.data.numpy()
70
result_str = ''.join(idx2char[c] for c in idx.squeeze())
71
print(f'epoch: {epoch + 1}, loss: {loss.item():1.3f}')
72
print("Predicted string: ", result_str)
73
74
print("Learning finished!")
75