Path: blob/master/notebooks/book2/12/neals_funnel.ipynb
1192 views
Kernel: Python [conda env:py3713_2]
Neal's funnel: Centered and Non-Centered parameterization
author: @karm-patel
Please find numpyro implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book2/11/neals_funnel_numpyro.ipynb
In [13]:
In [ ]:
In [4]:
Out[4]:
/home/patel_karm/anaconda3/envs/py3713_2/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
We use following toy model to create a funnel like plot (from section 11.6.4, book2)
In [5]:
Out[5]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [6]:
In [7]:
Out[7]:
/home/patel_karm/anaconda3/envs/py3713_2/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [8]:
Funnel samples with centered parameterization
In [9]:
In [10]:
Out[10]:
HMCState(position={'x': 0.0, 'v': 0.0}, potential_energy=DeviceArray(2.9364896, dtype=float32), potential_energy_grad={'v': DeviceArray(0.5, dtype=float32, weak_type=True), 'x': DeviceArray(0., dtype=float32, weak_type=True)})
In [14]:
Out[14]:
CPU times: user 1.97 s, sys: 21 ms, total: 1.99 s
Wall time: 1.94 s
In [15]:
Out[15]:
/home/patel_karm/anaconda3/envs/py3713_2/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [18]:
Out[18]:
In [19]:
Out[19]:
arviz - WARNING - Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)
Non-centered
We use reparameterized form to take samples
In [20]:
In [21]:
Out[21]:
HMCState(position={'v': 5.0, 'z': 1.0}, potential_energy=DeviceArray(4.8253784, dtype=float32), potential_energy_grad={'v': DeviceArray(0.5555555, dtype=float32, weak_type=True), 'z': DeviceArray(1., dtype=float32, weak_type=True)})
In [22]:
Out[22]:
CPU times: user 1.78 s, sys: 23.2 ms, total: 1.8 s
Wall time: 1.75 s
In [23]:
In [24]:
Out[24]:
/home/patel_karm/anaconda3/envs/py3713_2/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [26]:
Out[26]:
In [27]:
Out[27]:
arviz - WARNING - Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)
In [ ]: