Path: blob/master/notebooks/book1/14/batchnorm_jax.ipynb
2280 views
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/14/batchnorm_torch.ipynb
Batch normalization
We implement a batchnorm layer from scratch and add to LeNet CNN.
Code based on sec 7.5 of http://d2l.ai/chapter_convolutional-modern/batch-norm.html
Implementation from scratch
For fully connected layers, we take the average along minibatch samples for each dimension independently. For 2d convolutional layers, we take the average along minibatch samples, and along horizontal and vertical locations, for each channel (feature dimension) independently.
When training, we update the estimate of the mean and variance using a moving average. When testing (doing inference), we use the pre-computed values.
Wrap the batch norm function in a layer
Applying batch norm to LeNet
We add BN layers after some of the convolutions and fully connected layers, but before the activation functions.
Train the model
We train the model using the same code as in the standard LeNet colab. The only difference from the previous colab is the larger learning rate (which is possible because BN stabilizes training).
Plotting
Training Function
We create a subclass of train_state.TrainState store the auxilliary variables (i.e. gamma and beta) required by BatchNorm.
Since the same training procedure needs to be applied on two different networks (i.e. LeNetBN and LeNetBNFlax), we define a train_procedure_builder helper function to create separate procedures for these two networks. Note that we cannot simply pass the model class to the training functions because we are using @jax.jit and nn.Module is not a valid JAX type.
Examine learned parameters
Use Flax's BatchNorm layer
The built-in layer is much faster than our Python code, since it is implemented in C++. Note that instead of specifying ndims=2 for fully connected layer (batch x features) and ndims=4 for convolutional later (batch x channels x height x width), we simply use BatchNorm and take advantage of JAX's shape inference feature.