Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/14/batchnorm_torch.ipynb
1192 views
Kernel: Python 3

Open In Colab

Batch normalization

We implement a batchnorm layer from scratch and add to LeNet CNN.

Code based on sec 7.5 of http://d2l.ai/chapter_convolutional-modern/batch-norm.html

import numpy as np import matplotlib.pyplot as plt import math from IPython import display try: import torch except ModuleNotFoundError: %pip install -qq torch import torch try: import torchvision except ModuleNotFoundError: %pip install -qq torchvision import torchvision from torch import nn from torch.nn import functional as F from torch.utils import data from torchvision import transforms import random import os import time np.random.seed(seed=1) torch.manual_seed(1) !mkdir figures # for saving plots

Implementation from scratch

For fully connected layers, we take the average along minibatch samples for each dimension independently. For 2d convolutional layers, we take the average along minibatch samples, and along horizontal and vertical locations, for each channel (feature dimension) independently.

When training, we update the estimate of the mean and variance using a moving average. When testing (doing inference), we use the pre-computed values.

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum): # Use `is_grad_enabled` to determine whether the current mode is training # mode or prediction mode if not torch.is_grad_enabled(): # If it is prediction mode, directly use the mean and variance # obtained by moving average X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps) else: assert len(X.shape) in (2, 4) if len(X.shape) == 2: # When using a fully-connected layer, calculate the mean and # variance on the feature dimension mean = X.mean(dim=0) var = ((X - mean) ** 2).mean(dim=0) else: # When using a two-dimensional convolutional layer, calculate the # mean and variance on the channel dimension (axis=1). Here we # need to maintain the shape of `X`, so that the broadcasting # operation can be carried out later mean = X.mean(dim=(0, 2, 3), keepdim=True) var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # In training mode, the current mean and variance are used for the # standardization X_hat = (X - mean) / torch.sqrt(var + eps) # Update the mean and variance using moving average moving_mean = momentum * moving_mean + (1.0 - momentum) * mean moving_var = momentum * moving_var + (1.0 - momentum) * var Y = gamma * X_hat + beta # Scale and shift return Y, moving_mean.data, moving_var.data

Wrap the batch norm function in a layer

class BatchNorm(nn.Module): # `num_features`: the number of outputs for a fully-connected layer # or the number of output channels for a convolutional layer. `num_dims`: # 2 for a fully-connected layer and 4 for a convolutional layer def __init__(self, num_features, num_dims): super().__init__() if num_dims == 2: shape = (1, num_features) else: shape = (1, num_features, 1, 1) # The scale parameter and the shift parameter (model parameters) are # initialized to 1 and 0, respectively self.gamma = nn.Parameter(torch.ones(shape)) self.beta = nn.Parameter(torch.zeros(shape)) # The variables that are not model parameters are initialized to 0 and 1 self.moving_mean = torch.zeros(shape) self.moving_var = torch.ones(shape) def forward(self, X): # If `X` is not on the main memory, copy `moving_mean` and # `moving_var` to the device where `X` is located if self.moving_mean.device != X.device: self.moving_mean = self.moving_mean.to(X.device) self.moving_var = self.moving_var.to(X.device) # Save the updated `moving_mean` and `moving_var` Y, self.moving_mean, self.moving_var = batch_norm( X, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9 ) return Y

Applying batch norm to LeNet

We add BN layers after some of the convolutions and fully connected layers, but before the activation functions.

net = nn.Sequential( nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(), nn.Linear(84, 10), )

Train the model

We train the model using the same code as in the standard LeNet colab. The only difference from the previous colab is the larger learning rate (which is possible because BN stabilizes training).

def load_data_fashion_mnist(batch_size, resize=None): """Download the Fashion-MNIST dataset and then load it into memory.""" trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) return ( data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4), )
class Animator: """For plotting data in animation.""" def __init__( self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale="linear", yscale="linear", fmts=("-", "m--", "g-.", "r:"), nrows=1, ncols=1, figsize=(3.5, 2.5), ): # Incrementally plot multiple lines if legend is None: legend = [] display.set_matplotlib_formats("svg") self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [ self.axes, ] # Use a lambda function to capture arguments self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): # Add multiple data points into the figure if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) class Timer: """Record multiple running times.""" def __init__(self): self.times = [] self.start() def start(self): """Start the timer.""" self.tik = time.time() def stop(self): """Stop the timer and record the time in a list.""" self.times.append(time.time() - self.tik) return self.times[-1] def avg(self): """Return the average time.""" return sum(self.times) / len(self.times) def sum(self): """Return the sum of time.""" return sum(self.times) def cumsum(self): """Return the accumulated time.""" return np.array(self.times).cumsum().tolist() class Accumulator: """For accumulating sums over `n` variables.""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx]
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib.""" axes.set_xlabel(xlabel) axes.set_ylabel(ylabel) axes.set_xscale(xscale) axes.set_yscale(yscale) axes.set_xlim(xlim) axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid()
def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu().""" if torch.cuda.device_count() >= i + 1: return torch.device(f"cuda:{i}") return torch.device("cpu")
def accuracy(y_hat, y): """Compute the number of correct predictions.""" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: y_hat = torch.argmax(y_hat, axis=1) cmp_ = y_hat.type(y.dtype) == y return float(cmp_.type(y.dtype).sum()) def evaluate_accuracy_gpu(net, data_iter, device=None): """Compute the accuracy for a model on a dataset using a GPU.""" if isinstance(net, torch.nn.Module): net.eval() # Set the model to evaluation mode if not device: device = next(iter(net.parameters())).device # No. of correct predictions, no. of predictions metric = Accumulator(2) for X, y in data_iter: X = X.to(device) y = y.to(device) metric.add(accuracy(net(X), y), y.numel()) return metric[0] / metric[1]

Training Function

def train(net, train_iter, test_iter, num_epochs, lr, device): """Train a model with a GPU (defined in Chapter 6).""" def init_weights(m): if type(m) == nn.Linear or type(m) == nn.Conv2d: nn.init.xavier_uniform_(m.weight) net.apply(init_weights) print("training on", device) net.to(device) optimizer = torch.optim.SGD(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss() animator = Animator(xlabel="epoch", xlim=[1, num_epochs], legend=["train loss", "train acc", "test acc"]) timer, num_batches = Timer(), len(train_iter) for epoch in range(num_epochs): # Sum of training loss, sum of training accuracy, no. of examples metric = Accumulator(3) net.train() for i, (X, y) in enumerate(train_iter): timer.start() optimizer.zero_grad() X, y = X.to(device), y.to(device) y_hat = net(X) l = loss(y_hat, y) l.backward() optimizer.step() with torch.no_grad(): metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0]) timer.stop() train_l = metric[0] / metric[2] train_acc = metric[1] / metric[2] if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (train_l, train_acc, None)) test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc)) print(f"loss {train_l:.3f}, train acc {train_acc:.3f}, " f"test acc {test_acc:.3f}") print(f"{metric[2] * num_epochs / timer.sum():.1f} examples/sec " f"on {str(device)}")
lr, num_epochs, batch_size = 1.0, 10, 256 train_iter, test_iter = load_data_fashion_mnist(batch_size) train(net, train_iter, test_iter, num_epochs, lr, try_gpu())
loss 0.250, train acc 0.908, test acc 0.860 19699.9 examples/sec on cuda:0
Image in a Jupyter notebook

Examine learned parameters

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([2.0791, 1.5773, 1.7458, 1.8492, 1.6623, 1.6119], device='cuda:0', grad_fn=<ViewBackward>), tensor([ 0.7361, 0.0704, 0.3563, -1.6718, -0.1365, -0.1688], device='cuda:0', grad_fn=<ViewBackward>))

Use PyTorch's batchnorm layer

The built-in layer is much faster than our Python code, since it is implemented in C++. Note that instead of specifying ndims=2 for fully connected layer (batch x features) and ndims=4 for convolutional later (batch x channels x height x width), we use BatchNorm1d or BatchNorm2d instead.

net = nn.Sequential( nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(), nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(), nn.Linear(84, 10), )

Learning Curve

train(net, train_iter, test_iter, num_epochs, lr, try_gpu())
loss 0.247, train acc 0.909, test acc 0.876 30736.3 examples/sec on cuda:0
Image in a Jupyter notebook