Path: blob/master/CNN/lab-11-3-mnist_cnn_class.py
618 views
# Lab 11 MNIST and Deep learning CNN1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms4import torch.nn.init56device = 'cuda' if torch.cuda.is_available() else 'cpu'78# for reproducibility9torch.manual_seed(777)10if device == 'cuda':11torch.cuda.manual_seed_all(777)1213# parameters14learning_rate = 0.00115training_epochs = 1516batch_size = 1001718# MNIST dataset19mnist_train = dsets.MNIST(root='MNIST_data/',20train=True,21transform=transforms.ToTensor(),22download=True)2324mnist_test = dsets.MNIST(root='MNIST_data/',25train=False,26transform=transforms.ToTensor(),27download=True)2829# dataset loader30data_loader = torch.utils.data.DataLoader(dataset=mnist_train,31batch_size=batch_size,32shuffle=True,33drop_last=True)3435# CNN Model36class CNN(torch.nn.Module):3738def __init__(self):39super(CNN, self).__init__()40self._build_net()4142def _build_net(self):43# dropout (keep_prob) rate 0.7~0.5 on training, but should be 144self.keep_prob = 0.545# L1 ImgIn shape=(?, 28, 28, 1)46# Conv -> (?, 28, 28, 32)47# Pool -> (?, 14, 14, 32)48self.layer1 = torch.nn.Sequential(49torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),50torch.nn.ReLU(),51torch.nn.MaxPool2d(kernel_size=2, stride=2))52# L2 ImgIn shape=(?, 14, 14, 32)53# Conv ->(?, 14, 14, 64)54# Pool ->(?, 7, 7, 64)55self.layer2 = torch.nn.Sequential(56torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),57torch.nn.ReLU(),58torch.nn.MaxPool2d(kernel_size=2, stride=2))59# L3 ImgIn shape=(?, 7, 7, 64)60# Conv ->(?, 7, 7, 128)61# Pool ->(?, 4, 4, 128)62self.layer3 = torch.nn.Sequential(63torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),64torch.nn.ReLU(),65torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1))66# L4 FC 4x4x128 inputs -> 625 outputs67self.fc1 = torch.nn.Linear(4 * 4 * 128, 625, bias=True)68torch.nn.init.xavier_uniform_(self.fc1.weight)69self.layer4 = torch.nn.Sequential(70self.fc1,71torch.nn.ReLU(),72torch.nn.Dropout(p=1 - self.keep_prob))73# L5 Final FC 625 inputs -> 10 outputs74self.fc2 = torch.nn.Linear(625, 10, bias=True)75torch.nn.init.xavier_uniform_(self.fc2.weight)7677# define cost/loss & optimizer78self.criterion = torch.nn.CrossEntropyLoss() # Softmax is internally computed.79self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)8081def forward(self, x):82out = self.layer1(x)83out = self.layer2(out)84out = self.layer3(out)85out = out.view(out.size(0), -1) # Flatten them for FC86out = self.layer4(out)87out = self.fc2(out)88return out8990def predict(self, x):91self.eval()92return self.forward(x)9394def get_accuracy(self, x, y):95prediction = self.predict(x)96correct_prediction = torch.argmax(prediction, 1) == Y_test97self.accuracy = correct_prediction.float().mean().item()98return self.accuracy99100def train_model(self, x, y):101self.train()102self.optimizer.zero_grad()103hypothesis = self.forward(x)104self.cost = self.criterion(hypothesis, y)105self.cost.backward()106self.optimizer.step()107return self.cost108109110# instantiate CNN model111model = CNN().to(device)112113# train my model114total_batch = len(data_loader)115print('Learning started. It takes sometime.')116for epoch in range(training_epochs):117avg_cost = 0118119for X, Y in data_loader:120# image is already size of (28x28), no reshape121# label is not one-hot encoded122X = X.to(device)123Y = Y.to(device)124125cost = model.train_model(X, Y)126127avg_cost += cost / total_batch128129print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))130131print('Learning Finished!')132133# Test model and check accuracy134with torch.no_grad():135X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)136Y_test = mnist_test.test_labels.to(device)137138print('Accuracy:', model.get_accuracy(X_test, Y_test))139140141