Path: blob/master/deprecated/notebooks/smc_tempered_1d_bimodal.ipynb
1192 views
Kernel: Python 3
Illustrate SMC with HMC kernel on a 1d bimodal distribution
Code is from https://github.com/blackjax-devs/blackjax/blob/main/notebooks/TemperedSMC.ipynb
In [1]:
Out[1]:
Collecting git+https://github.com/blackjax-devs/blackjax
Cloning https://github.com/blackjax-devs/blackjax to /tmp/pip-req-build-juma940_
Running command git clone -q https://github.com/blackjax-devs/blackjax /tmp/pip-req-build-juma940_
Building wheels for collected packages: blackjax
Building wheel for blackjax (setup.py) ... done
Created wheel for blackjax: filename=blackjax-0.2.1-py3-none-any.whl size=71257 sha256=3463c47d56d4535a1867d2bd72c1a41028b720c2174af157ee64d56d6b722686
Stored in directory: /tmp/pip-ephem-wheel-cache-fo5zm9_r/wheels/d3/42/75/b8e1ec1f9f837fdd16abb96cb47725ff083f5f0774610070e4
Successfully built blackjax
Installing collected packages: blackjax
Successfully installed blackjax-0.2.1
In [2]:
Target distribution
In [3]:
Out[3]:
Tempered distribution
In [4]:
Out[4]:
HMC
In [6]:
In [7]:
Out[7]:
CPU times: user 1.75 s, sys: 57.3 ms, total: 1.81 s
Wall time: 1.77 s
NUTS
In [8]:
Out[8]:
CPU times: user 8.26 s, sys: 19.5 ms, total: 8.28 s
Wall time: 8.25 s
SMC
In [9]:
Out[9]:
Number of steps in the adaptive algorithm: 8
CPU times: user 5.37 s, sys: 1.64 s, total: 7.01 s
Wall time: 5.07 s
SMC modified
We change the code to track the temperature at each step. But first we illustrate how to append to a list inside jax while loop. The trick is to use a fixed-size buffer.
In [5]:
Out[5]:
{'current': DeviceArray(5, dtype=int32, weak_type=True), 'data': DeviceArray([1., 2., 3., 4., 5., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
If necessary, we can grow the buffer on demand.
In [35]:
Out[35]:
[1. 2. 3. 4. 5. 0.]
Now we return to SMC.
In [27]:
Out[27]:
CPU times: user 4.79 s, sys: 730 ms, total: 5.52 s
Wall time: 4.65 s
In [28]:
Out[28]:
[0.00538135 0.01532278 0.04021964 0.0823563 0.16662113 0.31522763
0.60703117 1. ]
In [32]:
Out[32]:
In [33]:
Out[33]:
In [ ]:
TemperedSMCState(particles=DeviceArray([[1.146692 ],
[1.146692 ],
[1.1463876 ],
...,
[0.57010907],
[0.70353454],
[0.7037285 ]], dtype=float32), lmbda=DeviceArray(1., dtype=float32, weak_type=True))
Let's track the adaptive temperature.
In [ ]:
(10000, 1)
In [ ]: