Path: blob/master/notebooks/book1/14/supplementary/lecun1989_with_commentary.ipynb
1193 views
Annotated MNIST
This tutorial demonstrates how to construct the original convolutional neural network (CNN) proposed by LeCun et al. in http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf
The original pytorch tutorial is at https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py.
It is converted to use JAX/ Flax, and is based on Flax's official Annotated MNIST notebook.
1. Imports
Import JAX, Flax, ordinary NumPy, and torchvision datasets. Flax can use any data-loading pipeline and this example demonstrates how to utilize torchvision datasets.
2. Define network
Create the original convolutional neural network with the Linen API by subclassing Module. Because the architecture in this example is fairly complex—the connection between the first and second hidden layers is quite unusual from a modern point of view—you cannot define the inlined submodules directly within the __call__
method and wrap it with the @compact decorator.
The most notable difference between LeCun1989 and recent CNNs is that the "units" in the original architecture share their weights but do not share their biases (thresholds), whereas its modern descendants share both weights and biases between the units. We define a custom LocalBias
layer to capture this particularity.
Now, we need to write our own __call__
function. In particular, H2 neurons all connect to only 8 of the 12 input planes. We implement this with 3 separate convolutions that we concatenate the results of. Additionally, we define a custom weight-initializing function lecun1989_uniform
and a static method pad
to pad images with -1
on the edges.
3. Define loss
Define a cross-entropy loss function using just jax.numpy that takes the model's logits and label vectors and returns a scalar loss. The labels can be one-hot encoded with jax.nn.one_hot, as demonstrated below.
Note that for demonstration purposes, we return nn.log_softmax()
from the model and then simply multiply these (normalized) logits with the labels. In our examples/mnist
folder we actually return non-normalized logits and then use optax.softmax_cross_entropy()
to compute the loss, which has the same result.
4. Metric computation
For loss and accuracy metrics, create a separate function:
5. Loading data
Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers.
6. Create train state
A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state.
Also adding optimizer & model to this state has the advantage that we only need to pass around a single argument to functions like train_step()
(see below).
Because this is such a common pattern, Flax provides the class flax.training.train_state.TrainState that serves most basic usecases. Usually one would subclass it to add more data to be tracked, but in this example we can use it without any modifications.
7. Training step
A function that:
Evaluates the neural network given the parameters and a batch of input images with the Module.apply method.
Computes the
mse_loss
loss function.Evaluates the loss function and its gradient using jax.value_and_grad.
Applies a pytree of gradients to the optimizer to update the model's parameters.
Computes the metrics using
compute_metrics
(defined earlier).
Use JAX's @jit decorator to trace the entire train_step
function and just-in-time compile it with XLA into fused device operations that run faster and more efficiently on hardware accelerators.
8. Evaluation step
Create a function that evaluates your model on the test set with Module.apply
9. Train function
Define a training function that:
Shuffles the training data before each epoch using jax.random.permutation that takes a PRNGKey as a parameter (check the JAX - the sharp bits).
Runs an optimization step for each batch.
Retrieves the training metrics from the device with
jax.device_get
and computes their mean across each batch in an epoch.Returns the optimizer with updated parameters and the training loss and accuracy metrics.
10. Eval function
Create a model evaluation function that:
Retrieves the evaluation metrics from the device with
jax.device_get
.Copies the metrics data stored in a JAX pytree.
11. Download data
12. Seed randomness
Get one PRNGKey and split it to get a second key that you'll use for parameter initialization. (Learn more about PRNG chains and JAX PRNG design.)
13. Initialize train state
Remember that function initializes both the model parameters and the optimizer and puts both into the training state dataclass that is returned.
We can verify that the parameters are in the correct shape.
14. Train and evaluate
Once the training and testing is done after 23 epochs, the output should show that your model was able to achieve approximately 95% accuracy. This may not seem very impressive, but remember that this network was from 1989!
Congrats! You made it to the end of the annotated LeCun1989 example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax's Git repo: