| Hosted by CoCalc | Download
# module import import pystan import numpy as np import pylab as py import pandas as pd ## data simulation x = np.arange(1, 100, 5) y = 2.5 + .5 * x + np.random.randn(20) * 10 # get number of observations N = len(x) # plot the data py.plot(x,y, 'o') py.show() # STAN model (this is the most important part) regress_code = """ data { int<lower = 0> N; // number of observations real y[N]; // response variable real x[N]; // predictor variable } parameters { real a; // intercept real b; // slope real<lower=0> sigma; // standard deviation } transformed parameters { real mu[N]; // fitted values for(i in 1:N) mu[i] <- a + b*x[i]; } model { y ~ normal(mu, sigma); } """ # make a dictionary containing all data to be passed to STAN regress_dat = {'x': x, 'y': y, 'N': N} # Fit the model fit = pystan.stan(model_code=regress_code, data=regress_dat, iter=1000, chains=4) # model summary print fit # show a traceplot of ALL parameters. This is a bear if you have many fit.traceplot() py.show() # Instead, show a traceplot for single parameter fit.plot(['a']) py.show() ##### PREDICTION #### # make a dataframe of parameter estimates for all chains params = pd.DataFrame({'a': fit.extract('a', permuted=True), 'b': fit.extract('b', permuted=True)}) # next, make a prediction function. Making a function makes every step following this 10 times easier def stanPred(p): fitted = p[0] + p[1] * predX return pd.Series({'fitted': fitted}) # make a prediction vector (the values of X for which you want to predict) predX = np.arange(1, 100) # get the median parameter estimates medParam = params.median() # predict yhat = stanPred(medParam) # get the predicted values for each chain. This is super convenient in pandas because # it is possible to have a single column where each element is a list chainPreds = params.apply(stanPred, axis = 1) ## PLOTTING # create a random index for chain sampling idx = np.random.choice(1999, 50) # plot each chain. chainPreds.iloc[i, 0] gets predicted values from the ith set of parameter estimates for i in range(len(idx)): py.plot(predX, chainPreds.iloc[idx[i], 0], color='lightgrey') # original data py.plot(x, y, 'ko') # fitted values py.plot(predX, yhat['fitted'], 'k') # supplementals py.xlabel('X') py.ylabel('Y') py.show()
[<matplotlib.lines.Line2D object at 0x7f2fb9f727d0>]
Error in lines 21-23 Traceback (most recent call last): File "/projects/4a5f0542-5873-4eed-a85c-a18c706e8bcd/.sagemathcloud/sage_server.py", line 865, in execute exec compile(block+'\n', '', 'single') in namespace, locals File "", line 2, in <module> File "/usr/local/sage/sage-6.4/local/lib/python2.7/site-packages/pystan/api.py", line 370, in stan save_dso=save_dso, verbose=verbose) File "/usr/local/sage/sage-6.4/local/lib/python2.7/site-packages/pystan/model.py", line 305, in __init__ build_extension.run() File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/build_ext.py", line 337, in run self.build_extensions() File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/build_ext.py", line 446, in build_extensions self.build_extension(ext) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/build_ext.py", line 496, in build_extension depends=ext.depends) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/ccompiler.py", line 574, in compile self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/unixccompiler.py", line 124, in _compile raise CompileError, msg CompileError: command 'gcc' failed with exit status 4
# make a function that iterates over every predicted values in every chain and returns the quantiles. For example: def quantileGet(q): # make a list to store the quantiles quants = [] # for every predicted value for i in range(len(predX)): # make a vector to store the predictions from each chain val = [] # next go down the rows and store the values for j in range(chainPreds.shape[0]): val.append(chainPreds['fitted'][j][i]) # return the quantile for the predictions. quants.append(np.percentile(val, q)) return quants # NOTE THAT NUMPY DOES PERCENTILES, SO MULTIPLE QUANTILE BY 100 # 2.5% quantile lower = quantileGet(2.5) #97.5 upper = quantileGet(97.5) # plot this fig = py.figure() ax = fig.add_subplot(111) # shade the credible interval ax.fill_between(predX, lower, upper, facecolor = 'lightgrey', edgecolor = 'none') # plot the data ax.plot(x, y, 'ko') # plot the fitted line ax.plot(predX, yhat['fitted'], 'k') # supplementals ax.set_xlabel('X') ax.set_ylabel('Y') ax.grid() py.show()
Error in lines 17-18 Traceback (most recent call last): File "/projects/4a5f0542-5873-4eed-a85c-a18c706e8bcd/.sagemathcloud/sage_server.py", line 865, in execute exec compile(block+'\n', '', 'single') in namespace, locals File "", line 1, in <module> File "", line 5, in quantileGet NameError: global name 'predX' is not defined