Path: blob/master/CNN/lab-10-8-mnist_nn_selu(wip).py
631 views
# Lab 10 MNIST and softmax1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms4import random56device = 'cuda' if torch.cuda.is_available() else 'cpu'78# for reproducibility9random.seed(777)10torch.manual_seed(777)11if device == 'cuda':12torch.cuda.manual_seed_all(777)1314# parameters15learning_rate = 0.00116training_epochs = 1517batch_size = 10018keep_prob = 0.71920# MNIST dataset21mnist_train = dsets.MNIST(root='MNIST_data/',22train=True,23transform=transforms.ToTensor(),24download=True)2526mnist_test = dsets.MNIST(root='MNIST_data/',27train=False,28transform=transforms.ToTensor(),29download=True)3031# dataset loader32data_loader = torch.utils.data.DataLoader(dataset=mnist_train,33batch_size=batch_size,34shuffle=True,35drop_last=True)3637# nn layers38linear1 = torch.nn.Linear(784, 512, bias=True)39linear2 = torch.nn.Linear(512, 512, bias=True)40linear3 = torch.nn.Linear(512, 512, bias=True)41linear4 = torch.nn.Linear(512, 512, bias=True)42linear5 = torch.nn.Linear(512, 10, bias=True)43selu = torch.nn.SELU()4445# xavier initialization46torch.nn.init.xavier_uniform_(linear1.weight)47torch.nn.init.xavier_uniform_(linear2.weight)48torch.nn.init.xavier_uniform_(linear3.weight)49torch.nn.init.xavier_uniform_(linear4.weight)50torch.nn.init.xavier_uniform_(linear5.weight)5152# model53model = torch.nn.Sequential(linear1, selu,54linear2, selu,55linear3, selu,56linear4, selu,57linear5).to(device)5859# define cost/loss & optimizer60criterion = torch.nn.CrossEntropyLoss().to(device) # Softmax is internally computed.61optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)6263total_batch = len(data_loader)64model.train() # set the model to train mode (dropout=True)65for epoch in range(training_epochs):66avg_cost = 06768for X, Y in data_loader:69# reshape input image into [batch_size by 784]70# label is not one-hot encoded71X = X.view(-1, 28 * 28).to(device)72Y = Y.to(device)7374optimizer.zero_grad()75hypothesis = model(X)76cost = criterion(hypothesis, Y)77cost.backward()78optimizer.step()7980avg_cost += cost / total_batch8182print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))8384print('Learning finished')8586# Test model and check accuracy87with torch.no_grad():88model.eval() # set the model to evaluation mode (dropout=False)8990# Test the model using test sets91X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)92Y_test = mnist_test.test_labels.to(device)9394prediction = model(X_test)95correct_prediction = torch.argmax(prediction, 1) == Y_test96accuracy = correct_prediction.float().mean()97print('Accuracy:', accuracy.item())9899# Get one and predict100r = random.randint(0, len(mnist_test) - 1)101X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)102Y_single_data = mnist_test.test_labels[r:r + 1].to(device)103104print('Label: ', Y_single_data.item())105single_prediction = model(X_single_data)106print('Prediction: ', torch.argmax(single_prediction, 1).item())107108