Path: blob/master/notebooks/book2/15/linreg_height_weight.ipynb
1193 views
Kernel: Python 3
Linear regression for predicting height from weight
We illustrate priors for linear and polynomial regression using the example in sec 4.4 of Statistical Rethinking ed 2. The numpyro code is from Du Phan's site.
In [1]:
Out[1]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax version 0.3.13
jax backend cpu
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
Data
We use the "Howell" dataset, which consists of measurements of height, weight, age and sex, of a certain foraging tribe, collected by Nancy Howell.
In [2]:
Out[2]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 544 entries, 0 to 543
Data columns (total 4 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 height 544 non-null float64
1 weight 544 non-null float64
2 age 544 non-null float64
3 male 544 non-null int64
dtypes: float64(3), int64(1)
memory usage: 17.1 KB
In [3]:
Out[3]:
Prior predictive distribution
In [4]:
Gaussian prior
In [5]:
Out[5]:
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Log-Gaussian prior
In [6]:
Out[6]:
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Posterior
We use the log-gaussian prior. We compute a Laplace approximation to the posterior.
In [7]:
Out[7]:
100%|██████████| 2000/2000 [00:01<00:00, 1498.75it/s, init loss: 40631.5391, avg. loss [1901-2000]: 1078.9297]
In [8]:
Out[8]:
{'a': [DeviceArray(154.366, dtype=float32),
DeviceArray(154.785, dtype=float32),
DeviceArray(154.735, dtype=float32),
DeviceArray(154.538, dtype=float32),
DeviceArray(154.535, dtype=float32)],
'b': [DeviceArray(0.975, dtype=float32),
DeviceArray(0.89, dtype=float32),
DeviceArray(0.819, dtype=float32),
DeviceArray(0.833, dtype=float32),
DeviceArray(1.012, dtype=float32)],
'mu': [DeviceArray(157.129, dtype=float32),
DeviceArray(146.077, dtype=float32),
DeviceArray(141.573, dtype=float32),
DeviceArray(162.213, dtype=float32),
DeviceArray(150.747, dtype=float32)],
'sigma': [DeviceArray(4.976, dtype=float32),
DeviceArray(4.944, dtype=float32),
DeviceArray(5.283, dtype=float32),
DeviceArray(4.878, dtype=float32),
DeviceArray(4.895, dtype=float32)]}
Posterior predictive
In [9]:
Out[9]:
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
(1000, 46)
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Polynomial regression
We will now consider the full dataset, including children. The resulting mapping from weight to height is now nonlinear.
Data
In [10]:
Out[10]:
In [11]:
Out[11]:
Fit model
In [12]:
Out[12]:
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying
warnings.warn("LATEXIFY environment variable not set, not latexifying")
Linear
In [13]:
Out[13]:
100%|██████████| 3000/3000 [00:01<00:00, 1904.53it/s, init loss: 49746.6406, avg. loss [2851-3000]: 2001.7004]
mean std median 2.5% 97.5% n_eff r_hat
a 138.31 0.41 138.32 137.55 139.15 931.50 1.00
b1 25.95 0.41 25.94 25.19 26.75 1101.82 1.00
sigma 9.36 0.29 9.36 8.84 9.99 949.32 1.00
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
Quadratic
In [14]:
Out[14]:
100%|██████████| 3000/3000 [00:01<00:00, 1599.75it/s, init loss: 68267.6406, avg. loss [2851-3000]: 1770.2694]
mean std median 2.5% 97.5% n_eff r_hat
a 146.05 0.36 146.03 145.33 146.71 1049.96 1.00
b1 21.75 0.30 21.75 21.18 22.32 886.88 1.00
b2 -7.79 0.28 -7.79 -8.33 -7.26 1083.62 1.00
sigma 5.78 0.17 5.78 5.46 6.14 973.21 1.00
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")
In [15]:
Out[15]:
100%|██████████| 3000/3000 [00:02<00:00, 1150.21it/s, init loss: 66975.6406, avg. loss [2851-3000]: 2008.9690]
mean std median 2.5% 97.5% n_eff r_hat
a 138.21 0.40 138.19 137.42 138.93 1049.96 1.00
b1 26.00 0.41 26.00 25.22 26.81 824.23 1.00
b2 0.08 0.04 0.07 0.02 0.16 935.60 1.00
sigma 9.40 0.28 9.40 8.83 9.91 947.68 1.00
/home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures
warnings.warn("set FIG_DIR environment variable to save figures")