Path: blob/master/CNN/lab-10-X1-mnist_back_prop.py
618 views
# Lab 10 MNIST and softmax1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms45device = 'cuda' if torch.cuda.is_available() else 'cpu'67# for reproducibility8torch.manual_seed(777)9if device == 'cuda':10torch.cuda.manual_seed_all(777)1112# parameters13learning_rate = 0.514batch_size = 101516# MNIST dataset17mnist_train = dsets.MNIST(root='MNIST_data/',18train=True,19transform=transforms.ToTensor(),20download=True)2122mnist_test = dsets.MNIST(root='MNIST_data/',23train=False,24transform=transforms.ToTensor(),25download=True)2627# dataset loader28data_loader = torch.utils.data.DataLoader(dataset=mnist_train,29batch_size=batch_size,30shuffle=True,31drop_last=True)3233w1 = torch.nn.Parameter(torch.Tensor(784, 30)).to(device)34b1 = torch.nn.Parameter(torch.Tensor(30)).to(device)35w2 = torch.nn.Parameter(torch.Tensor(30, 10)).to(device)36b2 = torch.nn.Parameter(torch.Tensor(10)).to(device)3738torch.nn.init.normal_(w1)39torch.nn.init.normal_(b1)40torch.nn.init.normal_(w2)41torch.nn.init.normal_(b2)4243def sigma(x):44# sigmoid function45return 1.0 / (1.0 + torch.exp(-x))46# return torch.div(torch.tensor(1), torch.add(torch.tensor(1.0), torch.exp(-x)))474849def sigma_prime(x):50# derivative of the sigmoid function51return sigma(x) * (1 - sigma(x))5253X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)[:1000]54Y_test = mnist_test.test_labels.to(device)[:1000]55i = 056while not i == 10000:57for X, Y in data_loader:58i += 15960# forward61X = X.view(-1, 28 * 28).to(device)62Y = torch.zeros((batch_size, 10)).scatter_(1, Y.unsqueeze(1), 1).to(device) # one-hot63l1 = torch.add(torch.matmul(X, w1), b1)64a1 = sigma(l1)65l2 = torch.add(torch.matmul(a1, w2), b2)66y_pred = sigma(l2)6768diff = y_pred - Y6970# Back prop (chain rule)71d_l2 = diff * sigma_prime(l2)72d_b2 = d_l273d_w2 = torch.matmul(torch.transpose(a1, 0, 1), d_l2)7475d_a1 = torch.matmul(d_l2, torch.transpose(w2, 0, 1))76d_l1 = d_a1 * sigma_prime(l1)77d_b1 = d_l178d_w1 = torch.matmul(torch.transpose(X, 0, 1), d_l1)7980w1 = w1 - learning_rate * d_w181b1 = b1 - learning_rate * torch.mean(d_b1, 0)82w2 = w2 - learning_rate * d_w283b2 = b2 - learning_rate * torch.mean(d_b2, 0)8485if i % 1000 == 0:86l1 = torch.add(torch.matmul(X_test, w1), b1)87a1 = sigma(l1)88l2 = torch.add(torch.matmul(a1, w2), b2)89y_pred = sigma(l2)90acct_mat = torch.argmax(y_pred, 1) == Y_test91acct_res = acct_mat.sum()92print(acct_res.item())9394if i == 10000:95break9697