Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
deeplearningzerotoall
GitHub Repository: deeplearningzerotoall/PyTorch
Path: blob/master/RNN/2-hihello.ipynb
618 views
Kernel: Python 3
import torch import torch.optim as optim import numpy as np
# Random seed to make results deterministic and reproducible torch.manual_seed(0)
<torch._C.Generator at 0x111e5fb50>
# declare dictionary char_set = ['h', 'i', 'e', 'l', 'o']
# hyper parameters input_size = len(char_set) hidden_size = len(char_set) learning_rate = 0.1
# data setting x_data = [[0, 1, 0, 2, 3, 3]] x_one_hot = [[[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 0]]] y_data = [[1, 0, 2, 3, 3, 4]]
# transform as torch tensor variable X = torch.FloatTensor(x_one_hot) Y = torch.LongTensor(y_data)
# declare RNN rnn = torch.nn.RNN(input_size, hidden_size, batch_first=True) # batch_first guarantees the order of output = (B, S, F)
# loss & optimizer setting criterion = torch.nn.CrossEntropyLoss() optimizer = optim.Adam(rnn.parameters(), learning_rate)
# start training for i in range(100): optimizer.zero_grad() outputs, _status = rnn(X) loss = criterion(outputs.view(-1, input_size), Y.view(-1)) loss.backward() optimizer.step() result = outputs.data.numpy().argmax(axis=2) result_str = ''.join([char_set[c] for c in np.squeeze(result)]) print(i, "loss: ", loss.item(), "prediction: ", result, "true Y: ", y_data, "prediction str: ", result_str)
0 loss: 1.7802648544311523 prediction: [[1 1 1 1 1 1]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: iiiiii 1 loss: 1.4931954145431519 prediction: [[1 4 1 1 4 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ioiioo 2 loss: 1.3337129354476929 prediction: [[1 3 2 3 1 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilelio 3 loss: 1.215295433998108 prediction: [[2 3 2 3 3 3]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: elelll 4 loss: 1.1131411790847778 prediction: [[2 3 2 3 3 3]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: elelll 5 loss: 1.0241888761520386 prediction: [[2 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: elello 6 loss: 0.9573155045509338 prediction: [[2 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: elello 7 loss: 0.9102011322975159 prediction: [[2 0 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ehello 8 loss: 0.8731772899627686 prediction: [[1 0 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ihello 9 loss: 0.8399624824523926 prediction: [[1 0 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ihello 10 loss: 0.8088951706886292 prediction: [[1 0 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ihello 11 loss: 0.7812867760658264 prediction: [[1 0 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ihello 12 loss: 0.7585349082946777 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 13 loss: 0.7401294112205505 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 14 loss: 0.7243587970733643 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 15 loss: 0.709122359752655 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 16 loss: 0.6929269433021545 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 17 loss: 0.6821210980415344 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 18 loss: 0.6735268235206604 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 19 loss: 0.6595444679260254 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 20 loss: 0.6534828543663025 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 21 loss: 0.6465457081794739 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 22 loss: 0.6398184895515442 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 23 loss: 0.6381523013114929 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 24 loss: 0.6326718926429749 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 25 loss: 0.6256727576255798 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 26 loss: 0.6215079426765442 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 27 loss: 0.616705060005188 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 28 loss: 0.6099358201026917 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 29 loss: 0.6030194163322449 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 30 loss: 0.5992398262023926 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 31 loss: 0.5964334607124329 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 32 loss: 0.5916643738746643 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 33 loss: 0.5881562232971191 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 34 loss: 0.5854337811470032 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 35 loss: 0.5813184976577759 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 36 loss: 0.5761863589286804 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 37 loss: 0.5734922289848328 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 38 loss: 0.5727553963661194 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 39 loss: 0.5682081580162048 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 40 loss: 0.5656263828277588 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 41 loss: 0.5647333264350891 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 42 loss: 0.5629464983940125 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 43 loss: 0.5603764057159424 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 44 loss: 0.5588172078132629 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 45 loss: 0.5584632754325867 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 46 loss: 0.5565395355224609 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 47 loss: 0.5548029541969299 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 48 loss: 0.5542733669281006 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 49 loss: 0.5534438490867615 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 50 loss: 0.5520094037055969 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 51 loss: 0.5510937571525574 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 52 loss: 0.5506715178489685 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 53 loss: 0.5493640303611755 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 54 loss: 0.548336923122406 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 55 loss: 0.5478048920631409 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 56 loss: 0.5469381213188171 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 57 loss: 0.545922577381134 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 58 loss: 0.5454025268554688 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 59 loss: 0.544852077960968 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 60 loss: 0.5439804196357727 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 61 loss: 0.5434582233428955 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 62 loss: 0.5429832935333252 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 63 loss: 0.5422741174697876 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 64 loss: 0.5417040586471558 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 65 loss: 0.5413308143615723 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 66 loss: 0.5407487750053406 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 67 loss: 0.5402576923370361 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 68 loss: 0.5399190187454224 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 69 loss: 0.5394622683525085 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 70 loss: 0.5390090942382812 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 71 loss: 0.5387027263641357 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 72 loss: 0.538316011428833 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 73 loss: 0.5379175543785095 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 74 loss: 0.5376288890838623 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 75 loss: 0.5372945666313171 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 76 loss: 0.5369362235069275 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 77 loss: 0.5366637110710144 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 78 loss: 0.5363660454750061 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 79 loss: 0.53604656457901 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 80 loss: 0.5357930064201355 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 81 loss: 0.5355223417282104 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 82 loss: 0.5352355241775513 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 83 loss: 0.5349991917610168 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 84 loss: 0.5347511172294617 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 85 loss: 0.5344937443733215 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 86 loss: 0.5342753529548645 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 87 loss: 0.534046471118927 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 88 loss: 0.5338148474693298 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 89 loss: 0.5336135029792786 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 90 loss: 0.5334023833274841 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 91 loss: 0.533194899559021 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 92 loss: 0.5330093502998352 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 93 loss: 0.53281569480896 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 94 loss: 0.5326292514801025 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 95 loss: 0.5324582457542419 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 96 loss: 0.5322802066802979 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 97 loss: 0.5321123003959656 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 98 loss: 0.5319531559944153 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello 99 loss: 0.5317898392677307 prediction: [[1 3 2 3 3 4]] true Y: [[1, 0, 2, 3, 3, 4]] prediction str: ilello