Path: blob/master/notebooks/book2/03/linreg_post_pred_plot.ipynb
1193 views
Kernel: Python 3 (ipykernel)
Bayesian Linear regression vs Plugin approximation
In [1]:
import jax import jax.numpy as jnp import matplotlib.pyplot as plt try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils import savefig, latexify, is_latexify_enabled try: from sklearn.linear_model import Ridge except ModuleNotFoundError: %pip install -qq scikit-learn from sklearn.linear_model import Ridge from scipy.stats import multivariate_normal import warnings warnings.filterwarnings("ignore") jax.config.update("jax_platform_name", "cpu")
In [2]:
latexify(width_scale_factor=2)
Generate Data
In [3]:
polydeg = 2 # Degree of design matrix alph = 0.001 # Alpha of ridge regression num_samples = 10 # Number of sample coefficients to draw and use for prediction visibility = 0.5 # Transparency of plotted lines - in case we wish to plot a bunch. key = jax.random.PRNGKey(1) xtrain = jnp.array([-3, -2, 0, 2, 3]) xtest = jnp.linspace(-7, 7, 141) sigma2 = 25 def fun(x): return 10 + x + x**2 ytrain = fun(xtrain) + jax.random.normal(key=key, shape=[xtrain.shape[0]]) * jnp.sqrt(sigma2) ytest = fun(xtest) + jax.random.normal(key=key, shape=[xtest.shape[0]]) * jnp.sqrt(sigma2) def reshape(x): return jnp.asarray(x).reshape(-1, 1) xtrain = reshape(xtrain) xtest = reshape(xtest) ytrain = reshape(ytrain) ytest = reshape(ytest)
In [4]:
def poly_basis(x, deg): # Expands a vector to a polynomial design matrix: from a constant to the deg-power return jnp.column_stack([x**deg for deg in range(0, deg + 1)])
Train both models
In [5]:
xtrainp = poly_basis(xtrain, polydeg) xtestp = poly_basis(xtest, polydeg) # Declare and fit linear regression model linreg = Ridge(alpha=alph, fit_intercept=False) linreg.fit(xtrainp, ytrain) # Determine coefficient distribution wmle = linreg.coef_.reshape( -1, ) # Mean of coefficients wcov = sigma2 * jnp.linalg.inv( jnp.diag(jnp.array([alph] * (polydeg + 1))) + xtrainp.T.dot(xtrainp) ) # Covariance of coefficients posterior_bayes = multivariate_normal(mean=wmle, cov=wcov) samples = posterior_bayes.rvs(num_samples) # Sample predictions according to samples of coefficients prediction_samples = xtestp.dot(samples.T) ypred_mle = linreg.predict(xtestp) # MLE prediction noise_mle = jnp.var(ytrain - linreg.predict(xtrainp), ddof=(polydeg + 1)) # MLE noise estimation
Plotting
In [6]:
def make_plot(ypreds, save_name, title, lowerb=None, upperb=None): # Function for creating and saving plots fig, ax = plt.subplots() if not is_latexify_enabled(): SCATTER_SIZE = 140 else: SCATTER_SIZE = 40 ax.scatter(xtrain, ytrain, s=SCATTER_SIZE, facecolors="none", edgecolors="r", label="training data") errlogi = lowerb is not None or upperb is not None # Determines where we will be plotting error bars as well if errlogi: errspacing = [int(round(s)) for s in jnp.linspace(0, xtest.shape[0] - 1, 30)] errspacing_reshaped = jnp.reshape(jnp.array(errspacing), (1, len(errspacing))) yerr = [list(lowerb[jnp.array(errspacing)]), list(upperb[jnp.array(errspacing)])] ax.errorbar( xtest[errspacing_reshaped][0, :, :], ypreds[jnp.array(errspacing), 0], yerr=yerr, linewidth=0.5, label="prediction", ) ax.fill_between(xtest[:, 0], ypreds[:, 0] + lowerb, ypreds[:, 0] - upperb, alpha=0.3) else: for j in range(ypreds.shape[1]): ax.plot(xtest, ypreds[:, j], color="k", linewidth=1.0, label="prediction", alpha=visibility) if errlogi: plt.legend(loc=9, prop={"size": 7}) ax.set_xlabel("x") ax.set_ylabel("y") ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) pml.savefig(save_name + "")
In [7]:
# plot a noisevec = jnp.array([jnp.sqrt(noise_mle)] * ypred_mle.shape[0]) make_plot(ypred_mle, "linregPostPredPlugin", "Plugin approximation", noisevec, noisevec) # plot b postnoise = jnp.array([jnp.sqrt(sigma2 + xtestp[i, :].T.dot(wcov.dot(xtestp[i, :]))) for i in range(xtestp.shape[0])]) make_plot(ypred_mle, "linregPostPredBayes", "Posterior predictive", postnoise, postnoise) # plot c make_plot(ypred_mle, "linregPostPredSamplesPlugin", "functions sampled from plugin approximation to posterior") # plot d make_plot(prediction_samples, "linregPostPredSamples", "functions sampled from posterior") plt.show()
Out[7]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: