Path: blob/master/deprecated/scripts/beta_binom_post_plot.py
1192 views
import superimport1import jax.numpy as jnp2import matplotlib.pyplot as plt3import pyprobml_utils as pml4from jax.scipy.stats import beta, bernoulli56# Points where we evaluate the pdf7x = jnp.linspace(0.001, 0.999, 100)8910# Forms graph given the parameters of the prior, likelihood and posterior:11def make_graph(data, save_name):12prior = beta.pdf(x, a=data["prior"]["a"], b=data["prior"]["b"])13n_0 = data["likelihood"]["n_0"]14n_1 = data["likelihood"]["n_1"]15samples = jnp.concatenate([jnp.zeros(n_0), jnp.ones(n_1)])16likelihood_function = jnp.vectorize(17lambda p: jnp.exp(bernoulli.logpmf(samples, p).sum())18)19likelihood = likelihood_function(x)20posterior = beta.pdf(x, a=data["posterior"]["a"], b=data["posterior"]["b"])2122fig, ax = plt.subplots()23axt = ax.twinx()24fig1 = ax.plot(25x,26prior,27"k",28label=f"prior Beta({data['prior']['a']}, {data['prior']['b']})",29linewidth=2.0,30)31fig2 = axt.plot(x, likelihood, "r:", label=f"likelihood Bernoulli", linewidth=2.0)32fig3 = ax.plot(33x,34posterior,35"b-.",36label=f"posterior Beta({data['posterior']['a']}, {data['posterior']['b']})",37linewidth=2.0,38)39fig_list = fig1 + fig2 + fig340labels = [fig.get_label() for fig in fig_list]41ax.legend(fig_list, labels, loc="upper left", shadow=True)42axt.set_ylabel("Likelihood")43ax.set_ylabel("Prior/Posterior")44ax.set_title(f"$N_0$:{n_0}, $N_1$:{n_1}")45pml.savefig(save_name)464748data1 = {49"prior": {"a": 1, "b": 1},50"likelihood": {"n_0": 1, "n_1": 4},51"posterior": {"a": 5, "b": 2},52}53make_graph(data1, "betaPostUninfSmallSample.pdf")5455data2 = {56"prior": {"a": 1, "b": 1},57"likelihood": {"n_0": 10, "n_1": 40},58"posterior": {"a": 41, "b": 11},59}60make_graph(data2, "betaPostUninfLargeSample.pdf")6162data3 = {63"prior": {"a": 2, "b": 2},64"likelihood": {"n_0": 1, "n_1": 4},65"posterior": {"a": 6, "b": 3},66}67make_graph(data3, "betaPostInfSmallSample.pdf")6869data4 = {70"prior": {"a": 2, "b": 2},71"likelihood": {"n_0": 10, "n_1": 40},72"posterior": {"a": 42, "b": 12},73}74make_graph(data4, "betaPostInfLargeSample.pdf")757677plt.show()787980