Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/15/logreg_iris_bayes_2d.ipynb
1193 views
Kernel: Python 3.7.13 ('py3713')

Bayesian Binary logistic regression in 2d for iris flowers

Code is based on: https://github.com/aloctavodia/BAP/blob/master/code/Chp4/04_Generalizing_linear_models.ipynb

author: @karm-patel

import jax import jax.numpy as jnp import pandas as pd import matplotlib.pyplot as plt from functools import partial import seaborn as sns import os try: import blackjax except: %pip install jaxopt blackjax import blackjax try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils.blackjax_utils import arviz_trace_from_states, inference_loop_multiple_chains from sklearn.datasets import load_iris try: from tensorflow_probability.substrates import jax as tfp except ModuleNotFoundError: %pip install -qqq tensorflow_probability from tensorflow_probability.substrates import jax as tfp import arviz as az tfd = tfp.distributions
# os.environ["LATEXIFY"] = "" # os.environ["FIG_DIR"] = "figures" pml.latexify(fig_height=1.75, width_scale_factor=2)
/home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
iris = load_iris() X = iris.data y = iris.target # Convert to pandas dataframe df_iris = pd.DataFrame(data=iris.data, columns=["sepal_length", "sepal_width", "petal_length", "petal_width"]) df_iris["species"] = pd.Series(iris.target_names[y], dtype="category") df = df_iris.query("species == ('setosa', 'versicolor')") df

Balanced Dataset

df_balanced = df[25:75] df_balanced.head()
X = df_balanced[["sepal_length", "sepal_width"]].values # features X.shape
(50, 2)
y = pd.Categorical(df_balanced["species"]).codes y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)
bias_prior = tfd.Normal(0, 20) weights_prior = tfd.Normal(0, 2) def joint_logprob_fn(params, inputs=X, outputs=y): weights = params["weights"] bias = params["bias"] # prior logprob log_prior = weights_prior.log_prob(weights).sum() + bias_prior.log_prob(bias) # likelihood logprob logits = inputs @ weights + bias log_lik = tfd.Bernoulli(logits=logits).log_prob(outputs).sum() return log_prior + log_lik
# initialize the params for two chains n_chains = 2 rng_key = jax.random.PRNGKey(10) initial_params = { "weights": weights_prior.sample(seed=rng_key, sample_shape=(n_chains, 2)), "bias": bias_prior.sample(seed=rng_key, sample_shape=(n_chains,)), }
inverse_mass_matrix = jnp.array([0.5] * 3) step_size = 0.01 nuts = blackjax.nuts(joint_logprob_fn, step_size, inverse_mass_matrix)
initial_states = jax.vmap(nuts.init)(initial_params) kernel = jax.jit(nuts.step)
%%time rng_key = jax.random.PRNGKey(1) states, infos = inference_loop_multiple_chains(rng_key, kernel, initial_states, 2000, n_chains)
CPU times: user 8.38 s, sys: 115 ms, total: 8.49 s Wall time: 8.42 s
trace = arviz_trace_from_states(states, infos, burn_in=200) summary_df = az.summary(trace) summary_df
az.plot_trace(trace) plt.tight_layout()
Image in a Jupyter notebook
burn = 200 w0 = states.position["weights"][burn:, :, 0].reshape(1, -1) w1 = states.position["weights"][burn:, :, 1].reshape(1, -1) b = states.position["bias"][burn:, :].reshape(1, -1)
x0 = X[:, 0] bd = -(jnp.dot(x0.reshape(-1, 1), w0) / w1 + b / w1) # Decision boundary bd.shape
(50, 3600)
mean_pred = -(x0 * w0.mean() / w1.mean() + b.mean() / w1.mean()) # mean prediction SCATTER_SIZE = 10 idx = y == 0 fig, ax = plt.subplots() ax.scatter(X[idx, 0], X[idx, 1], c="tab:red", label="setosa", s=SCATTER_SIZE) sc = ax.scatter(X[~idx, 0], X[~idx, 1], c="tab:green", label="versicolor", s=SCATTER_SIZE) ax.set_xlabel("sepal_length") ax.set_ylabel("sepal_width") ax.set_ylim(1, 8) ax.plot(x0, mean_pred, label="Posterior mean", lw=1) az.plot_hdi(X[:, 0], bd.T, color="k", ax=ax) ax.legend(frameon=False, bbox_to_anchor=(0.55, 0.6)) sns.despine() pml.savefig("logreg_iris_bayes_2d")
/home/patel_karm/arviz/arviz/plots/hdiplot.py:157: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs) /home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Unbalanced Dataset

Now we will use unbalanced dataset where we have only 5 samples of setosa and 45 samples of versicolor

df_un_balanced = df[45:95] df_un_balanced.sample(5)
X = df_un_balanced[["sepal_length", "sepal_width"]].values # features X.shape
(50, 2)
y = pd.Categorical(df_un_balanced["species"]).codes y
array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)
inverse_mass_matrix = jnp.array([0.5] * 3) step_size = 0.01 joint_logprob_fn_partial = partial(joint_logprob_fn, inputs=X, outputs=y) nuts = blackjax.nuts(joint_logprob_fn_partial, step_size, inverse_mass_matrix)
initial_states = jax.vmap(nuts.init)(initial_params) kernel = jax.jit(nuts.step)
%%time rng_key = jax.random.PRNGKey(1) states, infos = inference_loop_multiple_chains(rng_key, kernel, initial_states, 2000, n_chains)
CPU times: user 8.7 s, sys: 103 ms, total: 8.8 s Wall time: 8.73 s
trace = arviz_trace_from_states(states, infos, burn_in=200) summary_df = az.summary(trace) summary_df
az.plot_trace(trace) plt.tight_layout()
Image in a Jupyter notebook
burn = 200 w0 = states.position["weights"][burn:, :, 0].reshape(1, -1) w1 = states.position["weights"][burn:, :, 1].reshape(1, -1) b = states.position["bias"][burn:, :].reshape(1, -1)
x0 = X[:, 0] bd = -(jnp.dot(x0.reshape(-1, 1), w0) / w1 + b / w1) # Decision boundary bd.shape
(50, 3600)
mean_pred = -(x0 * w0.mean() / w1.mean() + b.mean() / w1.mean()) # mean prediction SCATTER_SIZE = 10 idx = y == 0 fig, ax = plt.subplots() ax.scatter(X[idx, 0], X[idx, 1], c="tab:red", label="setosa", s=SCATTER_SIZE) sc = ax.scatter(X[~idx, 0], X[~idx, 1], c="tab:green", label="versicolor", s=SCATTER_SIZE) ax.set_xlabel("sepal_length") ax.set_ylabel("sepal_width") ax.set_ylim(1, 8) ax.plot(x0, mean_pred, label="Posterior mean", lw=1) az.plot_hdi(X[:, 0], bd.T, color="k", ax=ax) # ax.legend(frameon=False, bbox_to_anchor=(0.55, 0.6)) sns.despine() pml.savefig("logreg_iris_bayes_2d_unbalanced")
/home/patel_karm/arviz/arviz/plots/hdiplot.py:157: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs) /home/patel_karm/sendbox/probml-utils/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook