Path: blob/master/RNN/season1_refactored/RNN_intro_2.py
629 views
# Lab 12 RNN1import torch2import torch.nn as nn34torch.manual_seed(777) # reproducibility56# 0 1 2 3 47idx2char = ['h', 'i', 'e', 'l', 'o']89# Teach hihell -> ihello10x_data = [0, 1, 0, 2, 3, 3] # hihell11one_hot_lookup = [[1, 0, 0, 0, 0], # 012[0, 1, 0, 0, 0], # 113[0, 0, 1, 0, 0], # 214[0, 0, 0, 1, 0], # 315[0, 0, 0, 0, 1]] # 41617y_data = [1, 0, 2, 3, 3, 4] # ihello18x_one_hot = [one_hot_lookup[x] for x in x_data]1920# As we have one batch of samples, we will change them to variables only once21inputs = torch.Tensor(x_one_hot)22labels = torch.LongTensor(y_data)2324num_classes = 525input_size = 5 # one-hot size26hidden_size = 5 # output from the RNN. 5 to directly predict one-hot27batch_size = 1 # one sentence28sequence_length = 1 # One by one29num_layers = 1 # one-layer rnn303132class Model(nn.Module):3334def __init__(self):35super(Model, self).__init__()36self.rnn = nn.RNN(input_size=input_size,37hidden_size=hidden_size,38batch_first=True)3940def forward(self, hidden, x):41# Reshape input (batch first)42x = x.view(batch_size, sequence_length, input_size)4344# Propagate input through RNN45# Input: (batch, seq_len, input_size)46# hidden: (num_layers * num_directions, batch, hidden_size)47out, hidden = self.rnn(x, hidden)48return out.view(-1, num_classes), hidden4950def init_hidden(self):51# Initialize hidden and cell states52# (num_layers * num_directions, batch, hidden_size)53return torch.zeros(num_layers, batch_size, hidden_size)545556# Instantiate RNN model57model = Model()58print(model)5960# Set loss and optimizer function61# CrossEntropyLoss = LogSoftmax + NLLLoss62criterion = nn.CrossEntropyLoss()63optimizer = torch.optim.Adam(model.parameters(), lr=0.1)6465print(inputs.size(), labels.size())66# Train the model67for epoch in range(100):68optimizer.zero_grad()69loss = 070hidden = model.init_hidden()7172print("predicted string: ", end='')73for input, label in zip(inputs, labels):74output, hidden = model(hidden, input)75val, idx = output.max(1)76print(idx2char[idx.data[0]], end='')77loss += criterion(output, label.reshape(-1))7879print(f', epoch: {epoch + 1}, loss: {loss.item():1.3f}')8081loss.backward()82optimizer.step()8384print("Learning finished!")858687