Kernel: Python 3
Base VAE Code: https://github.com/google/flax/tree/main/examples/vae
Followed the original Theano repository of LVAE paper: https://github.com/casperkaae/LVAE
Authors gave more details about the paper in issues: https://github.com/casperkaae/LVAE/issues/1
Finally, in some parts I followed: https://github.com/AntixK/PyTorch-VAE/blob/master/models/lvae.py
PS: Importance weighting is not implemented.
Firat Oncel / [email protected]
In [ ]:
Omniglot downlad:
In [ ]:
In [ ]:
This code is created with reference to torchvision/utils.py. Modify: torch.tensor -> jax.numpy.DeviceArray If you want to know about this file in detail, please visit the original code: https://github.com/pytorch/vision/blob/master/torchvision/utils.py
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
Dl Completed...: 0%| | 0/4 [00:00<?, ? file/s]
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Number of train samples: 50000
eval epoch: 25, loss: 168.5342, BCE: 63.0598, KLD: 105.4744
eval epoch: 50, loss: 137.9575, BCE: 66.2254, KLD: 71.7321
eval epoch: 75, loss: 124.7530, BCE: 69.3428, KLD: 55.4102
eval epoch: 100, loss: 117.7395, BCE: 72.3862, KLD: 45.3534
eval epoch: 125, loss: 113.7366, BCE: 74.9316, KLD: 38.8050
eval epoch: 150, loss: 111.8119, BCE: 76.5078, KLD: 35.3041
eval epoch: 175, loss: 110.5554, BCE: 78.0352, KLD: 32.5202
eval epoch: 200, loss: 109.9553, BCE: 80.2538, KLD: 29.7015
eval epoch: 225, loss: 109.5024, BCE: 79.4809, KLD: 30.0215
eval epoch: 250, loss: 109.2625, BCE: 80.4123, KLD: 28.8501
eval epoch: 275, loss: 108.8212, BCE: 79.7398, KLD: 29.0814
eval epoch: 300, loss: 108.6366, BCE: 79.4850, KLD: 29.1516
eval epoch: 325, loss: 108.4265, BCE: 79.4662, KLD: 28.9603
eval epoch: 350, loss: 108.3445, BCE: 79.8699, KLD: 28.4746
eval epoch: 375, loss: 108.2001, BCE: 79.4365, KLD: 28.7635
eval epoch: 400, loss: 108.1749, BCE: 79.2290, KLD: 28.9459
eval epoch: 425, loss: 108.1051, BCE: 79.1918, KLD: 28.9132
eval epoch: 450, loss: 108.0225, BCE: 79.0795, KLD: 28.9431
eval epoch: 475, loss: 107.9945, BCE: 79.2377, KLD: 28.7567
eval epoch: 500, loss: 107.9770, BCE: 78.9827, KLD: 28.9944
In [ ]:
Downloading and preparing dataset fashion_mnist/3.0.1 (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.1...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
0 examples [00:00, ? examples/s]
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incomplete2TXXTI/fashion_mnist-train.tfrecord
0%| | 0/60000 [00:00<?, ? examples/s]
0 examples [00:00, ? examples/s]
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incomplete2TXXTI/fashion_mnist-test.tfrecord
0%| | 0/10000 [00:00<?, ? examples/s]
Dataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.
Number of train samples: 50000
eval epoch: 25, loss: 277.9286, BCE: 216.7530, KLD: 61.1756
eval epoch: 50, loss: 260.0067, BCE: 216.7073, KLD: 43.2995
eval epoch: 75, loss: 252.6209, BCE: 218.1791, KLD: 34.4418
eval epoch: 100, loss: 249.3030, BCE: 219.7867, KLD: 29.5163
eval epoch: 125, loss: 246.9963, BCE: 220.9972, KLD: 25.9991
eval epoch: 150, loss: 245.6965, BCE: 222.5772, KLD: 23.1192
eval epoch: 175, loss: 244.7735, BCE: 223.3629, KLD: 21.4106
eval epoch: 200, loss: 244.4016, BCE: 224.5442, KLD: 19.8574
eval epoch: 225, loss: 244.1679, BCE: 225.1086, KLD: 19.0593
eval epoch: 250, loss: 243.9687, BCE: 224.3791, KLD: 19.5896
eval epoch: 275, loss: 243.6929, BCE: 224.6925, KLD: 19.0004
eval epoch: 300, loss: 243.5842, BCE: 224.6075, KLD: 18.9767
eval epoch: 325, loss: 243.5537, BCE: 224.2490, KLD: 19.3047
eval epoch: 350, loss: 243.4431, BCE: 224.2475, KLD: 19.1956
eval epoch: 375, loss: 243.3540, BCE: 224.2650, KLD: 19.0890
eval epoch: 400, loss: 243.3495, BCE: 223.8712, KLD: 19.4782
eval epoch: 425, loss: 243.1932, BCE: 224.3737, KLD: 18.8194
eval epoch: 450, loss: 243.1999, BCE: 223.7480, KLD: 19.4519
eval epoch: 475, loss: 243.1479, BCE: 223.8625, KLD: 19.2853
eval epoch: 500, loss: 243.1307, BCE: 224.0523, KLD: 19.0783
In [ ]:
Number of train samples: 24345
eval epoch: 25, loss: 205.9442, BCE: 80.7324, KLD: 125.2118
eval epoch: 50, loss: 173.0699, BCE: 82.3096, KLD: 90.7603
eval epoch: 75, loss: 157.6246, BCE: 86.0767, KLD: 71.5479
eval epoch: 100, loss: 148.7567, BCE: 90.9123, KLD: 57.8445
eval epoch: 125, loss: 143.3184, BCE: 94.5159, KLD: 48.8025
eval epoch: 150, loss: 140.0756, BCE: 98.2680, KLD: 41.8076
eval epoch: 175, loss: 138.1540, BCE: 101.4476, KLD: 36.7064
eval epoch: 200, loss: 137.2294, BCE: 105.6944, KLD: 31.5350
eval epoch: 225, loss: 136.5943, BCE: 104.2408, KLD: 32.3535
eval epoch: 250, loss: 136.3358, BCE: 105.8063, KLD: 30.5295
eval epoch: 275, loss: 136.0621, BCE: 104.4986, KLD: 31.5634
eval epoch: 300, loss: 135.9718, BCE: 105.2400, KLD: 30.7317
eval epoch: 325, loss: 135.7987, BCE: 104.4629, KLD: 31.3358
eval epoch: 350, loss: 135.6864, BCE: 104.8219, KLD: 30.8644
eval epoch: 375, loss: 135.4615, BCE: 104.7266, KLD: 30.7349
eval epoch: 400, loss: 135.4405, BCE: 104.8456, KLD: 30.5949
eval epoch: 425, loss: 135.3270, BCE: 104.6016, KLD: 30.7253
eval epoch: 450, loss: 135.3859, BCE: 104.5079, KLD: 30.8780
eval epoch: 475, loss: 135.2347, BCE: 104.6587, KLD: 30.5760
eval epoch: 500, loss: 135.2847, BCE: 104.4621, KLD: 30.8226