Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/vb_gmm_tfp.ipynb
1192 views
Kernel: Python 3

Open In Colab

Variational Bayes for Gaussian Mixture Models using TFP

Code is written by Dave Moore, with some tweaks by Kevin Murphy.

We use a diagonal Gaussian approximation to the posterior (after transforming the variables) using SVI objective, optimized with full batch gradient descent. See here for code that implements full-batch VBEM using a conjugate prior.

import functools import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability import bijectors as tfb from tensorflow_probability import distributions as tfd # from matplotlib import pylab as plt import seaborn as sns import pandas as pd import matplotlib.pyplot as plt print(tf.__version__, tfp.__version__)
2.5.0 0.13.0

Plotting code

from matplotlib.patches import Ellipse def plot_loc_scale(weight_, loc_, scale_tril_, color, ax): cov = np.dot(scale_tril_, scale_tril_.T) w, v = np.linalg.eig(cov) angle = np.arctan2(v[1, 0], v[1, 1]) * 360 / (2 * np.pi) height = 3 * np.sqrt(w[1]) # minor axis width = 3 * np.sqrt(w[0]) # major axis e = Ellipse(xy=loc_, width=width, height=height, angle=angle) ax.add_artist(e) e.set_clip_box(ax.bbox) e.set_alpha(weight_) e.set_facecolor(color) e.set_edgecolor("black") def plot_posterior_with_data(mix_, loc_, scale_tril_, data, ax, facecolors=None): ax.plot(data[:, 0], data[:, 1], "k.", markersize=3) ax.plot(loc_[:, 0], loc_[:, 1], "r^") num_components = len(mix_) np.random.seed(420) if facecolors is None: facecolors = sns.color_palette("deep", n_colors=num_components) weights_ = np.power(mix_, 0.8) # larger power means less emphasis on low weights weights_ = weights_ * (0.5 / np.max(weights_)) for i, (weight_, l_, st_) in enumerate(zip(weights_, loc_, scale_tril_)): plot_loc_scale(weight_, l_, st_, color=facecolors[i], ax=ax) def plot_posterior_sample(surrogate_posterior, data): fig = plt.figure(figsize=(10, 6), constrained_layout=True) gs = fig.add_gridspec(4, 4) mix, loc, _, _, scale_tril = surrogate_posterior.sample() num_components = len(mix) plot_posterior_with_data(mix.numpy(), loc.numpy(), scale_tril.numpy(), data=data, ax=fig.add_subplot(gs[:, :3])) ax = fig.add_subplot(gs[:1, 3]) sns.barplot(x=np.arange(num_components), y=mix.numpy(), ax=ax, palette="deep") ax.set_title("Mixture component weights")

Data

We use a datset of erruption times from the "Old Faithful" geyser in Yellowstone National Park.

url = "https://raw.githubusercontent.com/probml/probml-data/main/data/faithful.txt" # df = pd.read_csv(url, sep='\t', header=None, columns=['eruptions', 'waiting']) !wget $url data = np.array(np.loadtxt("faithful.txt")) print(data.shape)
--2021-07-14 20:06:38-- https://raw.githubusercontent.com/probml/probml-data/main/data/faithful.txt Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 5167 (5.0K) [text/plain] Saving to: ‘faithful.txt.2’ faithful.txt.2 100%[===================>] 5.05K --.-KB/s in 0s 2021-07-14 20:06:38 (67.9 MB/s) - ‘faithful.txt.2’ saved [5167/5167] (272, 2)
plt.figure() plt.scatter(data[:, 0], data[:, 1]) plt.xlabel("Eruption duration (mins)") plt.ylabel("Waiting time (mins)")
Text(0, 0.5, 'Waiting time (mins)')
Image in a Jupyter notebook
# Standardize the data (to simplify model fitting) mean = np.mean(data, axis=0) std = np.std(data, axis=0) data_normalized = (data - mean) / std
plt.figure() plt.scatter(data_normalized[:, 0], data_normalized[:, 1])
<matplotlib.collections.PathCollection at 0x7ff566d730d0>
Image in a Jupyter notebook

Model

We put a Gaussian prior on each mean vector, an LKJ prior on each correlation matrix, and a half-normal prior on each scale vector. (This is not a conjugate prior.)

def bayesian_gaussian_mixture_model(num_observations, dims, components): mixture_probs = yield tfd.Dirichlet( concentration=tf.ones(components, dtype=tf.float32) / components, name="mixture_probs" ) loc = yield tfd.Normal(loc=tf.zeros([components, dims]), scale=1, name="loc") scale = yield tfd.HalfNormal(scale=2 * tf.ones([components, dims]), name="scale") correlation_tril = yield tfd.CholeskyLKJ( dimension=dims, concentration=tf.ones([components]), name="correlation_tril" ) scale_tril = yield tfd.Deterministic(scale[..., tf.newaxis] * correlation_tril, name="scale_tril") observations = yield tfd.Sample( tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=mixture_probs), components_distribution=tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril), ), sample_shape=num_observations, name="observations", )
ncomponents = 10 ndata = data.shape[0] ndims = data.shape[1] bgmm = tfd.JointDistributionCoroutineAutoBatched( functools.partial(bayesian_gaussian_mixture_model, dims=ndims, components=ncomponents, num_observations=ndata) )
print(bgmm.event_shape) print(bgmm.event_shape._fields)
StructTuple( mixture_probs=TensorShape([10]), loc=TensorShape([10, 2]), scale=TensorShape([10, 2]), correlation_tril=TensorShape([10, 2, 2]), scale_tril=TensorShape([10, 2, 2]), observations=TensorShape([272, 2]) ) ('mixture_probs', 'loc', 'scale', 'correlation_tril', 'scale_tril', 'observations')
# Sample from the prior predictive joint distribution x = bgmm.sample() print(type(x)) # print(x) print(x.mixture_probs.shape) print(x.mixture_probs) print(x.loc.shape) print(x.scale.shape) print(x.correlation_tril.shape) print(x.scale_tril.shape) print(x.observations.shape) print("sample data") print(x.observations[:5, :])
<class 'tensorflow_probability.python.internal.structural_tuple.structtuple.<locals>.StructTuple'> (10,) tf.Tensor( [1.1842325e-04 1.6617058e-08 6.2336113e-08 1.1148515e-02 6.3603826e-02 1.6459078e-01 2.1702325e-10 7.5996721e-01 2.9287461e-04 2.7822447e-04], shape=(10,), dtype=float32) (10, 2) (10, 2) (10, 2, 2) (10, 2, 2) (272, 2) sample data tf.Tensor( [[ 1.0435464 -2.0615597 ] [ 1.0982577 1.6835704 ] [ 1.0092466 -4.3569803 ] [-0.09726715 -0.74421793] [-1.9014795 0.5790981 ]], shape=(5, 2), dtype=float32)
print(bgmm.log_prob(x))
tf.Tensor(-1111.151, shape=(), dtype=float32)
# Clamp the observations pinned = bgmm.experimental_pin(observations=data_normalized)
print(type(pinned)) print(pinned)
<class 'tensorflow_probability.python.experimental.distributions.joint_distribution_pinned.JointDistributionPinned'> tfp.distributions.JointDistributionPinned("PinnedJointDistributionCoroutineAutoBatched", batch_shape=[], event_shape=StructTuple( mixture_probs=[10], loc=[10, 2], scale=[10, 2], correlation_tril=[10, 2, 2], scale_tril=[10, 2, 2] ), dtype=StructTuple( mixture_probs=float32, loc=float32, scale=float32, correlation_tril=float32, scale_tril=float32 ))
# Sample from clamped model # x = pinned.sample() # does not work x = pinned.sample_unpinned() # sample from unnormalized joint print(x._fields) # observations is excluded print(x.mixture_probs) print(pinned.unnormalized_log_prob(x)) # print(bgmm.unnormalized_log_prob(x))
('mixture_probs', 'loc', 'scale', 'correlation_tril', 'scale_tril') tf.Tensor( [2.4685903e-06 2.2794875e-06 2.6798571e-04 1.3366579e-01 7.0801721e-04 8.5429221e-01 2.4315801e-13 1.0049632e-02 1.8950657e-04 8.2220143e-04], shape=(10,), dtype=float32) tf.Tensor(-1037.1965, shape=(), dtype=float32)

Fitting a point mass posterior (MAP estimate)

This marginalizes over the discrete latent indicators (as part of MixtureSameFamily logprob computation), but uses point estimates for model parameters, similar to standard EM. Thus there is no "Bayes Occam's razor" penalty factor when choosing too many mixture components.

def trainable_point_estimate(initial_loc, initial_scale, event_ndims, validate_args): return tfd.Independent( tfd.Deterministic(tf.Variable(initial_loc), validate_args=validate_args), reinterpreted_batch_ndims=event_ndims ) point_mass_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( pinned.event_shape, bijector=pinned.experimental_default_event_space_bijector(), trainable_distribution_fn=trainable_point_estimate, )
import time t0 = time.time() num_steps = 1000 losses = tfp.vi.fit_surrogate_posterior( pinned.unnormalized_log_prob, point_mass_posterior, optimizer=tf.optimizers.Adam(3e-2), num_steps=int(num_steps) ) t1 = time.time() print("{} variational steps finished in {:.3f}s".format(num_steps, t1 - t0))
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/vectorization_util.py:97: UserWarning: Saw Tensor seed Tensor("monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched/unnormalized_log_prob/Const:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed. 'by passing a Python `int` seed.'.format(seed)) /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/vectorization_util.py:97: UserWarning: Saw Tensor seed Tensor("monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched/unnormalized_log_prob/Const:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed. 'by passing a Python `int` seed.'.format(seed))
1000 variational steps finished in 4.806s
# The negative log data likelihood goes down monotonically, as EM theory predicts plt.plot(losses) plt.title("Training loss curve")
Text(0.5, 1.0, 'Training loss curve')
Image in a Jupyter notebook
print(type(point_mass_posterior))
<class 'tensorflow_probability.python.distributions.transformed_distribution.TransformedDistribution'>
print(point_mass_posterior)
tfp.distributions.TransformedDistribution("build_factored_surrogate_posterior_default_joint_bijectorbuild_factored_surrogate_posterior_JointDistributionNamed", batch_shape=StructTuple( mixture_probs=[], loc=[], scale=[], correlation_tril=[], scale_tril=[] ), event_shape=StructTuple( mixture_probs=[10], loc=[10, 2], scale=[10, 2], correlation_tril=[10, 2, 2], scale_tril=[10, 2, 2] ), dtype=StructTuple( mixture_probs=float32, loc=float32, scale=float32, correlation_tril=float32, scale_tril=float32 ))
# unconstrained parameters (before applying bijector e.g., Softplus for a positive-valued scale parameter). print(point_mass_posterior.trainable_variables)
(<tf.Variable 'build_factored_surrogate_posterior/Variable:0' shape=(9,) dtype=float32, numpy= array([-1.0691624 , -5.114904 , 1.28985 , 0.4173861 , -5.221218 , 0.13155887, -1.0381907 , -5.1708126 , -0.77442 ], dtype=float32)>, <tf.Variable 'build_factored_surrogate_posterior/Variable:0' shape=(10, 2) dtype=float32, numpy= array([[ 1.62215978e-01, 8.23917150e-01], [-1.12164579e-03, 1.60692434e-03], [ 7.20961392e-01, 7.50016987e-01], [-1.26306820e+00, -1.21620023e+00], [-9.37398523e-04, 1.31279277e-03], [ 9.82987285e-01, 5.39859354e-01], [-1.04261601e+00, -5.71361661e-01], [-9.66875756e-04, 1.39059790e-03], [-1.22851975e-01, -2.77763426e-01], [-1.44873738e+00, -1.43472803e+00]], dtype=float32)>, <tf.Variable 'build_factored_surrogate_posterior/Variable:0' shape=(10, 2) dtype=float32, numpy= array([[-1.6801108 , -1.4628218 ], [ 1.0032085 , 1.0036786 ], [-1.0639086 , -0.56256396], [-1.6406152 , -0.83363366], [ 1.0148695 , 1.0218562 ], [-1.604622 , -1.3254735 ], [-1.8308544 , -1.1224767 ], [ 1.0103079 , 1.0190338 ], [-0.597313 , -0.15598181], [-2.888337 , -0.8903793 ]], dtype=float32)>, <tf.Variable 'build_factored_surrogate_posterior/Variable:0' shape=(10, 1) dtype=float32, numpy= array([[-0.28085336], [ 0.00737957], [ 0.5580046 ], [-0.2497556 ], [ 0.0064737 ], [ 0.01802182], [ 0.2677456 ], [ 0.00684568], [ 1.3021624 ], [-0.2106765 ]], dtype=float32)>, <tf.Variable 'build_factored_surrogate_posterior/Variable:0' shape=(10, 2, 2, 0) dtype=float32, numpy=array([], shape=(10, 2, 2, 0), dtype=float32)>)
mix_log_weights = point_mass_posterior.trainable_variables[0] print(tf.nn.softmax(mix_log_weights))
tf.Tensor( [0.04597949 0.00080449 0.48648587 0.20331244 0.00072335 0.15276743 0.04742584 0.00076075 0.06174036], shape=(9,), dtype=float32)

Samples from the posterior predictive distribution should be constant across sampling runs, since we use a point estimate of the parameters.

params = point_mass_posterior.sample() print(params.mixture_probs)
tf.Tensor( [0.0405486 0.00070947 0.42902428 0.17929807 0.00063791 0.13472322 0.04182411 0.00067089 0.05444786 0.11811556], shape=(10,), dtype=float32)
params = point_mass_posterior.sample() print(params.mixture_probs)
tf.Tensor( [0.0405486 0.00070947 0.42902428 0.17929807 0.00063791 0.13472322 0.04182411 0.00067089 0.05444786 0.11811556], shape=(10,), dtype=float32)
plot_posterior_sample(point_mass_posterior, data=data_normalized) plt.savefig("vb_gmm_map_sample.pdf") plt.show()
Image in a Jupyter notebook
plot_posterior_sample(point_mass_posterior, data=data_normalized)
Image in a Jupyter notebook

Fitting a diagonal Gaussian posterior

Construct and fit a surrogate posterior using stochastic gradient VI. The surrogate is a diagonal Gaussian that is transformed into the support of the model's parameters using appropriate bijectors. (The transformed vector is then split into tensors for each of the models RVs, and these are pushed through constraining bijectors as needed.) For details, see https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/vi/build_affine_surrogate_posterior

The event space for this distribution is derived from the pinned distribution. For details, see https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/distributions/JointDistributionPinned

surrogate_posterior = tfp.experimental.vi.build_affine_surrogate_posterior( pinned.event_shape, bijector=pinned.experimental_default_event_space_bijector(), operators="diag" ) # Use operators='tril' for full covariance Gaussian
import time t0 = time.time() num_steps = 1000 losses = tfp.vi.fit_surrogate_posterior( pinned.unnormalized_log_prob, surrogate_posterior, optimizer=tf.optimizers.Adam(2e-2), num_steps=int(num_steps) ) t1 = time.time() print("{} variational steps finished in {:.3f}s".format(num_steps, t1 - t0))
/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/vectorization_util.py:97: UserWarning: Saw Tensor seed Tensor("monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched/unnormalized_log_prob/Const:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed. 'by passing a Python `int` seed.'.format(seed)) /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/vectorization_util.py:97: UserWarning: Saw Tensor seed Tensor("monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched/unnormalized_log_prob/Const:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed. 'by passing a Python `int` seed.'.format(seed))
1000 variational steps finished in 7.253s
plt.plot(losses) plt.title("Training loss curve")
Text(0.5, 1.0, 'Training loss curve')
Image in a Jupyter notebook
plot_posterior_sample(surrogate_posterior, data=data_normalized) plt.savefig("vb_gmm_bayes_sample1.pdf") plt.show()
Image in a Jupyter notebook
plot_posterior_sample(surrogate_posterior, data=data_normalized) plt.savefig("vb_gmm_bayes_sample2.pdf") plt.show()
Image in a Jupyter notebook
plot_posterior_sample(surrogate_posterior, data=data_normalized) plt.savefig("vb_gmm_bayes_sample3.pdf") plt.show()
Image in a Jupyter notebook
params = surrogate_posterior.sample() print(params.mixture_probs)
tf.Tensor( [2.23936467e-03 3.64494860e-01 5.86357564e-02 1.12465797e-02 9.28402471e-04 6.39106001e-05 5.53885102e-01 6.51227252e-04 1.21413046e-04 7.73340464e-03], shape=(10,), dtype=float32)
params = surrogate_posterior.sample() print(params.mixture_probs)
tf.Tensor( [1.3826601e-03 4.2224973e-01 1.6281165e-02 6.7551960e-03 1.8640651e-03 3.2661634e-04 5.4102653e-01 1.4823949e-03 2.0768255e-04 8.4240157e-03], shape=(10,), dtype=float32)