Path: blob/master/deprecated/notebooks/quantized_autoencoder_mnist.ipynb
1192 views
Kernel: Python 3
Autoencoder for MNIST using binary latent code
Uses straight-through estimator to approximate the gradient. Code is modified from https://www.hassanaskary.com/python/pytorch/deep learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html
In [1]:
Out[1]:
cuda:0
In [2]:
Out[2]:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw
In [28]:
In [29]:
Out[29]:
Starting epoch 0 of 5
100%|██████████| 938/938 [00:17<00:00, 53.21it/s]
Loss: 0.048126302659511566
Starting epoch 1 of 5
100%|██████████| 938/938 [00:17<00:00, 53.61it/s]
Loss: 0.03352956101298332
Starting epoch 2 of 5
100%|██████████| 938/938 [00:17<00:00, 53.68it/s]
Loss: 0.03598036617040634
Starting epoch 3 of 5
100%|██████████| 938/938 [00:18<00:00, 51.94it/s]
Loss: 0.029913021251559258
Starting epoch 4 of 5
100%|██████████| 938/938 [00:17<00:00, 54.00it/s]
Loss: 0.02666076086461544
In [30]:
Out[30]:
[torch.Size([64, 1, 28, 28]), torch.Size([64])]
[(64, 256, 1, 1), (64, 1, 28, 28)]
In [34]:
Out[34]:
In [5]:
Out[5]: