Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/thinkbayes2
Path: blob/master/soln/utils.py
1901 views
1
import warnings
2
import pandas as pd
3
import numpy as np
4
import matplotlib.pyplot as plt
5
6
from empiricaldist import Pmf
7
8
from scipy.stats import gaussian_kde
9
from scipy.stats import binom
10
from scipy.stats import gamma
11
from scipy.stats import poisson
12
13
14
def values(series):
15
"""Make a series of values and the number of times they appear.
16
17
Returns a DataFrame because they get rendered better in Jupyter.
18
19
series: Pandas Series
20
21
returns: Pandas DataFrame
22
"""
23
series = series.value_counts(dropna=False).sort_index()
24
series.index.name = 'values'
25
series.name = 'counts'
26
return pd.DataFrame(series)
27
28
29
def write_table(table, label, **options):
30
"""Write a table in LaTex format.
31
32
table: DataFrame
33
label: string
34
options: passed to DataFrame.to_latex
35
"""
36
filename = f'tables/{label}.tex'
37
fp = open(filename, 'w')
38
s = table.to_latex(**options)
39
fp.write(s)
40
fp.close()
41
42
43
def write_pmf(pmf, label):
44
"""Write a Pmf object as a table.
45
46
pmf: Pmf
47
label: string
48
"""
49
df = pd.DataFrame()
50
df['qs'] = pmf.index
51
df['ps'] = pmf.values
52
write_table(df, label, index=False)
53
54
55
def underride(d, **options):
56
"""Add key-value pairs to d only if key is not in d.
57
58
d: dictionary
59
options: keyword args to add to d
60
"""
61
for key, val in options.items():
62
d.setdefault(key, val)
63
64
return d
65
66
67
class SuppressWarning:
68
def __enter__(self):
69
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
70
71
def __exit__(self, exc_type, exc_value, traceback):
72
warnings.filterwarnings("default", category=UserWarning, module="matplotlib")
73
74
75
def decorate(**options):
76
"""Decorate the current axes.
77
78
Call decorate with keyword arguments like
79
decorate(title='Title',
80
xlabel='x',
81
ylabel='y')
82
83
The keyword arguments can be any of the axis properties
84
https://matplotlib.org/api/axes_api.html
85
"""
86
ax = plt.gca()
87
ax.set(**options)
88
89
handles, labels = ax.get_legend_handles_labels()
90
if handles:
91
ax.legend(handles, labels)
92
93
with SuppressWarning():
94
plt.tight_layout()
95
96
97
def savefig(root, **options):
98
"""Save the current figure.
99
100
root: string filename root
101
options: passed to plt.savefig
102
"""
103
format = options.pop('format', None)
104
if format:
105
formats = [format]
106
else:
107
formats = ['pdf', 'png']
108
109
for format in formats:
110
fname = f'figs/{root}.{format}'
111
plt.savefig(fname, **options)
112
113
114
def make_die(sides):
115
"""Pmf that represents a die with the given number of sides.
116
117
sides: int
118
119
returns: Pmf
120
"""
121
outcomes = np.arange(1, sides+1)
122
die = Pmf(1/sides, outcomes)
123
return die
124
125
126
def add_dist_seq(seq):
127
"""Distribution of sum of quantities from PMFs.
128
129
seq: sequence of Pmf objects
130
131
returns: Pmf
132
"""
133
total = seq[0]
134
for other in seq[1:]:
135
total = total.add_dist(other)
136
return total
137
138
139
def make_mixture(pmf, pmf_seq):
140
"""Make a mixture of distributions.
141
142
pmf: mapping from each hypothesis to its probability
143
(or it can be a sequence of probabilities)
144
pmf_seq: sequence of Pmfs, each representing
145
a conditional distribution for one hypothesis
146
147
returns: Pmf representing the mixture
148
"""
149
df = pd.DataFrame(pmf_seq).fillna(0).transpose()
150
df *= np.array(pmf)
151
total = df.sum(axis=1)
152
return Pmf(total)
153
154
155
def summarize(posterior, digits=3, prob=0.9):
156
"""Print the mean and CI of a distribution.
157
158
posterior: Pmf
159
digits: number of digits to round to
160
prob: probability in the CI
161
"""
162
mean = np.round(posterior.mean(), 3)
163
ci = posterior.credible_interval(prob)
164
print (mean, ci)
165
166
167
def outer_product(s1, s2):
168
"""Compute the outer product of two Series.
169
170
First Series goes down the rows;
171
second goes across the columns.
172
173
s1: Series
174
s2: Series
175
176
return: DataFrame
177
"""
178
a = np.multiply.outer(s1.to_numpy(), s2.to_numpy())
179
return pd.DataFrame(a, index=s1.index, columns=s2.index)
180
181
182
def make_uniform(qs, name=None, **options):
183
"""Make a Pmf that represents a uniform distribution.
184
185
qs: quantities
186
name: string name for the quantities
187
options: passed to Pmf
188
189
returns: Pmf
190
"""
191
pmf = Pmf(1.0, qs, **options)
192
pmf.normalize()
193
if name:
194
pmf.index.name = name
195
return pmf
196
197
198
def make_joint(s1, s2):
199
"""Compute the outer product of two Series.
200
201
First Series goes across the columns;
202
second goes down the rows.
203
204
s1: Series
205
s2: Series
206
207
return: DataFrame
208
"""
209
X, Y = np.meshgrid(s1, s2)
210
return pd.DataFrame(X*Y, columns=s1.index, index=s2.index)
211
212
213
def make_mesh(joint):
214
"""Make a mesh grid from the quantities in a joint distribution.
215
216
joint: DataFrame representing a joint distribution
217
218
returns: a mesh grid (X, Y) where X contains the column names and
219
Y contains the row labels
220
"""
221
x = joint.columns
222
y = joint.index
223
return np.meshgrid(x, y)
224
225
226
def normalize(joint):
227
"""Normalize a joint distribution.
228
229
joint: DataFrame
230
"""
231
prob_data = joint.to_numpy().sum()
232
joint /= prob_data
233
return prob_data
234
235
236
def marginal(joint, axis):
237
"""Compute a marginal distribution.
238
239
axis=0 returns the marginal distribution of the first variable
240
axis=1 returns the marginal distribution of the second variable
241
242
joint: DataFrame representing a joint distribution
243
axis: int axis to sum along
244
245
returns: Pmf
246
"""
247
return Pmf(joint.sum(axis=axis))
248
249
250
def pmf_marginal(joint_pmf, level):
251
"""Compute a marginal distribution.
252
253
joint_pmf: Pmf representing a joint distribution
254
level: int, level to sum along
255
256
returns: Pmf
257
"""
258
return Pmf(joint_pmf.sum(level=level))
259
260
261
def plot_contour(joint, **options):
262
"""Plot a joint distribution.
263
264
joint: DataFrame representing a joint PMF
265
"""
266
low = joint.to_numpy().min()
267
high = joint.to_numpy().max()
268
levels = np.linspace(low, high, 6)
269
levels = levels[1:]
270
271
underride(options, levels=levels, linewidths=1)
272
cs = plt.contour(joint.columns, joint.index, joint, **options)
273
decorate(xlabel=joint.columns.name,
274
ylabel=joint.index.name)
275
return cs
276
277
278
def make_binomial(n, p):
279
"""Make a binomial distribution.
280
281
n: number of trials
282
p: probability of success
283
284
returns: Pmf representing the distribution of k
285
"""
286
ks = np.arange(n+1)
287
ps = binom.pmf(ks, n, p)
288
return Pmf(ps, ks)
289
290
291
def make_gamma_dist(alpha, beta):
292
"""Makes a gamma object.
293
294
alpha: shape parameter
295
beta: scale parameter
296
297
returns: gamma object
298
"""
299
dist = gamma(alpha, scale=1/beta)
300
dist.alpha = alpha
301
dist.beta = beta
302
return dist
303
304
305
def make_poisson_pmf(lam, qs):
306
"""Make a PMF of a Poisson distribution.
307
308
lam: event rate
309
qs: sequence of values for `k`
310
311
returns: Pmf
312
"""
313
ps = poisson(lam).pmf(qs)
314
pmf = Pmf(ps, qs)
315
pmf.normalize()
316
return pmf
317
318
319
def pmf_from_dist(dist, qs):
320
"""Make a discrete approximation.
321
322
dist: SciPy distribution object
323
qs: quantities
324
325
returns: Pmf
326
"""
327
ps = dist.pdf(qs)
328
pmf = Pmf(ps, qs)
329
pmf.normalize()
330
return pmf
331
332
333
def kde_from_sample(sample, qs, **options):
334
"""Make a kernel density estimate from a sample
335
336
sample: sequence of values
337
qs: quantities where we should evaluate the KDE
338
339
returns: normalized Pmf
340
"""
341
kde = gaussian_kde(sample)
342
ps = kde(qs)
343
pmf = Pmf(ps, qs, **options)
344
pmf.normalize()
345
return pmf
346
347
348
def kde_from_pmf(pmf, n=101, **options):
349
"""Make a kernel density estimate from a Pmf.
350
351
pmf: Pmf object
352
n: number of points
353
354
returns: Pmf object
355
"""
356
# TODO: should this take qs rather than use min-max?
357
kde = gaussian_kde(pmf.qs, weights=pmf.ps)
358
qs = np.linspace(pmf.qs.min(), pmf.qs.max(), n)
359
ps = kde.evaluate(qs)
360
pmf = Pmf(ps, qs, **options)
361
pmf.normalize()
362
return pmf
363
364
from statsmodels.nonparametric.smoothers_lowess import lowess
365
366
def make_lowess(series):
367
"""Use LOWESS to compute a smooth line.
368
369
series: pd.Series
370
371
returns: pd.Series
372
"""
373
endog = series.values
374
exog = series.index.values
375
376
smooth = lowess(endog, exog)
377
index, data = np.transpose(smooth)
378
379
return pd.Series(data, index=index)
380
381
def plot_series_lowess(series, color):
382
"""Plots a series of data points and a smooth line.
383
384
series: pd.Series
385
color: string or tuple
386
"""
387
series.plot(lw=0, marker='o', color=color, alpha=0.5)
388
smooth = make_lowess(series)
389
smooth.plot(label='_', color=color)
390
391
from seaborn import JointGrid
392
393
def joint_plot(joint, **options):
394
"""Show joint and marginal distributions.
395
396
joint: DataFrame that represents a joint distribution
397
options: passed to JointGrid
398
"""
399
# get the names of the parameters
400
x = joint.columns.name
401
x = 'x' if x is None else x
402
403
y = joint.index.name
404
y = 'y' if y is None else y
405
406
# make a JointGrid with minimal data
407
data = pd.DataFrame({x:[0], y:[0]})
408
g = JointGrid(x=x, y=y, data=data, **options)
409
410
# replace the contour plot
411
g.ax_joint.contour(joint.columns,
412
joint.index,
413
joint,
414
cmap='viridis')
415
416
# replace the marginals
417
marginal_x = marginal(joint, 0)
418
g.ax_marg_x.plot(marginal_x.qs, marginal_x.ps)
419
420
marginal_y = marginal(joint, 1)
421
g.ax_marg_y.plot(marginal_y.ps, marginal_y.qs)
422
423
424
Gray20 = (0.162, 0.162, 0.162, 0.7)
425
Gray30 = (0.262, 0.262, 0.262, 0.7)
426
Gray40 = (0.355, 0.355, 0.355, 0.7)
427
Gray50 = (0.44, 0.44, 0.44, 0.7)
428
Gray60 = (0.539, 0.539, 0.539, 0.7)
429
Gray70 = (0.643, 0.643, 0.643, 0.7)
430
Gray80 = (0.757, 0.757, 0.757, 0.7)
431
Pu20 = (0.247, 0.0, 0.49, 0.7)
432
Pu30 = (0.327, 0.149, 0.559, 0.7)
433
Pu40 = (0.395, 0.278, 0.62, 0.7)
434
Pu50 = (0.46, 0.406, 0.685, 0.7)
435
Pu60 = (0.529, 0.517, 0.742, 0.7)
436
Pu70 = (0.636, 0.623, 0.795, 0.7)
437
Pu80 = (0.743, 0.747, 0.866, 0.7)
438
Bl20 = (0.031, 0.188, 0.42, 0.7)
439
Bl30 = (0.031, 0.265, 0.534, 0.7)
440
Bl40 = (0.069, 0.365, 0.649, 0.7)
441
Bl50 = (0.159, 0.473, 0.725, 0.7)
442
Bl60 = (0.271, 0.581, 0.781, 0.7)
443
Bl70 = (0.417, 0.681, 0.838, 0.7)
444
Bl80 = (0.617, 0.791, 0.882, 0.7)
445
Gr20 = (0.0, 0.267, 0.106, 0.7)
446
Gr30 = (0.0, 0.312, 0.125, 0.7)
447
Gr40 = (0.001, 0.428, 0.173, 0.7)
448
Gr50 = (0.112, 0.524, 0.253, 0.7)
449
Gr60 = (0.219, 0.633, 0.336, 0.7)
450
Gr70 = (0.376, 0.73, 0.424, 0.7)
451
Gr80 = (0.574, 0.824, 0.561, 0.7)
452
Or20 = (0.498, 0.153, 0.016, 0.7)
453
Or30 = (0.498, 0.153, 0.016, 0.7)
454
Or40 = (0.599, 0.192, 0.013, 0.7)
455
Or50 = (0.746, 0.245, 0.008, 0.7)
456
Or60 = (0.887, 0.332, 0.031, 0.7)
457
Or70 = (0.966, 0.475, 0.147, 0.7)
458
Or80 = (0.992, 0.661, 0.389, 0.7)
459
Re20 = (0.404, 0.0, 0.051, 0.7)
460
Re30 = (0.495, 0.022, 0.063, 0.7)
461
Re40 = (0.662, 0.062, 0.085, 0.7)
462
Re50 = (0.806, 0.104, 0.118, 0.7)
463
Re60 = (0.939, 0.239, 0.178, 0.7)
464
Re70 = (0.985, 0.448, 0.322, 0.7)
465
Re80 = (0.988, 0.646, 0.532, 0.7)
466
467
from cycler import cycler
468
469
color_list = [Bl30, Or70, Gr50, Re60, Pu20, Gray70, Re80, Gray50,
470
Gr70, Bl50, Re40, Pu70, Or50, Gr30, Bl70, Pu50, Gray30]
471
color_cycle = cycler(color=color_list)
472
473
def set_pyplot_params():
474
# plt.rcParams['figure.dpi'] = 300
475
plt.rcParams['axes.prop_cycle'] = color_cycle
476
plt.rcParams['lines.linewidth'] = 3
477
478