Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/03/simpsons_paradox.ipynb
1192 views
Kernel: Python [conda env:probml_py3912]

Simpson's Paradox

import jax import jax.numpy as jnp import seaborn as sns import matplotlib.pyplot as plt try: from probml_utils import latexify, savefig, is_latexify_enabled except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import latexify, savefig, is_latexify_enabled from scipy import stats import itertools
latexify(width_scale_factor=1)
/home/patel_zeel/miniconda3/envs/probml_py3912/lib/python3.9/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
def make_graph(df, x_col, y_col, groupby_col, save_name, fig=None): if fig is None: fig, ax = plt.subplots(ncols=2, sharey=True) palette = itertools.cycle(sns.color_palette()) scatter_kws = {"s": 10, "alpha": 0.7} slope, intercept, r_value, p_value, std_err = stats.linregress(df[x_col], df[y_col]) sns.regplot( x=x_col, y=y_col, data=df, line_kws={"label": f"All {groupby_col} \n R = {r_value:0.2f}"}, ci=1, ax=ax[0], color=next(palette), scatter_kws=scatter_kws, ) for group in df[groupby_col].unique(): subset_data = df[df[groupby_col] == group] subset_data = subset_data.dropna() slope, intercept, r_value, p_value, std_err = stats.linregress(subset_data[x_col], subset_data[y_col]) sns.regplot( x=x_col, y=y_col, data=subset_data, line_kws={"label": f"{group} \n R = {r_value:0.2f}"}, ci=1, ax=ax[1], color=next(palette), scatter_kws=scatter_kws, ) legend = fig.legend( title="Species", loc="upper center", bbox_to_anchor=(1.15, 1.05), ncol=1, fancybox=True, shadow=False, ) ax[1].set_ylabel("") fig.tight_layout() sns.despine() if is_latexify_enabled() and len(save_name) > 0: savefig(save_name, bbox_extra_artists=(legend,), bbox_inches="tight") # savefig(save_name) return fig, ax
column_mapping = { "penguins": { "species": "Species", "bill_length_mm": "Bill Length (mm)", "bill_depth_mm": "Bill Depth (mm)", }, "iris": { "species": "Species", "sepal_length": "Sepal Length", "sepal_width": "Sepal Width", }, } dataset_cols = { "penguins": { "x_col": "Bill Length (mm)", "y_col": "Bill Depth (mm)", "groupby_col": "Species", }, "iris": {"x_col": "Sepal Length", "y_col": "Sepal Width", "groupby_col": "Species"}, } for dataset in column_mapping.keys(): df = sns.load_dataset(dataset) df = df.rename(columns=column_mapping[dataset]) df = df.dropna() make_graph( df, dataset_cols[dataset]["x_col"], dataset_cols[dataset]["y_col"], dataset_cols[dataset]["groupby_col"], f"simpson_{dataset}.pdf", )
Image in a Jupyter notebookImage in a Jupyter notebook