Path: blob/master/CNN/lab-11-5-mnist_cnn_ensemble.py
618 views
# Lab 11 MNIST and Deep learning CNN1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms4import torch.nn.init5import numpy as np67device = 'cuda' if torch.cuda.is_available() else 'cpu'89# for reproducibility10torch.manual_seed(777)11if device == 'cuda':12torch.cuda.manual_seed_all(777)1314# parameters15learning_rate = 0.00116training_epochs = 1517batch_size = 1001819# MNIST dataset20mnist_train = dsets.MNIST(root='MNIST_data/',21train=True,22transform=transforms.ToTensor(),23download=True)2425mnist_test = dsets.MNIST(root='MNIST_data/',26train=False,27transform=transforms.ToTensor(),28download=True)2930# dataset loader31data_loader = torch.utils.data.DataLoader(dataset=mnist_train,32batch_size=batch_size,33shuffle=True,34drop_last=True)3536# CNN Model37class CNN(torch.nn.Module):3839def __init__(self):40super(CNN, self).__init__()41self._build_net()4243def _build_net(self):44# dropout (keep_prob) rate 0.7~0.5 on training, but should be 145self.keep_prob = 0.546# L1 ImgIn shape=(?, 28, 28, 1)47# Conv -> (?, 28, 28, 32)48# Pool -> (?, 14, 14, 32)49self.layer1 = torch.nn.Sequential(50torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),51torch.nn.ReLU(),52torch.nn.MaxPool2d(kernel_size=2, stride=2))53# L2 ImgIn shape=(?, 14, 14, 32)54# Conv ->(?, 14, 14, 64)55# Pool ->(?, 7, 7, 64)56self.layer2 = torch.nn.Sequential(57torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),58torch.nn.ReLU(),59torch.nn.MaxPool2d(kernel_size=2, stride=2))60# L3 ImgIn shape=(?, 7, 7, 64)61# Conv ->(?, 7, 7, 128)62# Pool ->(?, 4, 4, 128)63self.layer3 = torch.nn.Sequential(64torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),65torch.nn.ReLU(),66torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1))67# L4 FC 4x4x128 inputs -> 625 outputs68self.fc1 = torch.nn.Linear(4 * 4 * 128, 625, bias=True)69torch.nn.init.xavier_uniform_(self.fc1.weight)70self.layer4 = torch.nn.Sequential(71self.fc1,72torch.nn.ReLU(),73torch.nn.Dropout(p=1 - self.keep_prob))74# L5 Final FC 625 inputs -> 10 outputs75self.fc2 = torch.nn.Linear(625, 10, bias=True)76torch.nn.init.xavier_uniform_(self.fc2.weight)7778# define cost/loss & optimizer79self.criterion = torch.nn.CrossEntropyLoss() # Softmax is internally computed.80self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)8182def forward(self, x):83out = self.layer1(x)84out = self.layer2(out)85out = self.layer3(out)86out = out.view(out.size(0), -1) # Flatten them for FC87out = self.layer4(out)88out = self.fc2(out)89return out9091def predict(self, x):92self.eval()93return self.forward(x)9495def get_accuracy(self, x, y):96prediction = self.predict(x)97correct_prediction = torch.argmax(prediction, 1) == Y_test98self.accuracy = correct_prediction.float().mean()99return self.accuracy100101def train_model(self, x, y):102self.train()103self.optimizer.zero_grad()104hypothesis = self.forward(x)105self.cost = self.criterion(hypothesis, y)106self.cost.backward()107self.optimizer.step()108return self.cost109110111# instantiate CNN model112models = []113num_models = 2114for m in range(num_models):115models.append(CNN().to(device))116117# train my model118total_batch = len(data_loader)119print('Learning started. It takes sometime.')120for epoch in range(training_epochs):121avg_cost_list = np.zeros(len(models))122123for X, Y in data_loader:124X = X.to(device)125Y = Y.to(device)126# image is already size of (28x28), no reshape127# label is not one-hot encoded128129for m_idx, m in enumerate(models):130cost = m.train_model(X, Y)131avg_cost_list[m_idx] += cost / total_batch132133print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost_list.mean()))134135print('Learning Finished!')136137# Test model and check accuracy138with torch.no_grad():139X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)140Y_test = mnist_test.test_labels.to(device)141predictions = torch.zeros([len(mnist_test), 10])142for m_idx, m in enumerate(models):143print(m_idx, 'Accuracy:', m.get_accuracy(X_test, Y_test))144p = m.predict(X_test)145predictions += p.cpu()146147ensemble_correct_prediction = torch.argmax(predictions, 1) == Y_test.cpu()148ensemble_accuracy = ensemble_correct_prediction.float().mean()149print('Accuracy:', ensemble_accuracy.item())150151