Path: blob/master/notebooks/book2/13/smc_tempered_1d_bimodal.ipynb
1192 views
Kernel: Python [conda env:py3713]
Illustrate SMC with HMC kernel on a 1d bimodal distribution
Code is from https://blackjax-devs.github.io/blackjax/examples/TemperedSMC.html
In [1]:
In [2]:
In [3]:
Out[3]:
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
In [4]:
Target distribution
In [5]:
Out[5]:
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:84: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Tempered distribution
In [6]:
Out[6]:
HMC
In [7]:
In [8]:
Out[8]:
CPU times: user 1.38 s, sys: 30.1 ms, total: 1.41 s
Wall time: 1.37 s
In [9]:
Out[9]:
(0.0, 4.663308943341834)
NUTS
In [10]:
Out[10]:
CPU times: user 7.57 s, sys: 20.9 ms, total: 7.59 s
Wall time: 7.54 s
In [11]:
Out[11]:
SMC
In [12]:
In [13]:
Out[13]:
Number of steps in the adaptive algorithm: 9
CPU times: user 3.17 s, sys: 2.34 s, total: 5.5 s
Wall time: 2.63 s
In [14]:
Out[14]:
SMC modified
We change the code to track the temperature at each step. But first we illustrate how to append to a list inside jax while loop. The trick is to use a fixed-size buffer.
In [15]:
Out[15]:
{'current': DeviceArray(5, dtype=int32, weak_type=True), 'data': DeviceArray([1., 2., 3., 4., 5., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
If necessary, we can grow the buffer on demand.
In [16]:
Out[16]:
[1. 2. 3. 4. 5. 0.]
Now we return to SMC.
In [17]:
In [18]:
Out[18]:
Number of steps in the adaptive algorithm: 9
CPU times: user 3.24 s, sys: 3.25 s, total: 6.49 s
Wall time: 2.57 s
In [19]:
Out[19]:
[0.00538135 0.01532254 0.04021941 0.08235608 0.15540347 0.2983744
0.5712157 0.9644599 1. ]
In [20]:
Out[20]:
In [21]:
Out[21]:
In [22]:
Out[22]:
TemperedSMCState(particles=DeviceArray([[ 1.1464324 ],
[ 1.1464324 ],
[-1.1070769 ],
...,
[ 0.57004714],
[ 0.70350933],
[ 0.7039775 ]], dtype=float32), lmbda=DeviceArray(1., dtype=float32, weak_type=True))
Let's track the adaptive temperature.
In [23]:
Out[23]:
(10000, 1)