Path: blob/master/notebooks/book2/12/mcmc_traceplots_unigauss_numpyro.ipynb
1193 views
Kernel: Python 3.7.13 ('py3713')
Please find blackjax implementation of this notebook here:https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book2/11/mcmc_traceplots_unigauss.ipynb
In [1]:
Out[1]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax version 0.3.13
jax backend cpu
Note: you may need to restart the kernel to use updated packages.
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/ipykernel_launcher.py:67: UserWarning: There are not enough devices to run parallel chains: expected 3 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(3)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
sample: 100%|██████████| 1000/1000 [00:02<00:00, 406.53it/s, 5 steps of size 1.51e-01. acc. prob=0.88]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1447.00it/s, 7 steps of size 8.36e-02. acc. prob=0.92]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1276.03it/s, 327 steps of size 5.80e-03. acc. prob=0.97]
mean std median 2.5% 97.5% n_eff r_hat
alpha 24.92 441.83 0.22 -983.78 867.54 52.04 1.02
sigma 735.89 2041.89 133.56 1.01 3162.79 163.15 1.03
Number of divergences: 28
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/ipykernel_launcher.py:90: UserWarning: There are not enough devices to run parallel chains: expected 3 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(3)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
sample: 100%|██████████| 1000/1000 [00:02<00:00, 394.62it/s, 3 steps of size 5.46e-01. acc. prob=0.91]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1480.57it/s, 7 steps of size 5.95e-01. acc. prob=0.91]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1482.90it/s, 7 steps of size 4.97e-01. acc. prob=0.91]
mean std median 2.5% 97.5% n_eff r_hat
alpha 0.04 1.24 0.01 -2.43 2.87 344.99 1.02
sigma 1.59 0.80 1.42 0.49 3.20 457.20 1.00
Number of divergences: 0
/home/patel_karm/anaconda3/envs/py3713/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [ ]: