Path: blob/master/notebooks/misc/dp_mixgauss_cluster.ipynb
1192 views
Kernel: Python 3.8.10 64-bit
In [ ]:
try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from pml.dp_mixgauss_utils import dp_cluster from pml.multivariate_t_utils import NormalInverseWishart import numpy as np import jax.numpy as jnp from jax import random from jax.scipy.linalg import sqrtm import matplotlib.pyplot as plt
In [ ]:
# Sample the generative model. dim = 2 niw_params = dict(loc=jnp.zeros(dim), mean_precision=0.05, df=dim + 5, scale=jnp.eye(dim)) niw = NormalInverseWishart(**niw_params) N = 300 alpha = 1.0 K = 4 key = random.PRNGKey(3) key, subkey = random.split(key) params = niw.sample(seed=subkey, sample_shape=(K,)) Sigma = params["Sigma"] Mu = params["mu"] key, subkey = random.split(key) pi0 = random.dirichlet(subkey, alpha * jnp.ones(K), shape=(N,)) key, subkey = random.split(key) Z0 = jnp.array([random.categorical(subkey, jnp.log(p)) for p in pi0]) key, *subkey = random.split(key, N + 1) X0 = jnp.array([random.multivariate_normal(subkey[i], Mu[Z0[i]], Sigma[Z0[i]]) for i in range(N)]) key, subkey = random.split(key) X1 = random.permutation(subkey, X0, axis=0) # Perform the posterior inference hyper_params = niw_params.values() T = 110 key, subkey = random.split(key) Zs0 = dp_cluster(T, X0, alpha, hyper_params, key) Zs1, lp1 = dp_cluster(T, X1, alpha, hyper_params, key) # plot bb = np.arange(0, 2 * np.pi, 0.02) ts = [10, 50, 100] Xs = [X0, X1] Zs = [Zs0, Zs1] # Different rows represents different iterations in posterior sampling; # different column represents different shuffling of the data. fig, axes = plt.subplots(3, 2) plt.setp(axes, xticks=[], yticks=[]) for i in range(2): zs = Zs[i] x = Xs[i] for j in range(3): axes[j, i].plot(x[:, 0], x[:, 1], ".", markersize=5) # The clustering after ts[j] iterations Z = zs[ ts[j], ] for k in jnp.unique(Z): xk = x[ Z == k, ] mu_k = jnp.atleast_2d(jnp.mean(xk, axis=0)) Sig_k = (xk - mu_k).T @ (xk - mu_k) / (xk.shape[0] - 1) Sig_root = sqrtm(Sig_k) circ = mu_k.T.dot(jnp.ones((1, len(bb)))) + Sig_root.dot(jnp.vstack([jnp.sin(bb), jnp.cos(bb)])) axes[j, i].plot(circ[0, :], circ[1, :], linewidth=2, color="k") pml.savefig("dpmClusterKey%sN%s.pdf" % (key, N)) plt.show()