Path: blob/master/notebooks/misc/bnn_hmc_gaussian.ipynb
1192 views
Kernel: Python 3
(SG)HMC for inferring params of a 2d Gaussian
Based on
In [1]:
Out[1]:
[GpuDevice(id=0, process_index=0)]
In [25]:
Out[25]:
fatal: destination path 'google-research' already exists and is not an empty directory.
In [26]:
Out[26]:
/content/google-research
In [27]:
Out[27]:
core README.md run_sgd.py utils
make_posterior_surface_plot.py requirements.txt run_sgmcmc.py
notebooks run_hmc.py run_vi.py
In [5]:
Out[5]:
Collecting optax
Downloading optax-0.0.9-py3-none-any.whl (118 kB)
|████████████████████████████████| 118 kB 12.1 MB/s eta 0:00:01
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)
Collecting chex>=0.0.4
Downloading chex-0.0.8-py3-none-any.whl (57 kB)
|████████████████████████████████| 57 kB 5.7 MB/s eta 0:00:01
Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.19)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.70+cuda110)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1)
Installing collected packages: chex, optax
Successfully installed chex-0.0.8 optax-0.0.9
Setup
In [28]:
In [29]:
Data and model
In [7]:
In [8]:
Out[8]:
In [9]:
In [30]:
In [31]:
HMC
In [32]:
Out[32]:
Leapfrog steps per iteration: 16
In [33]:
In [34]:
In [35]:
Out[35]:
100%|██████████| 500/500 [03:26<00:00, 2.42it/s]
In [36]:
Out[36]:
0
In [37]:
Out[37]:
(DeviceArray(-4997.733, dtype=float32),
DeviceArray([-1.e+04, -1.e+00], dtype=float32),
DeviceArray(-4997.733, dtype=float32),
ShardedDeviceArray([nan], dtype=float32))
In [38]:
Out[38]:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-38-dde50ab74494> in <module>()
1
----> 2 all_samples_cat = onp.stack(all_samples)
<__array_function__ internals> in stack(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/numpy/core/shape_base.py in stack(arrays, axis, out)
421 arrays = [asanyarray(arr) for arr in arrays]
422 if not arrays:
--> 423 raise ValueError('need at least one array to stack')
424
425 shapes = {arr.shape for arr in arrays}
ValueError: need at least one array to stack
In [19]:
Out[19]:
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-19-ccd922184a23> in <module>()
----> 1 plt.scatter(all_samples_cat[:, 0], all_samples_cat[:, 1], alpha=0.3)
2 plt.grid()
NameError: name 'all_samples_cat' is not defined
Blackjax
In [1]:
Out[1]:
Requirement already satisfied: blackjax in /usr/local/lib/python3.7/dist-packages (0.2.1)
In [2]:
In [3]:
Out[3]:
[GpuDevice(id=0, process_index=0)]
In [17]:
In [18]:
Out[18]:
HMCState(position={'params': array([0., 0.])}, potential_energy=DeviceArray(-2.7672932, dtype=float32), potential_energy_grad={'params': DeviceArray([0., 0.], dtype=float32)})
In [20]:
Out[20]:
CPU times: user 3.59 s, sys: 132 ms, total: 3.72 s
Wall time: 2.07 s
In [21]:
Out[21]:
CPU times: user 1.81 ms, sys: 0 ns, total: 1.81 ms
Wall time: 925 µs
In [22]:
In [23]:
Out[23]:
(500, 2)
CPU times: user 2.54 s, sys: 69.8 ms, total: 2.61 s
Wall time: 1.72 s
In [24]:
Out[24]:
In [ ]: