Path: blob/master/models/__init__.py
411 views
"""1Author: Nicole Carson23Copied from https://raw.githubusercontent.com/parsing-science/ps-toolkit/master/ps_toolkit/pymc3_models/__init__.py.4"""56import joblib7import matplotlib.pyplot as plt8import pymc3 as pm9import seaborn as sns10from sklearn.base import BaseEstimator111213class BayesianModel(BaseEstimator):14"""15Bayesian model base class16"""17def __init__(self):18self.advi_hist = None # The ADVI history.19self.advi_trace = None # The ADVI trace.20self.cached_model = None # The PyMC3 model object.21self.num_pred = None # The number of predictor variables (number of columns in X).22self.shared_vars = None # A dictionary of Theano shared variables.2324def create_model(self):25raise NotImplementedError2627def _set_shared_vars(self, shared_vars):28"""29Sets theano shared variables for the PyMC3 model.30"""31for key in shared_vars.keys():32self.shared_vars[key].set_value(shared_vars[key])3334def _inference(self, minibatches, n=200000):35"""36Runs minibatch variational ADVI and then sample from those results.3738Parameters39----------40minibatches: minibatches for ADVI4142n: number of iterations for ADVI fit, defaults to 20000043"""44with self.cached_model:45advi = pm.ADVI()46approx = pm.fit(47n=n,48method=advi,49more_replacements=minibatches,50callbacks=[pm.callbacks.CheckParametersConvergence()]51)5253self.advi_trace = approx.sample(draws=10000)5455self.advi_hist = advi.hist5657def fit(self):58raise NotImplementedError5960def predict(self):61raise NotImplementedError6263def score(self):64raise NotImplementedError6566def save(self, file_prefix, custom_params=None):67"""68Saves the advi_trace and custom params to files with the given file_prefix.6970Parameters71----------72file_prefix: str, path and prefix used to identify where to save the trace for this model.73Ex: given file_prefix = "path/to/file/"74This will attempt to save to "path/to/file/advi_trace.pickle"7576custom_params: Dictionary of custom parameters to save. Defaults to None77"""78with open(file_prefix + 'advi_trace.pickle', 'wb') as fileObject:79joblib.dump(self.advi_trace, fileObject)8081if custom_params:82with open(file_prefix + 'params.pickle', 'wb') as fileObject:83joblib.dump(custom_params, fileObject)8485def load(self, file_prefix, load_custom_params=False):86"""87Loads a saved version of the advi_trace, v_params, and custom param files with the given file_prefix.8889Parameters90----------91file_prefix: str, path and prefix used to identify where to load the saved trace for this model.92Ex: given file_prefix = "path/to/file/"93This will attempt to load "path/to/file/advi_trace.pickle"9495load_custom_params: Boolean flag to indicate whether custom parameters should be loaded. Defaults to False.9697Returns98----------99custom_params: Dictionary of custom parameters100"""101self.advi_trace = joblib.load(file_prefix + 'advi_trace.pickle')102103custom_params = None104if load_custom_params:105custom_params = joblib.load(file_prefix + 'params.pickle')106107return custom_params108109def plot_elbo(self):110"""111Plot the ELBO values after running ADVI minibatch.112"""113sns.set_style("white")114plt.plot(-self.advi_hist)115plt.ylabel('ELBO')116plt.xlabel('iteration')117sns.despine()118119