Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
deeplearningzerotoall
GitHub Repository: deeplearningzerotoall/PyTorch
Path: blob/master/RNN/3-charseq.ipynb
618 views
Kernel: Python 3
import torch import torch.optim as optim import numpy as np
# Random seed to make results deterministic and reproducible torch.manual_seed(0)
<torch._C.Generator at 0x1118cdb50>
sample = " if you want you"
# make dictionary char_set = list(set(sample)) char_dic = {c: i for i, c in enumerate(char_set)} print(char_dic)
{'y': 0, 't': 1, 'o': 2, 'i': 3, 'f': 4, 'w': 5, 'u': 6, 'a': 7, 'n': 8, ' ': 9}
# hyper parameters dic_size = len(char_dic) hidden_size = len(char_dic) learning_rate = 0.1
# data setting sample_idx = [char_dic[c] for c in sample] x_data = [sample_idx[:-1]] x_one_hot = [np.eye(dic_size)[x] for x in x_data] y_data = [sample_idx[1:]]
# transform as torch tensor variable X = torch.FloatTensor(x_one_hot) Y = torch.LongTensor(y_data)
# declare RNN rnn = torch.nn.RNN(dic_size, hidden_size, batch_first=True)
# loss & optimizer setting criterion = torch.nn.CrossEntropyLoss() optimizer = optim.Adam(rnn.parameters(), learning_rate)
# start training for i in range(50): optimizer.zero_grad() outputs, _status = rnn(X) loss = criterion(outputs.view(-1, dic_size), Y.view(-1)) loss.backward() optimizer.step() result = outputs.data.numpy().argmax(axis=2) result_str = ''.join([char_set[c] for c in np.squeeze(result)]) print(i, "loss: ", loss.item(), "prediction: ", result, "true Y: ", y_data, "prediction str: ", result_str)
0 loss: 2.3307039737701416 prediction: [[2 4 2 9 2 3 2 9 9 4 4 4 9 2 3]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: ofo oio fff oi 1 loss: 2.0587525367736816 prediction: [[2 2 9 5 2 2 2 5 6 8 9 2 2 2 2]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: oo wooowun oooo 2 loss: 1.8148916959762573 prediction: [[0 6 9 5 2 6 9 5 6 8 9 9 2 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yu wou wun oou 3 loss: 1.6277488470077515 prediction: [[0 6 9 0 2 6 9 5 0 8 9 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yu you wyn you 4 loss: 1.521714448928833 prediction: [[0 6 9 0 2 6 9 5 7 8 9 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yu you wan you 5 loss: 1.438066840171814 prediction: [[0 4 9 0 2 6 9 0 7 8 9 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yan you 6 loss: 1.3512088060379028 prediction: [[0 8 9 0 2 6 9 0 7 8 9 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yn you yan you 7 loss: 1.2912108898162842 prediction: [[0 4 9 0 2 6 9 0 7 8 9 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yan yyou 8 loss: 1.2324522733688354 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 9 loss: 1.1849303245544434 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 10 loss: 1.1469171047210693 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 11 loss: 1.1173083782196045 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 12 loss: 1.0898127555847168 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 13 loss: 1.0637227296829224 prediction: [[0 4 9 0 2 6 9 0 7 8 1 0 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: yf you yantyyou 14 loss: 1.0346168279647827 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 15 loss: 1.0096397399902344 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 16 loss: 0.9882010817527771 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 17 loss: 0.9665720462799072 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 18 loss: 0.9490536451339722 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 19 loss: 0.9347025752067566 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 20 loss: 0.9210413694381714 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 21 loss: 0.9097891449928284 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 22 loss: 0.9012717008590698 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 23 loss: 0.8941951394081116 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 24 loss: 0.8881015181541443 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 25 loss: 0.8832717537879944 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 26 loss: 0.8788976669311523 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 27 loss: 0.8751237988471985 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 28 loss: 0.8722966313362122 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 29 loss: 0.8697354793548584 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 30 loss: 0.8674185872077942 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 31 loss: 0.8656530380249023 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 32 loss: 0.863921046257019 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 33 loss: 0.8623018860816956 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 34 loss: 0.8610238432884216 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 35 loss: 0.8597375750541687 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 36 loss: 0.8585304021835327 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 37 loss: 0.8575419187545776 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 38 loss: 0.8564358353614807 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 39 loss: 0.8553726077079773 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 40 loss: 0.8543716669082642 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 41 loss: 0.8532596230506897 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 42 loss: 0.8521810173988342 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 43 loss: 0.8511130809783936 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 44 loss: 0.8499239087104797 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 45 loss: 0.8487468957901001 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 46 loss: 0.84747314453125 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 47 loss: 0.8461035490036011 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 48 loss: 0.8447437882423401 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you 49 loss: 0.8433225154876709 prediction: [[3 4 9 0 2 6 9 0 7 8 1 9 0 2 6]] true Y: [[3, 4, 9, 0, 2, 6, 9, 5, 7, 8, 1, 9, 0, 2, 6]] prediction str: if you yant you