Kernel: Python 3
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/14/lenet_jax.ipynb
The LeNet CNN
Based on sec 6.5 of http://d2l.ai/chapter_convolutional-neural-networks/lenet.html
In [ ]:
mkdir: cannot create directory ‘figures’: File exists
Make the model
We hard-code the assumption that the input will be 1x28x28, as is the case for (Fashion) MNIST.
In [ ]:
In [ ]:
Reshape output shape: torch.Size([1, 1, 28, 28])
Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])
Data
In [ ]:
In [ ]:
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Training and eval
We have to move the model and the data to the GPU, otherwise the code is terribly slow.
In [ ]:
In [ ]:
In [ ]:
In [ ]:
##Train function
The training loop is fairly standard. It uses the Animator class to make a "real time" plot of 3 metrics over time.
In [ ]:
Learning curve
In [ ]:
loss 0.468, train acc 0.823, test acc 0.820
44107.5 examples/sec on cuda:0