Path: blob/master/deprecated/scripts/armijo_mnist_demo.py
1192 views
# We compare armijo line search to fixed learning rate SGD1# when used to fit a CNN / MLP to MNIST23# Linesearch code is from4# https://github.com/IssamLaradji/stochastic_line_search/blob/master/main.py5import superimport67from armijo_sgd import SGD_Armijo, ArmijoModel89# Neural net code is based on various tutorials10#https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py11#https://github.com/CSCfi/machine-learning-scripts/blob/master/notebooks/pytorch-mnist-mlp.ipynb121314import numpy as np15np.set_printoptions(precision=3)16import matplotlib.pyplot as plt17import pyprobml_utils as pml18import warnings19warnings.filterwarnings('ignore')20212223import torch24use_cuda = torch.cuda.is_available()25device = torch.device("cuda:0" if use_cuda else "cpu")26torch.backends.cudnn.benchmark = True27print('Using PyTorch version:', torch.__version__, ' Device:', device)282930figdir = "../figures"31import os3233############34# Get data35import torchvision36import torchvision.transforms as transforms37import torchvision.datasets as datasets383940batch_size = 3241train_dataset = datasets.MNIST('./data',42train=True,43download=True,44transform=transforms.ToTensor())4546test_dataset = datasets.MNIST('./data',47train=False,48transform=transforms.ToTensor())4950train_loader = torch.utils.data.DataLoader(dataset=train_dataset,51batch_size=batch_size,52shuffle=True)5354test_loader = torch.utils.data.DataLoader(dataset=test_dataset,55batch_size=batch_size,56shuffle=False)575859for (X_train, y_train) in train_loader:60print('X_train:', X_train.size(), 'type:', X_train.type())61print('y_train:', y_train.size(), 'type:', y_train.type())62break6364bs, ncolors, height, width = X_train.shape65nclasses = 1066N_train = train_dataset.data.shape[0]6768#####69# Define model7071import torch.nn as nn72import torch.nn.functional as F7374criterion = nn.CrossEntropyLoss(reduction='mean')75# https://pytorch.org/docs/stable/nn.html#crossentropyloss76# This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single clas77# Therefore we don't need the LogSoftmax on the final layer78# But we do need it if we use NLLLoss7980# The Armijo method assumes gradient noise goes to zero,81# so it is important that we don't have dropout layers.8283class CNN(nn.Module):84def __init__(self):85super(CNN, self).__init__()86self.conv1 = nn.Conv2d(ncolors, 10, kernel_size=5)87self.conv2 = nn.Conv2d(10, 20, kernel_size=5)88#self.dropout = nn.Dropout2d()89self.fc1 = nn.Linear(320, 50)90self.fc2 = nn.Linear(50, 10)9192def forward(self, x):93# input is 28x28x194# conv1(kernel=5, filters=10) 28x28x10 -> 24x24x1095# max_pool(kernel=2) 24x24x10 -> 12x12x1096x = F.relu(F.max_pool2d(self.conv1(x), 2))9798# conv2(kernel=5, filters=20) 12x12x20 -> 8x8x2099# max_pool(kernel=2) 8x8x20 -> 4x4x20100#x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))101x = F.relu(F.max_pool2d(self.conv2(x), 2))102103# flatten 4x4x20 = 320104x = x.view(-1, 320)105106# 320 -> 50107x = F.relu(self.fc1(x))108#x = F.dropout(x, training=self.training)109110# 50 -> 10111x = self.fc2(x)112113return x114#return F.log_softmax(x)115116class MLP(nn.Module):117def __init__(self):118super(MLP, self).__init__()119self.fc1 = nn.Linear(ncolors*height*width, 50)120#self.fc1_drop = nn.Dropout(0.2)121self.fc2 = nn.Linear(50, 50)122#self.fc2_drop = nn.Dropout(0.2)123self.fc3 = nn.Linear(50, nclasses)124125def forward(self, x):126x = x.view(-1, ncolors*height*width)127x = F.relu(self.fc1(x))128#x = self.fc1_drop(x)129x = F.relu(self.fc2(x))130#x = self.fc2_drop(x)131x = self.fc3(x)132#return F.log_softmax(x, dim=1)133return x134135class Logreg(nn.Module):136def __init__(self):137super(Logreg, self).__init__()138self.fc1 = nn.Linear(ncolors*height*width, nclasses)139140def forward(self, x):141x = x.view(-1, ncolors*height*width)142x = self.fc1(x)143#return F.log_softmax(x, dim=1)144return x145146def make_model(name, seed=0):147np.random.seed(seed)148if name == 'CNN':149net = CNN()150elif name == 'MLP':151net = MLP()152else:153net = Logreg()154net = net.to(device)155return net156157###############158159# Define each expermental configuration160expts = []161ep = 4162#model = 'Logreg'163model = 'MLP'164#model = 'CNN'165bs = 10166expts.append({'lr':'armijo', 'bs':bs, 'epochs':ep, 'model': model})167expts.append({'lr':0.01, 'bs':bs, 'epochs':ep, 'model': model})168expts.append({'lr':0.1, 'bs':bs, 'epochs':ep, 'model': model})169#expts.append({'lr':0.5, 'bs':bs, 'epochs':ep, 'model': model})170171@torch.no_grad()172def eval_loss(model, loader):173avg_loss = 0.0174model.eval()175for step, (x_batch, y_batch) in enumerate(loader):176# Copy data to GPU if needed177x_batch = x_batch.to(device)178y_batch = y_batch.to(device)179y_pred = model(x_batch)180loss = criterion(y_pred, y_batch)181avg_loss += loss.item()182# Compute average loss per example183# Note that the criterion already averages within each batch.184n_batches = len(loader)185avg_loss /= n_batches186return avg_loss187188def fit_epoch(model, optimizer, train_loader, loss_history):189epoch_loss = 0.0190model.train()191for step, (x_batch, y_batch) in enumerate(train_loader):192# Copy data to GPU if needed193x_batch = x_batch.to(device)194y_batch = y_batch.to(device)195# Function to (re)evaluate loss and its gradient for this step.196def closure():197optimizer.zero_grad()198y_pred = model(x_batch)199loss = criterion(y_pred, y_batch)200loss.backward()201return loss202loss = optimizer.step(closure)203batch_loss = loss.item()204epoch_loss += batch_loss205loss_history.append(batch_loss)206# Compute average loss per example for this epoch.207# Note that the criterion already averages within each batch.208n_batches = len(train_loader)209epoch_loss /= n_batches210return epoch_loss211212def fit_epoch_armijo(model, optimizer, train_loader, loss_history, step_size_history):213epoch_loss = 0.0214for step, (x_batch, y_batch) in enumerate(train_loader):215x_batch = x_batch.to(device)216y_batch = y_batch.to(device)217batch_loss, step_size = model.step((x_batch, y_batch))218epoch_loss += batch_loss219loss_history.append(batch_loss)220step_size_history.append(step_size)221n_batches = len(train_loader)222epoch_loss /= n_batches223return epoch_loss224225226results_dict = {}227for expt in expts:228lr = expt['lr']229bs = expt['bs']230max_epochs = expt['epochs']231model_name = expt['model']232model = make_model(model_name)233model.train() # set to training mode234train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs,235shuffle=True, num_workers=2)236n_batches = len(train_loader)237batch_loss_history = []238epoch_loss_history = []239step_size_history = []240print_every = max(1, int(0.1*max_epochs))241if lr == 'armijo':242name = '{}-armijo-bs{}'.format(model_name, bs)243model = ArmijoModel(model, criterion)244optimizer = SGD_Armijo(model, batch_size=bs, dataset_size=N_train)245model.opt = optimizer246armijo = True247else:248name = '{}-lr{:0.3f}-bs{}'.format(model_name, lr, bs)249optimizer = torch.optim.SGD(model.parameters(), lr=lr)250armijo = False251252print('starting {}'.format(name))253for epoch in range(max_epochs):254if armijo:255avg_batch_loss = fit_epoch_armijo(model, optimizer, train_loader, batch_loss_history, step_size_history)256else:257avg_batch_loss = fit_epoch(model, optimizer, train_loader, batch_loss_history)258epoch_loss = eval_loss(model, train_loader)259epoch_loss_history.append(epoch_loss)260if epoch % print_every == 0:261print("epoch {}, loss {}".format(epoch, epoch_loss))262263label = '{}-final-loss{:0.3f}'.format(name, epoch_loss)264results = {'label': label, 'batch_loss_history': batch_loss_history,265'epoch_loss_history': epoch_loss_history, 'step_size_history': step_size_history}266results_dict[name] = results267268269plt.figure()270name = 'MLP-armijo-bs10'271results = results_dict[name]272plt.plot(results['step_size_history'])273plt.ylabel('stepsize')274pml.savefig('armijo-mnist-stepsize.pdf')275plt.show()276277plt.figure()278for name, results in results_dict.items():279label = results['label']280y = results['epoch_loss_history']281plt.plot(y, label=label)282plt.legend()283pml.savefig('armijo-mnist-epoch-loss.pdf')284plt.show()285286# Add smoothed version of batch loss history to results dict287import pandas as pd288for name, results in results_dict.items():289loss_history = results['batch_loss_history']290df = pd.Series(loss_history)291nsteps = len(loss_history)292smoothed = pd.Series.ewm(df, span=0.1*nsteps).mean()293results['batch_loss_history_smoothed'] = smoothed294295# Plot curves on one figure296plt.figure()297for name, results in results_dict.items():298label = results['label']299y = results['batch_loss_history_smoothed']300nsteps = len(y)301x = np.arange(nsteps)302ndx = np.arange(int(0.2*nsteps), nsteps) # skip first 20%303#plt.figure()304plt.plot(x[ndx], y[ndx], label=label)305plt.legend()306pml.savefig('armijo-mnist-batch-loss.pdf')307plt.show()308309310311