Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/thinkbayes2
Path: blob/master/notebooks/utils.py
1901 views
1
import warnings
2
import pandas as pd
3
import numpy as np
4
import matplotlib.pyplot as plt
5
6
from empiricaldist import Pmf
7
8
from scipy.stats import gaussian_kde
9
from scipy.stats import binom
10
from scipy.stats import gamma
11
from scipy.stats import poisson
12
13
14
def values(series):
15
"""Make a series of values and the number of times they appear.
16
17
Returns a DataFrame because they get rendered better in Jupyter.
18
19
series: Pandas Series
20
21
returns: Pandas DataFrame
22
"""
23
series = series.value_counts(dropna=False).sort_index()
24
series.index.name = 'values'
25
series.name = 'counts'
26
return pd.DataFrame(series)
27
28
29
def write_table(table, label, **options):
30
"""Write a table in LaTex format.
31
32
table: DataFrame
33
label: string
34
options: passed to DataFrame.to_latex
35
"""
36
filename = f'tables/{label}.tex'
37
fp = open(filename, 'w')
38
s = table.to_latex(**options)
39
fp.write(s)
40
fp.close()
41
42
43
def write_pmf(pmf, label):
44
"""Write a Pmf object as a table.
45
46
pmf: Pmf
47
label: string
48
"""
49
df = pd.DataFrame()
50
df['qs'] = pmf.index
51
df['ps'] = pmf.values
52
write_table(df, label, index=False)
53
54
55
def underride(d, **options):
56
"""Add key-value pairs to d only if key is not in d.
57
58
d: dictionary
59
options: keyword args to add to d
60
"""
61
for key, val in options.items():
62
d.setdefault(key, val)
63
64
return d
65
66
67
class SuppressWarning:
68
def __enter__(self):
69
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
70
71
def __exit__(self, exc_type, exc_value, traceback):
72
warnings.filterwarnings("default", category=UserWarning, module="matplotlib")
73
74
75
def decorate(**options):
76
"""Decorate the current axes.
77
78
Call decorate with keyword arguments like
79
decorate(title='Title',
80
xlabel='x',
81
ylabel='y')
82
83
The keyword arguments can be any of the axis properties
84
https://matplotlib.org/api/axes_api.html
85
"""
86
ax = plt.gca()
87
ax.set(**options)
88
89
handles, labels = ax.get_legend_handles_labels()
90
if handles:
91
ax.legend(handles, labels)
92
93
with SuppressWarning():
94
plt.tight_layout()
95
96
97
def savefig(root, **options):
98
"""Save the current figure.
99
100
root: string filename root
101
options: passed to plt.savefig
102
"""
103
format = options.pop('format', None)
104
if format:
105
formats = [format]
106
else:
107
formats = ['pdf', 'png']
108
109
for format in formats:
110
fname = f'figs/{root}.{format}'
111
plt.savefig(fname, **options)
112
113
114
def make_die(sides):
115
"""Pmf that represents a die with the given number of sides.
116
117
sides: int
118
119
returns: Pmf
120
"""
121
outcomes = np.arange(1, sides+1)
122
die = Pmf(1/sides, outcomes)
123
return die
124
125
126
def add_dist_seq(seq):
127
"""Distribution of sum of quantities from PMFs.
128
129
seq: sequence of Pmf objects
130
131
returns: Pmf
132
"""
133
total = seq[0]
134
for other in seq[1:]:
135
total = total.add_dist(other)
136
return total
137
138
139
def make_mixture(pmf, pmf_seq):
140
"""Make a mixture of distributions.
141
142
pmf: mapping from each hypothesis to its probability
143
(or it can be a sequence of probabilities)
144
pmf_seq: sequence of Pmfs, each representing
145
a conditional distribution for one hypothesis
146
147
returns: Pmf representing the mixture
148
"""
149
df = pd.DataFrame(pmf_seq).fillna(0).transpose()
150
df *= np.array(pmf)
151
total = df.sum(axis=1)
152
return Pmf(total)
153
154
155
def summarize(posterior, digits=3, prob=0.9):
156
"""Print the mean and CI of a distribution.
157
158
posterior: Pmf
159
digits: number of digits to round to
160
prob: probability in the CI
161
"""
162
mean = np.round(posterior.mean(), 3)
163
ci = posterior.credible_interval(prob)
164
print (mean, ci)
165
166
167
def outer_product(s1, s2):
168
"""Compute the outer product of two Series.
169
170
First Series goes down the rows;
171
second goes across the columns.
172
173
s1: Series
174
s2: Series
175
176
return: DataFrame
177
"""
178
a = np.multiply.outer(s1.to_numpy(), s2.to_numpy())
179
return pd.DataFrame(a, index=s1.index, columns=s2.index)
180
181
182
def make_uniform(qs, name=None, **options):
183
"""Make a Pmf that represents a uniform distribution.
184
185
qs: quantities
186
name: string name for the quantities
187
options: passed to Pmf
188
189
returns: Pmf
190
"""
191
pmf = Pmf(1.0, qs, **options)
192
pmf.normalize()
193
if name:
194
pmf.index.name = name
195
return pmf
196
197
198
def make_joint(s1, s2):
199
"""Compute the outer product of two Series.
200
201
First Series goes across the columns;
202
second goes down the rows.
203
204
s1: Series
205
s2: Series
206
207
return: DataFrame
208
"""
209
X, Y = np.meshgrid(s1, s2)
210
return pd.DataFrame(X*Y, columns=s1.index, index=s2.index)
211
212
213
def make_mesh(joint):
214
"""Make a mesh grid from the quantities in a joint distribution.
215
216
joint: DataFrame representing a joint distribution
217
218
returns: a mesh grid (X, Y) where X contains the column names and
219
Y contains the row labels
220
"""
221
x = joint.columns
222
y = joint.index
223
return np.meshgrid(x, y)
224
225
226
def normalize(joint):
227
"""Normalize a joint distribution.
228
229
joint: DataFrame
230
"""
231
prob_data = joint.to_numpy().sum()
232
joint /= prob_data
233
return prob_data
234
235
236
def marginal(joint, axis):
237
"""Compute a marginal distribution.
238
239
axis=0 returns the marginal distribution of the first variable
240
axis=1 returns the marginal distribution of the second variable
241
242
joint: DataFrame representing a joint distribution
243
axis: int axis to sum along
244
245
returns: Pmf
246
"""
247
return Pmf(joint.sum(axis=axis))
248
249
250
def conditional(joint, axis, value):
251
"""Compute a conditional distribution.
252
253
joint: DataFrame representing a joint distribution
254
axis: int axis to condition on (0 for rows, 1 for columns)
255
value: value to condition on (row index if axis=0, column index if axis=1)
256
257
returns: Pmf
258
"""
259
# Condition on the specified axis
260
cond = joint.xs(value, axis=axis)
261
return Pmf(cond / cond.sum())
262
263
264
def pmf_marginal(joint_pmf, level):
265
"""Compute a marginal distribution.
266
267
joint_pmf: Pmf representing a joint distribution
268
level: int, level to sum along
269
270
returns: Pmf
271
"""
272
return Pmf(joint_pmf.sum(level=level))
273
274
275
def plot_contour(joint, **options):
276
"""Plot a joint distribution.
277
278
joint: DataFrame representing a joint PMF
279
"""
280
low = joint.to_numpy().min()
281
high = joint.to_numpy().max()
282
levels = np.linspace(low, high, 6)
283
levels = levels[1:]
284
285
underride(options, levels=levels, linewidths=1)
286
cs = plt.contour(joint.columns, joint.index, joint, **options)
287
decorate(xlabel=joint.columns.name,
288
ylabel=joint.index.name)
289
return cs
290
291
292
def make_binomial(n, p):
293
"""Make a binomial distribution.
294
295
n: number of trials
296
p: probability of success
297
298
returns: Pmf representing the distribution of k
299
"""
300
ks = np.arange(n+1)
301
ps = binom.pmf(ks, n, p)
302
return Pmf(ps, ks)
303
304
305
def make_gamma_dist(alpha, beta):
306
"""Makes a gamma object.
307
308
alpha: shape parameter
309
beta: scale parameter
310
311
returns: gamma object
312
"""
313
dist = gamma(alpha, scale=1/beta)
314
dist.alpha = alpha
315
dist.beta = beta
316
return dist
317
318
319
def make_poisson_pmf(lam, qs):
320
"""Make a PMF of a Poisson distribution.
321
322
lam: event rate
323
qs: sequence of values for `k`
324
325
returns: Pmf
326
"""
327
ps = poisson(lam).pmf(qs)
328
pmf = Pmf(ps, qs)
329
pmf.normalize()
330
return pmf
331
332
333
def pmf_from_dist(dist, qs):
334
"""Make a discrete approximation.
335
336
dist: SciPy distribution object
337
qs: quantities
338
339
returns: Pmf
340
"""
341
ps = dist.pdf(qs)
342
pmf = Pmf(ps, qs)
343
pmf.normalize()
344
return pmf
345
346
347
def kde_from_sample(sample, qs, **options):
348
"""Make a kernel density estimate from a sample
349
350
sample: sequence of values
351
qs: quantities where we should evaluate the KDE
352
353
returns: normalized Pmf
354
"""
355
kde = gaussian_kde(sample)
356
ps = kde(qs)
357
pmf = Pmf(ps, qs, **options)
358
pmf.normalize()
359
return pmf
360
361
362
def kde_from_pmf(pmf, n=101, **options):
363
"""Make a kernel density estimate from a Pmf.
364
365
pmf: Pmf object
366
n: number of points
367
368
returns: Pmf object
369
"""
370
# TODO: should this take qs rather than use min-max?
371
kde = gaussian_kde(pmf.qs, weights=pmf.ps)
372
qs = np.linspace(pmf.qs.min(), pmf.qs.max(), n)
373
ps = kde.evaluate(qs)
374
pmf = Pmf(ps, qs, **options)
375
pmf.normalize()
376
return pmf
377
378
from statsmodels.nonparametric.smoothers_lowess import lowess
379
380
def make_lowess(series):
381
"""Use LOWESS to compute a smooth line.
382
383
series: pd.Series
384
385
returns: pd.Series
386
"""
387
endog = series.values
388
exog = series.index.values
389
390
smooth = lowess(endog, exog)
391
index, data = np.transpose(smooth)
392
393
return pd.Series(data, index=index)
394
395
def plot_series_lowess(series, color):
396
"""Plots a series of data points and a smooth line.
397
398
series: pd.Series
399
color: string or tuple
400
"""
401
series.plot(lw=0, marker='o', color=color, alpha=0.5)
402
smooth = make_lowess(series)
403
smooth.plot(label='_', color=color)
404
405
from seaborn import JointGrid
406
407
def joint_plot(joint, **options):
408
"""Show joint and marginal distributions.
409
410
joint: DataFrame that represents a joint distribution
411
options: passed to JointGrid
412
"""
413
# get the names of the parameters
414
x = joint.columns.name
415
x = 'x' if x is None else x
416
417
y = joint.index.name
418
y = 'y' if y is None else y
419
420
# make a JointGrid with minimal data
421
data = pd.DataFrame({x:[0], y:[0]})
422
g = JointGrid(x=x, y=y, data=data, **options)
423
424
# replace the contour plot
425
g.ax_joint.contour(joint.columns,
426
joint.index,
427
joint,
428
cmap='viridis')
429
430
# replace the marginals
431
marginal_x = marginal(joint, 0)
432
g.ax_marg_x.plot(marginal_x.qs, marginal_x.ps)
433
434
marginal_y = marginal(joint, 1)
435
g.ax_marg_y.plot(marginal_y.ps, marginal_y.qs)
436
437
438
Gray20 = (0.162, 0.162, 0.162, 0.7)
439
Gray30 = (0.262, 0.262, 0.262, 0.7)
440
Gray40 = (0.355, 0.355, 0.355, 0.7)
441
Gray50 = (0.44, 0.44, 0.44, 0.7)
442
Gray60 = (0.539, 0.539, 0.539, 0.7)
443
Gray70 = (0.643, 0.643, 0.643, 0.7)
444
Gray80 = (0.757, 0.757, 0.757, 0.7)
445
Pu20 = (0.247, 0.0, 0.49, 0.7)
446
Pu30 = (0.327, 0.149, 0.559, 0.7)
447
Pu40 = (0.395, 0.278, 0.62, 0.7)
448
Pu50 = (0.46, 0.406, 0.685, 0.7)
449
Pu60 = (0.529, 0.517, 0.742, 0.7)
450
Pu70 = (0.636, 0.623, 0.795, 0.7)
451
Pu80 = (0.743, 0.747, 0.866, 0.7)
452
Bl20 = (0.031, 0.188, 0.42, 0.7)
453
Bl30 = (0.031, 0.265, 0.534, 0.7)
454
Bl40 = (0.069, 0.365, 0.649, 0.7)
455
Bl50 = (0.159, 0.473, 0.725, 0.7)
456
Bl60 = (0.271, 0.581, 0.781, 0.7)
457
Bl70 = (0.417, 0.681, 0.838, 0.7)
458
Bl80 = (0.617, 0.791, 0.882, 0.7)
459
Gr20 = (0.0, 0.267, 0.106, 0.7)
460
Gr30 = (0.0, 0.312, 0.125, 0.7)
461
Gr40 = (0.001, 0.428, 0.173, 0.7)
462
Gr50 = (0.112, 0.524, 0.253, 0.7)
463
Gr60 = (0.219, 0.633, 0.336, 0.7)
464
Gr70 = (0.376, 0.73, 0.424, 0.7)
465
Gr80 = (0.574, 0.824, 0.561, 0.7)
466
Or20 = (0.498, 0.153, 0.016, 0.7)
467
Or30 = (0.498, 0.153, 0.016, 0.7)
468
Or40 = (0.599, 0.192, 0.013, 0.7)
469
Or50 = (0.746, 0.245, 0.008, 0.7)
470
Or60 = (0.887, 0.332, 0.031, 0.7)
471
Or70 = (0.966, 0.475, 0.147, 0.7)
472
Or80 = (0.992, 0.661, 0.389, 0.7)
473
Re20 = (0.404, 0.0, 0.051, 0.7)
474
Re30 = (0.495, 0.022, 0.063, 0.7)
475
Re40 = (0.662, 0.062, 0.085, 0.7)
476
Re50 = (0.806, 0.104, 0.118, 0.7)
477
Re60 = (0.939, 0.239, 0.178, 0.7)
478
Re70 = (0.985, 0.448, 0.322, 0.7)
479
Re80 = (0.988, 0.646, 0.532, 0.7)
480
481
from cycler import cycler
482
483
color_list = [Bl30, Or70, Gr50, Re60, Pu20, Gray70, Re80, Gray50,
484
Gr70, Bl50, Re40, Pu70, Or50, Gr30, Bl70, Pu50, Gray30]
485
color_cycle = cycler(color=color_list)
486
487
def set_pyplot_params(dpi=75):
488
"""Set the parameters for matplotlib.
489
490
dpi: int, default 75
491
"""
492
plt.rcParams['figure.figsize'] = (6, 4)
493
plt.rcParams['figure.dpi'] = dpi
494
plt.rcParams['axes.prop_cycle'] = color_cycle
495
496
# no spines
497
plt.rcParams['axes.spines.left'] = False
498
plt.rcParams['axes.spines.right'] = False
499
plt.rcParams['axes.spines.top'] = False
500
plt.rcParams['axes.spines.bottom'] = False
501