Path: blob/master/RNN/season1_refactored/RNN_intro_4.py
631 views
# Lab 12 RNN1import torch2import torch.nn as nn34torch.manual_seed(777) # reproducibility567idx2char = ['h', 'i', 'e', 'l', 'o']89# Teach hihell -> ihello10x_data = [[0, 1, 0, 2, 3, 3]] # hihell11y_data = [1, 0, 2, 3, 3, 4] # ihello1213# As we have one batch of samples, we will change them to variables only once14inputs = torch.LongTensor(x_data)15labels = torch.LongTensor(y_data)1617num_classes = 518input_size = 519embedding_size = 10 # embedding size20hidden_size = 5 # output from the LSTM. 5 to directly predict one-hot21batch_size = 1 # one sentence22sequence_length = 6 # |ihello| == 623num_layers = 1 # one-layer rnn242526class Model(nn.Module):2728def __init__(self):29super(Model, self).__init__()30self.embedding = nn.Embedding(input_size, embedding_size)31self.rnn = nn.RNN(input_size=embedding_size,32hidden_size=hidden_size,33batch_first=True)34self.fc = nn.Linear(hidden_size, num_classes)3536def forward(self, x):37# Initialize hidden and cell states38# (num_layers * num_directions, batch, hidden_size)39h_0 = torch.zeros(num_layers, x.size(0), hidden_size)4041emb = self.embedding(x)42emb = emb.view(batch_size, sequence_length, -1)4344# Propagate embedding through RNN45# Input: (batch, seq_len, embedding_size)46# h_0: (num_layers * num_directions, batch, hidden_size)47out, _ = self.rnn(emb, h_0)48return self.fc(out.view(-1, num_classes))495051# Instantiate RNN model52model = Model()53print(model)5455# Set loss and optimizer function56# CrossEntropyLoss = LogSoftmax + NLLLoss57criterion = torch.nn.CrossEntropyLoss()58optimizer = torch.optim.Adam(model.parameters(), lr=0.1)5960# Train the model61for epoch in range(100):62outputs = model(inputs)63optimizer.zero_grad()64loss = criterion(outputs, labels)65loss.backward()66optimizer.step()67_, idx = outputs.max(1)68idx = idx.data.numpy()69result_str = ''.join(idx2char[c] for c in idx.squeeze())70print(f'epoch: {epoch + 1}, loss: {loss.item():1.3f}')71print("Predicted string: ", result_str)7273print("Learning finished!")7475