Path: blob/master/notebooks/tutorials/numpyro_intro.ipynb
1192 views
NumPyro is probabilistic programming language built on top of JAX. It is very similar to Pyro, which is built on top of PyTorch. However, the HMC algorithm in NumPyro is much faster.
Both Pyro flavors are usually also faster than PyMc3, and allow for more complex models, since Pyro is integrated into Python.
Installation
Collecting numpyro
Downloading numpyro-0.7.2-py3-none-any.whl (250 kB)
|████████████████████████████████| 250 kB 5.2 MB/s eta 0:00:01
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.41.1)
Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.1.69+cuda110)
Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.2.17)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.19.5)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.12.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.4.1)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.12)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->numpyro) (1.15.0)
Installing collected packages: numpyro
Successfully installed numpyro-0.7.2
Example: 1d Gaussian with unknown mean.
We use the simple example from the Pyro intro. The goal is to infer the weight of an object, given noisy measurements . We assume the following model:
Where is the initial guess.
Exact inference
By Bayes rule for Gaussians, we know that the exact posterior, given a single observation , is given by
Ancestral sampling
MCMC
Stochastic variational inference
Laplace (quadratic) approximation
Example: Beta-Bernoulli model
Example is from SVI tutorial
The model is where . In the code, is called latent_fairness
.
Exact inference
The posterior is given by
Variational inference
MCMC
Distributions
1d Gaussian
Multivariate Gaussian
Shape semantics
Numpyro, Pyro and TFP and Distrax all distinguish between 'event shape' and 'batch shape'. For a D-dimensional Gaussian, the event shape is (D,), and the batch shape will be (), meaning we have a single instance of this distribution. If the covariance is diagonal, we can view this as D independent 1d Gaussians, stored along the batch dimension; this will have event shape () but batch shape (2,).
When we sample from a distribution, we also specify the sample_shape. Suppose we draw N samples from a single D-dim diagonal Gaussian, and N samples from D 1d Gaussians. These samples will have the same shape. However, the semantics of logprob differs. We illustrate this below.
We can turn a set of independent distributions into a single product distribution using the Independent class