Path: blob/master/notebooks/book1/14/batchnorm_torch.ipynb
1192 views
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/14/batchnorm_jax.ipynb
Batch normalization
We implement a batchnorm layer from scratch and add to LeNet CNN.
Code based on sec 7.5 of http://d2l.ai/chapter_convolutional-modern/batch-norm.html
Implementation from scratch
For fully connected layers, we take the average along minibatch samples for each dimension independently. For 2d convolutional layers, we take the average along minibatch samples, and along horizontal and vertical locations, for each channel (feature dimension) independently.
When training, we update the estimate of the mean and variance using a moving average. When testing (doing inference), we use the pre-computed values.
Wrap the batch norm function in a layer
Applying batch norm to LeNet
We add BN layers after some of the convolutions and fully connected layers, but before the activation functions.
Train the model
We train the model using the same code as in the standard LeNet colab. The only difference from the previous colab is the larger learning rate (which is possible because BN stabilizes training).
Training Function
Examine learned parameters
Use PyTorch's batchnorm layer
The built-in layer is much faster than our Python code, since it is implemented in C++. Note that instead of specifying ndims=2 for fully connected layer (batch x features) and ndims=4 for convolutional later (batch x channels x height x width), we use BatchNorm1d or BatchNorm2d instead.