Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/adf_logistic_regression_demo.py
1192 views
1
# Online training of a logistic regression model
2
# using Assumed Density Filtering (ADF).
3
# We compare the ADF result with MCMC sampling
4
# For further details, see the ADF paper:
5
# * O. Zoeter, "Bayesian Generalized Linear Models in a Terabyte World,"
6
# 2007 5th International Symposium on Image and Signal Processing and Analysis, 2007,
7
# pp. 435-440, doi: 10.1109/ISPA.2007.4383733.
8
# of the posterior distribution
9
# Dependencies:
10
# !pip install jax_cosmo
11
12
# Author: Gerardo Durán-Martín (@gerdm)
13
14
import superimport
15
16
import jax
17
import jax.numpy as jnp
18
import matplotlib.pyplot as plt
19
import pyprobml_utils as pml
20
from jax import random
21
from jax.scipy.stats import norm
22
from jax_cosmo.scipy import integrate
23
from functools import partial
24
from jsl.demos import logreg_biclusters_demo as demo
25
26
import pyprobml_utils as pml
27
28
# cosmo seems to only support numerical integration in CPU mode
29
jax.config.update("jax_platform_name", "cpu")
30
jax.config.update("jax_enable_x64", True)
31
32
figures, data = demo.main()
33
34
X = data["X"]
35
y = data["y"]
36
Phi = data["Phi"]
37
Xspace = data["Xspace"]
38
Phispace = data["Phispace"]
39
w_laplace = data["w_laplace"]
40
41
def sigmoid(z): return jnp.exp(z) / (1 + jnp.exp(z))
42
def log_sigmoid(z): return z - jnp.log1p(jnp.exp(z))
43
44
def Zt_func(eta, y, mu, v):
45
log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
46
log_term = log_term + norm.logpdf(eta, mu, v)
47
48
return jnp.exp(log_term)
49
50
51
def mt_func(eta, y, mu, v, Zt):
52
log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
53
log_term = log_term + norm.logpdf(eta, mu, v)
54
55
return eta * jnp.exp(log_term) / Zt
56
57
58
def vt_func(eta, y, mu, v, Zt):
59
log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
60
log_term = log_term + norm.logpdf(eta, mu, v)
61
62
return eta ** 2 * jnp.exp(log_term) / Zt
63
64
65
def adf_step(state, xs, prior_variance, lbound, ubound):
66
mu_t, tau_t = state
67
Phi_t, y_t = xs
68
69
mu_t_cond = mu_t
70
tau_t_cond = tau_t + prior_variance
71
72
# prior predictive distribution
73
m_t_cond = (Phi_t * mu_t_cond).sum()
74
v_t_cond = (Phi_t ** 2 * tau_t_cond).sum()
75
76
v_t_cond_sqrt = jnp.sqrt(v_t_cond)
77
78
# Moment-matched Gaussian approximation elements
79
Zt = integrate.romb(lambda eta: Zt_func(eta, y_t, m_t_cond, v_t_cond_sqrt), lbound, ubound)
80
81
mt = integrate.romb(lambda eta: mt_func(eta, y_t, m_t_cond, v_t_cond_sqrt, Zt), lbound, ubound)
82
83
vt = integrate.romb(lambda eta: vt_func(eta, y_t, m_t_cond, v_t_cond_sqrt, Zt), lbound, ubound)
84
vt = vt - mt ** 2
85
86
# Posterior estimation
87
delta_m = mt - m_t_cond
88
delta_v = vt - v_t_cond
89
a = Phi_t * tau_t_cond / (Phi_t ** 2 * tau_t_cond).sum()
90
mu_t = mu_t_cond + a * delta_m
91
tau_t = tau_t_cond + a ** 2 * delta_v
92
93
return (mu_t, tau_t), (mu_t, tau_t)
94
95
# ** ADF inference **
96
prior_variance = 0.0
97
# Lower and upper bounds of integration. Ideally, we would like to
98
# integrate from -inf to inf, but we run into numerical issues.
99
n_datapoints, ndims = Phi.shape
100
lbound, ubound = -20, 20
101
mu_t = jnp.zeros(ndims)
102
tau_t = jnp.ones(ndims) * 1.0
103
104
init_state = (mu_t, tau_t)
105
xs = (Phi, y)
106
107
adf_loop = partial(adf_step, prior_variance=prior_variance, lbound=lbound, ubound=ubound)
108
(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)
109
print("ADF weights")
110
print(mu_t)
111
112
# ADF posterior predictive distribution
113
n_samples = 5000
114
key = random.PRNGKey(3141)
115
adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t), (n_samples,))
116
Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples))
117
Z_adf = Z_adf.mean(axis=0)
118
119
# ** Plotting predictive distribution **
120
colors = ["black" if el else "white" for el in y]
121
122
## Add posterior marginal for ADF-estimated weights
123
for i in range(ndims):
124
mean, std = mu_t[i], jnp.sqrt(tau_t[i])
125
#fig = figures[f"weights_marginals_w{i}"]
126
fig = figures[f"logistic_regression_weights_marginals_w{i}"]
127
ax = fig.gca()
128
x = jnp.linspace(mean - 4 * std, mean + 4 * std, 500)
129
ax.plot(x, norm.pdf(x, mean, std), label="posterior (ADF)", linestyle="dashdot")
130
ax.legend()
131
132
fig_adf, ax = plt.subplots()
133
title = "ADF Predictive distribution"
134
demo.plot_posterior_predictive(ax, X, Xspace, Z_adf, title, colors)
135
#figures["predictive_distribution_adf"] = fig_adf
136
#figures["logistic_regression_surface_adf"] = fig_adf
137
pml.savefig("logistic_regression_surface_adf.pdf")
138
139
# Posterior vs time
140
141
lcolors = ["black", "tab:blue", "tab:red"]
142
elements = mu_t_hist.T, tau_t_hist.T, w_laplace, lcolors
143
timesteps = jnp.arange(n_datapoints) + 1
144
145
for k, (wk, Pk, wk_laplace, c) in enumerate(zip(*elements)):
146
fig_weight_k, ax = plt.subplots()
147
ax.errorbar(timesteps, wk, jnp.sqrt(Pk), c=c, label=f"$w_{k}$ online (adf)")
148
ax.axhline(y=wk_laplace, c=c, linestyle="dotted", label=f"$w_{k}$ batch (Laplace)", linewidth=3)
149
150
ax.set_xlim(1, n_datapoints)
151
ax.legend(framealpha=0.7, loc="upper right")
152
ax.set_xlabel("number samples")
153
ax.set_ylabel("weights")
154
plt.tight_layout()
155
#figures[f"adf_logistic_regression_hist_w{k}"] = fig_weight_k
156
#figures[f"logistic_regression_hist_adf_w{k}"] = fig_weight_k
157
pml.savefig(f"logistic_regression_hist_adf_w{k}")
158
159
#for name, figure in figures.items():
160
# filename = f"./../figures/{name}.pdf"
161
# figure.savefig(filename)
162
163
plt.show()
164
165