Path: blob/master/RNN/season1_refactored/RNN_intro_1.py
631 views
import torch1import torch.nn as nn23# Random seed to make results deterministic4torch.manual_seed(0)56# One hot encoding for each char in 'hello'7h = [1, 0, 0, 0]8e = [0, 1, 0, 0]9l = [0, 0, 1, 0]10o = [0, 0, 0, 1]1112# One cell RNN input_dim (4) -> output_dim (2). sequence: 513cell = nn.RNN(input_size=4, hidden_size=2, batch_first=True)1415# (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False16# num_directions is 2 when the RNN is bidirectional, otherwise, it is 117hidden = torch.randn(1, 1, 2)1819# Propagate input through RNN20# Input: (batch, seq_len, input_size) when batch_first=True21inputs = torch.Tensor([h, e, l, l, o])22for one in inputs:23one = one.view(1, 1, -1)24# Input: (batch, seq_len, input_size) when batch_first=True25out, hidden = cell(one, hidden)26print("one input size", one.size(), "out size", out.size())2728# We can do the whole at once29# Propagate input through RNN30# Input: (batch, seq_len, input_size) when batch_first=True31inputs = inputs.view(1, 5, -1)32out, hidden = cell(inputs, hidden)33print("sequence input size", inputs.size(), "out size", out.size())343536# hidden : (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False37hidden = torch.randn(1, 3, 2)3839# One cell RNN input_dim (4) -> output_dim (2). sequence: 5, batch 340# 3 batches 'hello', 'eolll', 'lleel'41# rank = (3, 5, 4)42inputs = torch.Tensor([[h, e, l, l, o],43[e, o, l, l, l],44[l, l, e, e, l]])4546# Propagate input through RNN47# Input: (batch, seq_len, input_size) when batch_first=True48# B x S x I49out, hidden = cell(inputs, hidden)50print("batch input size", inputs.size(), "out size", out.size())515253# One cell RNN input_dim (4) -> output_dim (2)54cell = nn.RNN(input_size=4, hidden_size=2)5556# The given dimensions dim0 and dim1 are swapped.57inputs = inputs.transpose(dim0=0, dim1=1)58# Propagate input through RNN59# Input: (seq_len, batch_size, input_size) when batch_first=False (default)60# S x B x I61out, hidden = cell(inputs, hidden)62print("batch input size", inputs.size(), "out size", out.size())6364