Path: blob/master/CNN/lab-11-1-mnist_cnn.py
618 views
# Lab 11 MNIST and Convolutional Neural Network1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms4import torch.nn.init56# device = 'cuda' if torch.cuda.is_available() else 'cpu'7device = 'cpu'8# for reproducibility9torch.manual_seed(777)10if device == 'cuda':11torch.cuda.manual_seed_all(777)1213# parameters14learning_rate = 0.00115training_epochs = 1516batch_size = 1001718# MNIST dataset19mnist_train = dsets.MNIST(root='MNIST_data/',20train=True,21transform=transforms.ToTensor(),22download=True)2324mnist_test = dsets.MNIST(root='MNIST_data/',25train=False,26transform=transforms.ToTensor(),27download=True)2829# dataset loader30data_loader = torch.utils.data.DataLoader(dataset=mnist_train,31batch_size=batch_size,32shuffle=True,33drop_last=True)3435# CNN Model (2 conv layers)36class CNN(torch.nn.Module):3738def __init__(self):39super(CNN, self).__init__()40# L1 ImgIn shape=(?, 28, 28, 1)41# Conv -> (?, 28, 28, 32)42# Pool -> (?, 14, 14, 32)43self.layer1 = torch.nn.Sequential(44torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),45torch.nn.ReLU(),46torch.nn.MaxPool2d(kernel_size=2, stride=2))47# L2 ImgIn shape=(?, 14, 14, 32)48# Conv ->(?, 14, 14, 64)49# Pool ->(?, 7, 7, 64)50self.layer2 = torch.nn.Sequential(51torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),52torch.nn.ReLU(),53torch.nn.MaxPool2d(kernel_size=2, stride=2))54# Final FC 7x7x64 inputs -> 10 outputs55self.fc = torch.nn.Linear(7 * 7 * 64, 10, bias=True)56torch.nn.init.xavier_uniform_(self.fc.weight)5758def forward(self, x):59out = self.layer1(x)60out = self.layer2(out)61out = out.view(out.size(0), -1) # Flatten them for FC62out = self.fc(out)63return out646566# instantiate CNN model67model = CNN().to(device)6869# define cost/loss & optimizer70criterion = torch.nn.CrossEntropyLoss().to(device) # Softmax is internally computed.71optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)7273# train my model74total_batch = len(data_loader)75print('Learning started. It takes sometime.')76for epoch in range(training_epochs):77avg_cost = 07879for X, Y in data_loader:80# image is already size of (28x28), no reshape81# label is not one-hot encoded82X = X.to(device)83Y = Y.to(device)8485optimizer.zero_grad()86hypothesis = model(X)87cost = criterion(hypothesis, Y)88cost.backward()89optimizer.step()9091avg_cost += cost / total_batch9293print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))9495print('Learning Finished!')9697# Test model and check accuracy98with torch.no_grad():99X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)100Y_test = mnist_test.test_labels.to(device)101102prediction = model(X_test)103correct_prediction = torch.argmax(prediction, 1) == Y_test104accuracy = correct_prediction.float().mean()105print('Accuracy:', accuracy.item())106107108