Path: blob/master/RNN/7-packed_sequence.py
618 views
import torch1import numpy as np2from torch.nn.utils.rnn import pad_sequence, pack_sequence, pack_padded_sequence, pad_packed_sequence34# Random word from random word generator5data = ['hello world',6'midnight',7'calculation',8'path',9'short circuit']1011# Make dictionary12char_set = ['<pad>'] + list(set(char for seq in data for char in seq)) # Get all characters and include pad token13char2idx = {char: idx for idx, char in enumerate(char_set)} # Constuct character to index dictionary14print('char_set:', char_set)15print('char_set length:', len(char_set))1617# Convert character to index and18# Make list of tensors19X = [torch.LongTensor([char2idx[char] for char in seq]) for seq in data]2021# Check converted result22for sequence in X:23print(sequence)2425# Make length tensor (will be used later in 'pack_padded_sequence' function)26lengths = [len(seq) for seq in X]27print('lengths:', lengths)2829# Make a Tensor of shape (Batch x Maximum_Sequence_Length)30padded_sequence = pad_sequence(X, batch_first=True) # X is now padded sequence31print(padded_sequence)32print(padded_sequence.shape)3334# Sort by descending lengths35sorted_idx = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)36sorted_X = [X[idx] for idx in sorted_idx]3738# Check converted result39for sequence in sorted_X:40print(sequence)4142packed_sequence = pack_sequence(sorted_X)43print(packed_sequence)4445# one-hot embedding using PaddedSequence46eye = torch.eye(len(char_set)) # Identity matrix of shape (len(char_set), len(char_set))47embedded_tensor = eye[padded_sequence]48print(embedded_tensor.shape) # shape: (Batch_size, max_sequence_length, number_of_input_tokens)4950# one-hot embedding using PackedSequence51embedded_packed_seq = pack_sequence([eye[X[idx]] for idx in sorted_idx])52print(embedded_packed_seq.data.shape)5354# declare RNN55rnn = torch.nn.RNN(input_size=len(char_set), hidden_size=30, batch_first=True)5657# Try out PaddedSequence58rnn_output, hidden = rnn(embedded_tensor)59print(rnn_output.shape) # shape: (batch_size, max_seq_length, hidden_size)60print(hidden.shape) # shape: (num_layers * num_directions, batch_size, hidden_size)6162# Try out PackedSequence63rnn_output, hidden = rnn(embedded_packed_seq)64print(rnn_output.data.shape)65print(hidden.data.shape)6667# Try out pad_packed_sequence68unpacked_sequence, seq_lengths = pad_packed_sequence(embedded_packed_seq, batch_first=True)69print(unpacked_sequence.shape)70print(seq_lengths)7172# Construct embedded_padded_sequence73embedded_padded_sequence = eye[pad_sequence(sorted_X, batch_first=True)]74print(embedded_padded_sequence.shape)7576# Try out pack_padded_sequence77sorted_lengths = sorted(lengths, reverse=True)78new_packed_sequence = pack_padded_sequence(embedded_padded_sequence, sorted_lengths, batch_first=True)79print(new_packed_sequence.data.shape)80print(new_packed_sequence.batch_sizes)8182838485