Path: blob/master/notebooks/misc/dp_mixgauss_forward.ipynb
1192 views
Kernel: Python 3.8.10 64-bit
In [ ]:
try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from pml.dp_mixgauss_utils import dp_mixture_simu from pml.multivariate_t_utils import NormalInverseWishart import numpy as np import jax.numpy as jnp from jax import random from jax.scipy.linalg import sqrtm import matplotlib.pyplot as plt
In [ ]:
# Example dim = 2 # Set the hyperparameter for the NIW distribution hyper_params = dict(loc=jnp.zeros(dim), mean_precision=0.05, df=dim + 5, scale=jnp.eye(dim)) # Generate the NIW object niw = NormalInverseWishart(**hyper_params) # Plot N = 1000 alpha = [1.0, 2.0] bb = np.arange(0, 2 * np.pi, 0.02) ss = [50, 500, 1000] fig, axes = plt.subplots(3, 2) plt.setp(axes, xticks=[], yticks=[]) key = random.PRNGKey(3) for i in range(2): Z, X, Mu, Sigma = dp_mixture_simu(N, alpha[i], niw, key) Sig_root = jnp.array([sqrtm(sigma) for sigma in Sigma]) for j in range(3): s = ss[j] axes[j, i].plot(X[:s, 0], X[:s, 1], ".", markersize=5) for k in jnp.unique(Z[:s]): sig_root = Sig_root[ k, ] mu = Mu[ [k], ].T circ = mu.dot(jnp.ones((1, len(bb)))) + sig_root.dot(jnp.vstack([jnp.sin(bb), jnp.cos(bb)])) axes[j, i].plot(circ[0, :], circ[1, :], linewidth=2, color="k") pml.savefig("dpmForwardKey%sN%s.pdf" % (key, N)) plt.show()