Path: blob/master/notebooks/book1/03/gauss_infer_1d.ipynb
1193 views
Kernel: Python [conda env:py3713]
Inference about a random variable given a noisy observation
In [1]:
import jax.numpy as jnp from jax import random import matplotlib.pyplot as plt from scipy.stats import multivariate_normal as mvn import seaborn as sns import os try: from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify
In [2]:
os.environ["LATEXIFY"] = "" os.environ["FIG_DIR"] = "figures"
In [3]:
latexify(width_scale_factor=2, fig_height=2)
In [4]:
def plot_data(quantiles, pdf_dict, variance_prior, interact_flag, save_name=""): """ Plot data Args: ---------- quantiles : JAX array Quantiles for normal distribution pdf_dict : dictionary Data and plotting options for all PDF's variance_prior : int/float Variance of Prior save_name : string, default='' Filename for the saved graph Returns: ---------- None """ # Setup graph and parameters fig = plt.figure() # Update limits for interactive plot if interact_flag: plt.ylim(0, 1) plt.xlim(-10, 10) else: plt.ylim(0, 0.6) plt.xlim(-7, 7) # Plot graph for key, value in pdf_dict.items(): plt.plot( quantiles, value["pdf"], color=value["color"], label=key, linestyle=value["linestyle"], linewidth=1.5, ) # Update labels,legends and title plt.title(f"Prior with variance of {variance_prior}") plt.xlabel("$x$") plt.ylabel("$p(x)$") plt.legend(loc="upper left") sns.despine() # Save figure to files if len(save_name) > 0: savefig(save_name) plt.show()
In [5]:
def generate_PDF(prior_var=[1, 5], prior_mean=[0, 0], observed_var=[1, 1], interact_flag=False): """ Plot PDF for prior, likelihood and posterior of a gaussian distribution to gain insight about a random variable given a noisy observation under the effect of a strong prior versus a weak prior Args: ---------- prior_var : list,default=[1, 5] List of Variance of priors prior_mean : list,default=[0, 0] List of Mean of Priors observed_var : list, default=[1,1] List of Variance of observed values interact_flag : bool, default=False Interactive plot indicator Returns: ---------- None """ # The noisy observation obsrvd_val = 3 # Quantiles for data quantiles = jnp.arange(-10, 10, 0.10) for mean_prior, variance_prior, variance_observed in zip(prior_mean, prior_var, observed_var): # Prior variance,mean and PDF prior_var = variance_prior prior_mu = mean_prior prior_pdf = mvn.pdf(quantiles, mean=prior_mu, cov=prior_var) # Likelihood variance,mean and PDF likelihood_var = variance_observed likelihood_mu = jnp.mean(obsrvd_val) likelihood_pdf = mvn.pdf(quantiles, mean=likelihood_mu, cov=likelihood_var) # Number of noisy observations obsrvd_val_count = jnp.size(obsrvd_val) # Posterior variance,mean and PDF posterior_var = (likelihood_var * prior_var) / ((obsrvd_val_count * prior_var) + likelihood_var) posterior_mu = posterior_var * ((prior_mu / prior_var) + ((obsrvd_val_count * likelihood_mu) / likelihood_var)) posterior_pdf = mvn.pdf(quantiles, mean=posterior_mu, cov=posterior_var) # Setup data and graphing options pdf_plot_dict = { "prior": {"pdf": prior_pdf, "color": "blue", "linestyle": "-"}, "likelihood": {"pdf": likelihood_pdf, "color": "red", "linestyle": ":"}, "posterior": {"pdf": posterior_pdf, "color": "black", "linestyle": "-."}, } # Plot plot_data( quantiles, pdf_plot_dict, variance_prior, interact_flag, f"gauss_infer_1d_prior{variance_prior}", )
In [6]:
generate_PDF()
Out[6]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
saving image to figures/gauss_infer_1d_prior1_latexified.pdf
Figure size: [3. 2.]
saving image to figures/gauss_infer_1d_prior5_latexified.pdf
Figure size: [3. 2.]
In [7]:
from ipywidgets import Layout, interact import ipywidgets as widgets @interact( prior_var=widgets.FloatSlider( description="prior variance", min=1.0, max=5.0, value=1.0, step=0.1, style=dict(description_width="initial"), continuous_update=False, ), prior_mean=widgets.FloatSlider( description="prior mean", min=-3.0, max=3.0, value=0.0, step=0.1, style=dict(description_width="initial"), continuous_update=False, ), observed_var=widgets.FloatSlider( description="Variance of Observed values", min=1.0, max=5.0, value=1.0, step=0.1, style=dict(description_width="initial"), layout=Layout(width="40%"), continuous_update=False, ), ) def update(prior_var, prior_mean, observed_var): generate_PDF( prior_var=[prior_var], prior_mean=[prior_mean], observed_var=[observed_var], interact_flag=True, )
Out[7]:
interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='prior variance', max=5.0, m…
In [ ]: