Path: blob/master/notebooks/misc/dp_mixgauss_sample.ipynb
1192 views
Kernel: Python 3.8.10 64-bit
In [5]:
try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml
In [ ]:
from probml_utils.dp_mixgauss_utils import dp_mixgauss_sample, NormalInverseWishart
In [2]:
import jax.numpy as jnp from jax import random, vmap from scipy.linalg import sqrtm import matplotlib.pyplot as plt
In [ ]:
# Example dim = 2 # Set the hyperparameter for the NIW distribution hyper_params = dict(loc=jnp.zeros(dim), mean_precision=0.05, df=dim + 5, scale=jnp.eye(dim)) # Generate the NIW object dp_base_measure = NormalInverseWishart(**hyper_params) key = random.PRNGKey(0) num_of_samples = 1000 dp_concentrations = jnp.array([1.0, 2.0]) key, *subkeys = random.split(key, 3) # Sampling from the DP mixture distribution cluster_means, cluster_covs, samples = vmap(dp_mixgauss_sample, in_axes=(0, None, 0, None))( jnp.array(subkeys), num_of_samples, dp_concentrations, dp_base_measure ) bb = jnp.arange(0, 2 * jnp.pi, 0.02) sample_size = [50, 500, 1000] fig, axes = plt.subplots(3, 2) plt.setp(axes, xticks=[], yticks=[]) for i in range(2): cluster_mean = cluster_means[i] cluster_cov = cluster_covs[i] sample = samples[i] for j in range(3): s = sample_size[j] # plotting samples axes[j, i].plot(sample[:s, 0], sample[:s, 1], ".", markersize=5) # plotting covariance ellipses mu_per_cluster, indices = jnp.unique(cluster_mean[:s], return_index=True, axis=0) cov_per_cluster = cluster_cov[indices] cov_root_per_cluster = jnp.array([sqrtm(cov) for cov in cov_per_cluster]) for mu, cov_root in zip(mu_per_cluster, cov_root_per_cluster): mu = jnp.atleast_2d(mu).T circ = mu.dot(jnp.ones((1, len(bb)))) + cov_root.dot(jnp.vstack([jnp.sin(bb), jnp.cos(bb)])) axes[j, i].plot(circ[0, :], circ[1, :], linewidth=2, color="k") plt.show()