Path: blob/master/RNN/season1_refactored/RNN_intro_3.py
631 views
# Lab 12 RNN1import torch2import torch.nn as nn34torch.manual_seed(777) # reproducibility567idx2char = ['h', 'i', 'e', 'l', 'o']89# Teach hihell -> ihello10x_data = [[0, 1, 0, 2, 3, 3]] # hihell11x_one_hot = [[[1, 0, 0, 0, 0], # h 012[0, 1, 0, 0, 0], # i 113[1, 0, 0, 0, 0], # h 014[0, 0, 1, 0, 0], # e 215[0, 0, 0, 1, 0], # l 316[0, 0, 0, 1, 0]]] # l 31718y_data = [1, 0, 2, 3, 3, 4] # ihello1920# As we have one batch of samples, we will change them to variables only once21inputs = torch.Tensor(x_one_hot)22labels = torch.LongTensor(y_data)2324num_classes = 525input_size = 5 # one-hot size26hidden_size = 5 # output from the LSTM. 5 to directly predict one-hot27batch_size = 1 # one sentence28sequence_length = 6 # |ihello| == 629num_layers = 1 # one-layer rnn303132class RNN(nn.Module):3334def __init__(self, num_classes, input_size, hidden_size, num_layers):35super(RNN, self).__init__()3637self.num_classes = num_classes38self.num_layers = num_layers39self.input_size = input_size40self.hidden_size = hidden_size41self.sequence_length = sequence_length4243self.rnn = nn.RNN(input_size=self.input_size,44hidden_size=self.hidden_size,45batch_first=True)4647def forward(self, x):48# Initialize hidden and cell states49# (num_layers * num_directions, batch, hidden_size) for batch_first=True50h_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)5152# Propagate input through RNN53# Input: (batch, seq_len, input_size)54# h_0: (num_layers * num_directions, batch, hidden_size)5556out, _ = self.rnn(x, h_0)57return out.view(-1, self.num_classes)585960# Instantiate RNN model61rnn = RNN(num_classes, input_size, hidden_size, num_layers)62print(rnn)6364# Set loss and optimizer function65# CrossEntropyLoss = LogSoftmax + NLLLoss66criterion = torch.nn.CrossEntropyLoss()67optimizer = torch.optim.Adam(rnn.parameters(), lr=0.1)6869# Train the model70for epoch in range(100):71outputs = rnn(inputs)72optimizer.zero_grad()73loss = criterion(outputs, labels)74loss.backward()75optimizer.step()76_, idx = outputs.max(1)77idx = idx.data.numpy()78result_str = ''.join(idx2char[c] for c in idx.squeeze())79print(f'epoch: {epoch + 1}, loss: {loss.item():1.3f}')80print("Predicted string: ", result_str)8182print("Learning finished!")8384