Path: blob/master/CNN/lab-10-6-mnist_batchnorm.py
618 views
# Lab 10 MNIST and softmax1import torch2import torchvision.datasets as dsets3import torchvision.transforms as transforms4import matplotlib.pylab as plt56device = 'cuda' if torch.cuda.is_available() else 'cpu'78# for reproducibility9torch.manual_seed(1)10if device == 'cuda':11torch.cuda.manual_seed_all(1)1213# parameters14learning_rate = 0.0115training_epochs = 1016batch_size = 321718# MNIST dataset19mnist_train = dsets.MNIST(root='MNIST_data/',20train=True,21transform=transforms.ToTensor(),22download=True)2324mnist_test = dsets.MNIST(root='MNIST_data/',25train=False,26transform=transforms.ToTensor(),27download=True)2829# dataset loader30train_loader = torch.utils.data.DataLoader(dataset=mnist_train,31batch_size=batch_size,32shuffle=True,33drop_last=True)3435test_loader = torch.utils.data.DataLoader(dataset=mnist_test,36batch_size=batch_size,37shuffle=False,38drop_last=True)3940# nn layers41linear1 = torch.nn.Linear(784, 32, bias=True)42linear2 = torch.nn.Linear(32, 32, bias=True)43linear3 = torch.nn.Linear(32, 10, bias=True)44relu = torch.nn.ReLU()45bn1 = torch.nn.BatchNorm1d(32)46bn2 = torch.nn.BatchNorm1d(32)4748nn_linear1 = torch.nn.Linear(784, 32, bias=True)49nn_linear2 = torch.nn.Linear(32, 32, bias=True)50nn_linear3 = torch.nn.Linear(32, 10, bias=True)5152# model53bn_model = torch.nn.Sequential(linear1, relu, bn1,54linear2, relu, bn2,55linear3).to(device)56nn_model = torch.nn.Sequential(nn_linear1, relu,57nn_linear2, relu,58nn_linear3).to(device)5960# define cost/loss & optimizer61criterion = torch.nn.CrossEntropyLoss().to(device) # Softmax is internally computed.62bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)63nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=learning_rate)6465# Save Losses and Accuracies every epoch66# We are going to plot them later67train_losses = []68train_accs = []6970valid_losses = []71valid_accs = []7273train_total_batch = len(train_loader)74test_total_batch = len(test_loader)75for epoch in range(training_epochs):76bn_model.train() # set the model to train mode7778for X, Y in train_loader:79# reshape input image into [batch_size by 784]80# label is not one-hot encoded81X = X.view(-1, 28 * 28).to(device)82Y = Y.to(device)8384bn_optimizer.zero_grad()85bn_prediction = bn_model(X)86bn_loss = criterion(bn_prediction, Y)87bn_loss.backward()88bn_optimizer.step()8990nn_optimizer.zero_grad()91nn_prediction = nn_model(X)92nn_loss = criterion(nn_prediction, Y)93nn_loss.backward()94nn_optimizer.step()9596with torch.no_grad():97bn_model.eval() # set the model to evaluation mode9899# Test the model using train sets100bn_loss, nn_loss, bn_acc, nn_acc = 0, 0, 0, 0101for i, (X, Y) in enumerate(train_loader):102X = X.view(-1, 28 * 28).to(device)103Y = Y.to(device)104105bn_prediction = bn_model(X)106bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y107bn_loss += criterion(bn_prediction, Y)108bn_acc += bn_correct_prediction.float().mean()109110nn_prediction = nn_model(X)111nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y112nn_loss += criterion(nn_prediction, Y)113nn_acc += nn_correct_prediction.float().mean()114115bn_loss, nn_loss, bn_acc, nn_acc = bn_loss / train_total_batch, nn_loss / train_total_batch, bn_acc / train_total_batch, nn_acc / train_total_batch116117# Save train losses/acc118train_losses.append([bn_loss, nn_loss])119train_accs.append([bn_acc, nn_acc])120print(121'[Epoch %d-TRAIN] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)' % (122(epoch + 1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))123# Test the model using test sets124bn_loss, nn_loss, bn_acc, nn_acc = 0, 0, 0, 0125for i, (X, Y) in enumerate(test_loader):126X = X.view(-1, 28 * 28).to(device)127Y = Y.to(device)128129bn_prediction = bn_model(X)130bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y131bn_loss += criterion(bn_prediction, Y)132bn_acc += bn_correct_prediction.float().mean()133134nn_prediction = nn_model(X)135nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y136nn_loss += criterion(nn_prediction, Y)137nn_acc += nn_correct_prediction.float().mean()138139bn_loss, nn_loss, bn_acc, nn_acc = bn_loss / test_total_batch, nn_loss / test_total_batch, bn_acc / test_total_batch, nn_acc / test_total_batch140141# Save valid losses/acc142valid_losses.append([bn_loss, nn_loss])143valid_accs.append([bn_acc, nn_acc])144print(145'[Epoch %d-VALID] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)' % (146(epoch + 1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))147print()148149print('Learning finished')150151def plot_compare(loss_list: list, ylim=None, title=None) -> None:152bn = [i[0] for i in loss_list]153nn = [i[1] for i in loss_list]154155plt.figure(figsize=(15, 10))156plt.plot(bn, label='With BN')157plt.plot(nn, label='Without BN')158if ylim:159plt.ylim(ylim)160161if title:162plt.title(title)163plt.legend()164plt.grid('on')165plt.show()166167plot_compare(train_losses, title='Training Loss at Epoch')168plot_compare(train_accs, [0, 1.0], title='Training Acc at Epoch')169plot_compare(valid_losses, title='Validation Loss at Epoch')170plot_compare(valid_accs, [0, 1.0], title='Validation Acc at Epoch')171172