Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/04/samplingDistributionGaussianShrinkage.ipynb
1192 views
Kernel: Python 3

MAP Estimation (Posterior Mean)

import matplotlib.pyplot as plt try: from probml_utils import savefig, latexify, is_latexify_enabled except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify, is_latexify_enabled try: import tensorflow_probability.substrates.jax as tfp except ModuleNotFoundError: %pip install -qq tensorflow_probability import tensorflow_probability.substrates.jax as tfp try: import jax import jax.numpy as jnp except ModuleNotFoundError: %pip install -qq jax import jax import jax.numpy as np tfd = tfp.distributions latexify(width_scale_factor=2, fig_height=1.85)
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
colors = ["b", "r", "k", "g", "c", "y", "m", "r", "b", "k", "g", "c", "y", "m"] styles = ["-", ":", "-.", "--", "-", ":", "-.", "--", "-", ":", "-.", "--", "-", ":", "-.", "--"]
def gauss_prob(X, mu, Sigma): dist_normal = tfd.Normal(loc=mu, scale=Sigma) prob = dist_normal.prob(X) return prob
def plot_posterior_mean( colors=colors, styles=styles, k0=4, n=5, save_file_name="sampling_distribution_gaussian_shrinkage_latexified", fig=None, ax=None, ): # k0s is an array containing prior strengths over which we will plot the graph. k0s = jnp.arange(k0) thetaTrue = 1 sigmaTrue = 1 thetaPrior = 0 xrange = jnp.arange(-1, 2.55, 0.05) names = [] for ki in range(len(k0s)): k0 = k0s[ki] w = n / (n + k0) v = w**2 * sigmaTrue**2 / n thetaEst = w * thetaTrue + (1 - w) * thetaPrior names.append("$\kappa_0 = {0:01d}$".format(k0s[ki])) ax.plot(xrange, gauss_prob(xrange, thetaEst, jnp.sqrt(v)), color=colors[ki], linestyle=styles[ki], linewidth=1) ax.set_title("Sampling Distribution\n truth = {}, prior = {}, n = {}".format(thetaTrue, thetaPrior, n)) ax.set_xlabel("$x$") ax.tick_params(axis="both", labelsize=7) ax.set_ylabel("$P(x)\ (Posterior\ mean)$") ax.legend(names, loc="upper left", fontsize=7) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() savefig(save_file_name) return fig, ax
def plot_relative_mean( colors=colors, styles=styles, k0=4, fig=None, ax=None, save_file_name="sampling_distribution_gaussian_shrinkage_second_latexified", ): # k0s is an array containing prior strengths over which we will plot the graph. k0s = jnp.arange(k0) ns = jnp.arange(1, 50, 2) mseThetaE = jnp.zeros((len(ns), len(k0s))) mseThetaB = jnp.zeros((len(ns), len(k0s))) thetaTrue = 1 sigmaTrue = 1 thetaPrior = 0 names = [] for ki in range(len(k0s)): k0 = k0s[ki] ws = ns / (ns + k0) mseThetaE = mseThetaE.at[:, ki].set(sigmaTrue**2 / ns) mseThetaB = mseThetaB.at[:, ki].set( ws**2 * sigmaTrue**2 / ns + (1 - ws) ** 2 * (thetaPrior - thetaTrue) ** 2 ) names.append("$\kappa_0 = {0:01d}$".format(k0s[ki])) ratio = mseThetaB / mseThetaE for ki in range(len(k0s)): ax.plot(ns, ratio[:, ki], color=colors[ki], linestyle=styles[ki], linewidth=1) ax.legend(names) ax.set_ylabel("Relative MSE") ax.set_xlabel("Sample Size") ax.set_title("MSE of postmean / MSE of MLE") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() savefig(save_file_name) return fig, ax
fig1, ax1 = plt.subplots(1, 1) plot_posterior_mean(colors=colors, styles=styles, fig=fig1, ax=ax1) fig2, ax2 = plt.subplots(1, 1) plot_relative_mean(colors=colors, styles=styles, fig=fig2, ax=ax2)
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
(<Figure size 432x288 with 1 Axes>, <AxesSubplot:title={'center':'MSE of postmean / MSE of MLE'}, xlabel='Sample Size', ylabel='Relative MSE'>)
Image in a Jupyter notebookImage in a Jupyter notebook
from ipywidgets import interact, fixed # demonstartes change in posterior mean with change in number of samples (n) @interact(n=(2, 80, 1), colors=fixed(colors), styles=fixed(styles), k0=fixed(4)) def plot_posterior_mean_interactive(n=4, colors=colors, styles=styles, k0=4): k0s = jnp.arange(k0) thetaTrue = 1 sigmaTrue = 1 thetaPrior = 0 xrange = jnp.arange(-1, 2.55, 0.05) names = [] for ki in range(len(k0s)): k0 = k0s[ki] w = n / (n + k0) v = w**2 * sigmaTrue**2 / n thetaEst = w * thetaTrue + (1 - w) * thetaPrior names.append("$\kappa_0 = {0:01d}$".format(k0s[ki])) plt.plot(xrange, gauss_prob(xrange, thetaEst, jnp.sqrt(v)), color=colors[ki], linestyle=styles[ki], linewidth=1) plt.title("Sampling Distribution, truth = {}, prior = {}, n = {}".format(thetaTrue, thetaPrior, n)) plt.xlabel("$x$") plt.ylabel("$P(x)\ (Posterior mean)$") plt.legend(names, loc="upper left") plt.show()
interactive(children=(IntSlider(value=4, description='n', max=80, min=2), Output()), _dom_classes=('widget-int…