Path: blob/master/notebooks/book1/14/batchnorm_jax.ipynb
1192 views
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/14/batchnorm_torch.ipynb
Batch normalization
We implement a batchnorm layer from scratch and add to LeNet CNN.
Code based on sec 7.5 of http://d2l.ai/chapter_convolutional-modern/batch-norm.html
Implementation from scratch
For fully connected layers, we take the average along minibatch samples for each dimension independently. For 2d convolutional layers, we take the average along minibatch samples, and along horizontal and vertical locations, for each channel (feature dimension) independently.
When training, we update the estimate of the mean and variance using a moving average. When testing (doing inference), we use the pre-computed values.
Wrap the batch norm function in a layer
Applying batch norm to LeNet
We add BN layers after some of the convolutions and fully connected layers, but before the activation functions.
Train the model
We train the model using the same code as in the standard LeNet colab. The only difference from the previous colab is the larger learning rate (which is possible because BN stabilizes training).
Plotting
Training Function
We create a subclass of train_state.TrainState
store the auxilliary variables (i.e. gamma
and beta
) required by BatchNorm
.
Since the same training procedure needs to be applied on two different networks (i.e. LeNetBN
and LeNetBNFlax
), we define a train_procedure_builder
helper function to create separate procedures for these two networks. Note that we cannot simply pass the model class to the training functions because we are using @jax.jit
and nn.Module
is not a valid JAX type.
Examine learned parameters
Use Flax's BatchNorm layer
The built-in layer is much faster than our Python code, since it is implemented in C++. Note that instead of specifying ndims=2 for fully connected layer (batch x features) and ndims=4 for convolutional later (batch x channels x height x width), we simply use BatchNorm
and take advantage of JAX's shape inference feature.