Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
deeplearningzerotoall
GitHub Repository: deeplearningzerotoall/PyTorch
Path: blob/master/CNN/lab-10-X1-mnist_back_prop.py
618 views
1
# Lab 10 MNIST and softmax
2
import torch
3
import torchvision.datasets as dsets
4
import torchvision.transforms as transforms
5
6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
8
# for reproducibility
9
torch.manual_seed(777)
10
if device == 'cuda':
11
torch.cuda.manual_seed_all(777)
12
13
# parameters
14
learning_rate = 0.5
15
batch_size = 10
16
17
# MNIST dataset
18
mnist_train = dsets.MNIST(root='MNIST_data/',
19
train=True,
20
transform=transforms.ToTensor(),
21
download=True)
22
23
mnist_test = dsets.MNIST(root='MNIST_data/',
24
train=False,
25
transform=transforms.ToTensor(),
26
download=True)
27
28
# dataset loader
29
data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
30
batch_size=batch_size,
31
shuffle=True,
32
drop_last=True)
33
34
w1 = torch.nn.Parameter(torch.Tensor(784, 30)).to(device)
35
b1 = torch.nn.Parameter(torch.Tensor(30)).to(device)
36
w2 = torch.nn.Parameter(torch.Tensor(30, 10)).to(device)
37
b2 = torch.nn.Parameter(torch.Tensor(10)).to(device)
38
39
torch.nn.init.normal_(w1)
40
torch.nn.init.normal_(b1)
41
torch.nn.init.normal_(w2)
42
torch.nn.init.normal_(b2)
43
44
def sigma(x):
45
# sigmoid function
46
return 1.0 / (1.0 + torch.exp(-x))
47
# return torch.div(torch.tensor(1), torch.add(torch.tensor(1.0), torch.exp(-x)))
48
49
50
def sigma_prime(x):
51
# derivative of the sigmoid function
52
return sigma(x) * (1 - sigma(x))
53
54
X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)[:1000]
55
Y_test = mnist_test.test_labels.to(device)[:1000]
56
i = 0
57
while not i == 10000:
58
for X, Y in data_loader:
59
i += 1
60
61
# forward
62
X = X.view(-1, 28 * 28).to(device)
63
Y = torch.zeros((batch_size, 10)).scatter_(1, Y.unsqueeze(1), 1).to(device) # one-hot
64
l1 = torch.add(torch.matmul(X, w1), b1)
65
a1 = sigma(l1)
66
l2 = torch.add(torch.matmul(a1, w2), b2)
67
y_pred = sigma(l2)
68
69
diff = y_pred - Y
70
71
# Back prop (chain rule)
72
d_l2 = diff * sigma_prime(l2)
73
d_b2 = d_l2
74
d_w2 = torch.matmul(torch.transpose(a1, 0, 1), d_l2)
75
76
d_a1 = torch.matmul(d_l2, torch.transpose(w2, 0, 1))
77
d_l1 = d_a1 * sigma_prime(l1)
78
d_b1 = d_l1
79
d_w1 = torch.matmul(torch.transpose(X, 0, 1), d_l1)
80
81
w1 = w1 - learning_rate * d_w1
82
b1 = b1 - learning_rate * torch.mean(d_b1, 0)
83
w2 = w2 - learning_rate * d_w2
84
b2 = b2 - learning_rate * torch.mean(d_b2, 0)
85
86
if i % 1000 == 0:
87
l1 = torch.add(torch.matmul(X_test, w1), b1)
88
a1 = sigma(l1)
89
l2 = torch.add(torch.matmul(a1, w2), b2)
90
y_pred = sigma(l2)
91
acct_mat = torch.argmax(y_pred, 1) == Y_test
92
acct_res = acct_mat.sum()
93
print(acct_res.item())
94
95
if i == 10000:
96
break
97