Path: blob/master/deprecated/scripts/berksons_gaussian.py
1192 views
# Illustrate Explaining Away (Berskson's Paraodx) for a gaussian DAG1# Based on https://ff13.fastforwardlabs.com/23import numpy as np4from numpy.random import randn5import matplotlib.pyplot as plt6import pyprobml_utils as pml7import seaborn as sns89np.random.seed(0)1011def x():12return -5 + randn()1314def y():15return 5 + randn()1617def z(x, y):18return x + y + randn()1920def sample():21x_ = x()22y_ = y()23z_ = z(x_, y_)24return x_, y_, z_2526np.random.seed(0)2728nsamples = 800029xs = -5 + randn(nsamples)30ys = 5 + randn(nsamples)31zs = xs + ys + randn(nsamples)323334num_bins = 2035plt.figure(figsize=(6, 6))36_ = plt.hist(xs, num_bins, facecolor='#7FDAD9', label='x')37_ = plt.hist(ys, num_bins, facecolor='#B5C5E2',label='y')38_ = plt.hist(zs, 2*num_bins, facecolor='#FAC9AC',alpha=0.8, label='z')39plt.ylabel('number of samples', fontsize=14)40plt.xlabel('value', fontsize=14)41plt.legend();42pml.savefig('berksons-hist.pdf', dpi=300)43plt.show()4445plt.figure(figsize=(6, 6))46plt.scatter(xs, ys, color='#00B6B5', alpha=0.1)47plt.ylabel('y', fontsize=14)48plt.xlabel('x', fontsize=14)49plt.xlim([-10, 0])50plt.ylim([0, 10])51pml.savefig('berksons-scatter.pdf', dpi=300)52plt.show()5354indices = np.argwhere(zs > 2.5)55num_bins = 2056plt.figure(figsize=(6, 6))57_ = plt.hist(xs[indices], num_bins, facecolor='#7FDAD9', label='x')58_ = plt.hist(ys[indices], num_bins, facecolor='#B5C5E2',label='y')59_ = plt.hist(zs[indices], 2*num_bins, facecolor='#FAC9AC',alpha=0.8, label='z')60plt.ylabel('number of samples', fontsize=14)61plt.xlabel('value', fontsize=14)62#plt.ylim([0, 300])63plt.legend();64pml.savefig('berksons-conditioned-hist.pdf', dpi=300)65plt.show()6667plt.figure(figsize=(6, 6))68sns.regplot(x=xs[indices], y=ys[indices], ci=None, color='tab:orange', scatter_kws={'alpha':0.1, "color":'#00B6B5'})69plt.ylabel('y', fontsize=14)70plt.xlabel('x', fontsize=14)71plt.xlim([-10, 0])72plt.ylim([0, 10])73pml.savefig('berksons-conditioned-scatter.pdf', dpi=300)74plt.show()7576