Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/berksons_gaussian.py
1192 views
1
# Illustrate Explaining Away (Berskson's Paraodx) for a gaussian DAG
2
# Based on https://ff13.fastforwardlabs.com/
3
4
import numpy as np
5
from numpy.random import randn
6
import matplotlib.pyplot as plt
7
import pyprobml_utils as pml
8
import seaborn as sns
9
10
np.random.seed(0)
11
12
def x():
13
return -5 + randn()
14
15
def y():
16
return 5 + randn()
17
18
def z(x, y):
19
return x + y + randn()
20
21
def sample():
22
x_ = x()
23
y_ = y()
24
z_ = z(x_, y_)
25
return x_, y_, z_
26
27
np.random.seed(0)
28
29
nsamples = 8000
30
xs = -5 + randn(nsamples)
31
ys = 5 + randn(nsamples)
32
zs = xs + ys + randn(nsamples)
33
34
35
num_bins = 20
36
plt.figure(figsize=(6, 6))
37
_ = plt.hist(xs, num_bins, facecolor='#7FDAD9', label='x')
38
_ = plt.hist(ys, num_bins, facecolor='#B5C5E2',label='y')
39
_ = plt.hist(zs, 2*num_bins, facecolor='#FAC9AC',alpha=0.8, label='z')
40
plt.ylabel('number of samples', fontsize=14)
41
plt.xlabel('value', fontsize=14)
42
plt.legend();
43
pml.savefig('berksons-hist.pdf', dpi=300)
44
plt.show()
45
46
plt.figure(figsize=(6, 6))
47
plt.scatter(xs, ys, color='#00B6B5', alpha=0.1)
48
plt.ylabel('y', fontsize=14)
49
plt.xlabel('x', fontsize=14)
50
plt.xlim([-10, 0])
51
plt.ylim([0, 10])
52
pml.savefig('berksons-scatter.pdf', dpi=300)
53
plt.show()
54
55
indices = np.argwhere(zs > 2.5)
56
num_bins = 20
57
plt.figure(figsize=(6, 6))
58
_ = plt.hist(xs[indices], num_bins, facecolor='#7FDAD9', label='x')
59
_ = plt.hist(ys[indices], num_bins, facecolor='#B5C5E2',label='y')
60
_ = plt.hist(zs[indices], 2*num_bins, facecolor='#FAC9AC',alpha=0.8, label='z')
61
plt.ylabel('number of samples', fontsize=14)
62
plt.xlabel('value', fontsize=14)
63
#plt.ylim([0, 300])
64
plt.legend();
65
pml.savefig('berksons-conditioned-hist.pdf', dpi=300)
66
plt.show()
67
68
plt.figure(figsize=(6, 6))
69
sns.regplot(x=xs[indices], y=ys[indices], ci=None, color='tab:orange', scatter_kws={'alpha':0.1, "color":'#00B6B5'})
70
plt.ylabel('y', fontsize=14)
71
plt.xlabel('x', fontsize=14)
72
plt.xlim([-10, 0])
73
plt.ylim([0, 10])
74
pml.savefig('berksons-conditioned-scatter.pdf', dpi=300)
75
plt.show()
76