Path: blob/master/notebooks/book2/19/bnn_mnist_sgld_blackjax.ipynb
1192 views
Kernel: Python [conda env:py3713_2]
SGD & Stochastic Gradient Langevin Dynamics (SGLD) predictions using Bayesian neural network
author: @karm-patel
Resources:
In [1]:
In [2]:
Out[2]:
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/pkg_resources/__init__.py:122: PkgResourcesDeprecationWarning: LOCAL is an invalid version and will not be supported in a future release
warnings.warn(
Mnist Dataset
In [3]:
Out[3]:
2022-07-08 16:03:20.340500: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/compilers/openmpi/4.1.2/lib64/openmpi:/opt/compilers/openmpi/4.1.2/lib64:/usr/local/cuda/lib64:
2022-07-08 16:03:20.340526: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
In [4]:
In [5]:
Out[5]:
DeviceArray([ 963, 1119, 962, 1090, 917, 916, 1018, 1042, 997, 976], dtype=int32)
In [6]:
Out[6]:
DeviceArray([4960, 5623, 4996, 5041, 4925, 4505, 4900, 5223, 4854, 4973], dtype=int32)
In [7]:
In [8]:
Out[8]:
((50000, 784), (50000, 10))
In [9]:
Out[9]:
((10000, 784), (10000, 10))
In [10]:
Out[10]:
(60000, 1000)
Flax MLP model
In [11]:
In [12]:
SGD
In [13]:
Out[13]:
{'params': {'0_Dense': {'bias': (300,), 'kernel': (784, 300)},
'1_Dense': {'bias': (100,), 'kernel': (300, 100)},
'2_Dense': {'bias': (10,), 'kernel': (100, 10)}}}
In [14]:
In [15]:
In [16]:
Out[16]:
CPU times: user 6.92 s, sys: 3.09 s, total: 10 s
Wall time: 2.86 s
In [17]:
Out[17]:
(DeviceArray([1. , 1. , 1. , 1. , 0.9990001,
1. , 1. , 0.9990001, 1. , 0.9990001], dtype=float32),
DeviceArray([0.9783 , 0.97969997, 0.978 , 0.97789997, 0.97889996,
0.97889996, 0.9783 , 0.9781 , 0.9773 , 0.9781 ], dtype=float32))
In [18]:
Out[18]:
In [19]:
Out[19]:
(10000,)
In [20]:
In [21]:
Out[21]:
1.0 0.9781
1.0 0.9781
1.0 0.9781
1.0 0.9781
0.9995 0.9785893
0.99759996 0.97965115
0.9905 0.98283696
0.98149997 0.98553234
0.9726 0.98797035
0.95739996 0.99101734
0.095699996 1.0
In [22]:
In [23]:
Out[23]:
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
SGLD
In [24]:
In [25]:
In [26]:
Out[26]:
{'params': {'0_Dense': {'bias': (300,), 'kernel': (784, 300)},
'1_Dense': {'bias': (100,), 'kernel': (300, 100)},
'2_Dense': {'bias': (10,), 'kernel': (100, 10)}}}
In [27]:
In [28]:
Out[28]:
CPU times: user 2.81 s, sys: 119 ms, total: 2.93 s
Wall time: 6.35 s
In [29]:
In [30]:
Out[30]:
CPU times: user 7.66 s, sys: 3.04 s, total: 10.7 s
Wall time: 3.24 s
In [31]:
Out[31]:
(DeviceArray([0.95600003, 0.93900007, 0.961 , 0.9480001 , 0.943 ,
0.95600003, 0.95500004, 0.95400006, 0.938 , 0.95500004], dtype=float32),
DeviceArray([0.93979996, 0.9349 , 0.9414 , 0.9374 , 0.93479997,
0.9411 , 0.9406 , 0.941 , 0.93539995, 0.9395 ], dtype=float32))
In [32]:
Out[32]:
In [33]:
Out[33]:
FrozenDict({
params: {
0_Dense: {
bias: (300, 300),
kernel: (300, 784, 300),
},
1_Dense: {
bias: (300, 100),
kernel: (300, 300, 100),
},
2_Dense: {
bias: (300, 10),
kernel: (300, 100, 10),
},
},
})
In [34]:
In [35]:
Out[35]:
(10000,)
In [36]:
In [37]:
Out[37]:
1.0 0.92040634
1.0 0.92040634
1.0 0.92040634
0.9995 0.9207711
0.99439996 0.9239689
0.9798 0.9329346
0.95629996 0.94558614
0.93299997 0.9563945
0.9051 0.9683604
0.86139995 0.97970706
0.4529 0.9973504
In [38]:
Out[38]:
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Distribution shift
In [39]:
In [40]:
Out[40]:
((10000, 784), (10000, 10))
SGD
In [41]:
Out[41]:
(10000,)
In [42]:
Out[42]:
1.0 0.086799994
1.0 0.086799994
1.0 0.086799994
0.99909997 0.08647783
0.9917 0.08631642
0.9709 0.084251724
0.93399996 0.08072805
0.89159995 0.078398384
0.83889997 0.07283347
0.7646 0.0661784
In [43]:
Out[43]:
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
SGLD
In [44]:
Out[44]:
(10000,)
In [45]:
Out[45]:
1.0 0.057391662
1.0 0.057391662
1.0 0.057391662
0.9806 0.055688012
0.86289996 0.045777418
0.72609997 0.03393013
0.5959 0.024917493
0.4761 0.018405098
0.3579 0.013786906
0.2383 0.0102923475
0.021699999 0.0046082954
In [46]:
Out[46]:
/home/patel_zeel/miniconda3/envs/jax_gpu/lib/python3.9/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [ ]: