Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AllenDowney
GitHub Repository: AllenDowney/bayesian-analysis-recipes
Path: blob/master/models/__init__.py
411 views
1
"""
2
Author: Nicole Carson
3
4
Copied from https://raw.githubusercontent.com/parsing-science/ps-toolkit/master/ps_toolkit/pymc3_models/__init__.py.
5
"""
6
7
import joblib
8
import matplotlib.pyplot as plt
9
import pymc3 as pm
10
import seaborn as sns
11
from sklearn.base import BaseEstimator
12
13
14
class BayesianModel(BaseEstimator):
15
"""
16
Bayesian model base class
17
"""
18
def __init__(self):
19
self.advi_hist = None # The ADVI history.
20
self.advi_trace = None # The ADVI trace.
21
self.cached_model = None # The PyMC3 model object.
22
self.num_pred = None # The number of predictor variables (number of columns in X).
23
self.shared_vars = None # A dictionary of Theano shared variables.
24
25
def create_model(self):
26
raise NotImplementedError
27
28
def _set_shared_vars(self, shared_vars):
29
"""
30
Sets theano shared variables for the PyMC3 model.
31
"""
32
for key in shared_vars.keys():
33
self.shared_vars[key].set_value(shared_vars[key])
34
35
def _inference(self, minibatches, n=200000):
36
"""
37
Runs minibatch variational ADVI and then sample from those results.
38
39
Parameters
40
----------
41
minibatches: minibatches for ADVI
42
43
n: number of iterations for ADVI fit, defaults to 200000
44
"""
45
with self.cached_model:
46
advi = pm.ADVI()
47
approx = pm.fit(
48
n=n,
49
method=advi,
50
more_replacements=minibatches,
51
callbacks=[pm.callbacks.CheckParametersConvergence()]
52
)
53
54
self.advi_trace = approx.sample(draws=10000)
55
56
self.advi_hist = advi.hist
57
58
def fit(self):
59
raise NotImplementedError
60
61
def predict(self):
62
raise NotImplementedError
63
64
def score(self):
65
raise NotImplementedError
66
67
def save(self, file_prefix, custom_params=None):
68
"""
69
Saves the advi_trace and custom params to files with the given file_prefix.
70
71
Parameters
72
----------
73
file_prefix: str, path and prefix used to identify where to save the trace for this model.
74
Ex: given file_prefix = "path/to/file/"
75
This will attempt to save to "path/to/file/advi_trace.pickle"
76
77
custom_params: Dictionary of custom parameters to save. Defaults to None
78
"""
79
with open(file_prefix + 'advi_trace.pickle', 'wb') as fileObject:
80
joblib.dump(self.advi_trace, fileObject)
81
82
if custom_params:
83
with open(file_prefix + 'params.pickle', 'wb') as fileObject:
84
joblib.dump(custom_params, fileObject)
85
86
def load(self, file_prefix, load_custom_params=False):
87
"""
88
Loads a saved version of the advi_trace, v_params, and custom param files with the given file_prefix.
89
90
Parameters
91
----------
92
file_prefix: str, path and prefix used to identify where to load the saved trace for this model.
93
Ex: given file_prefix = "path/to/file/"
94
This will attempt to load "path/to/file/advi_trace.pickle"
95
96
load_custom_params: Boolean flag to indicate whether custom parameters should be loaded. Defaults to False.
97
98
Returns
99
----------
100
custom_params: Dictionary of custom parameters
101
"""
102
self.advi_trace = joblib.load(file_prefix + 'advi_trace.pickle')
103
104
custom_params = None
105
if load_custom_params:
106
custom_params = joblib.load(file_prefix + 'params.pickle')
107
108
return custom_params
109
110
def plot_elbo(self):
111
"""
112
Plot the ELBO values after running ADVI minibatch.
113
"""
114
sns.set_style("white")
115
plt.plot(-self.advi_hist)
116
plt.ylabel('ELBO')
117
plt.xlabel('iteration')
118
sns.despine()
119