Path: blob/master/site/en-snapshot/probability/examples/HLM_TFP_R_Stan.ipynb
25118 views
Copyright 2018 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
Linear Mixed-Effect Regression in {TF Probability, R, Stan}
1 Introduction
In this colab we will fit a linear mixed-effect regression model to a popular, toy dataset. We will make this fit thrice, using R's lme4
, Stan's mixed-effects package, and TensorFlow Probability (TFP) primitives. We conclude by showing all three give roughly the same fitted parameters and posterior distributions.
Our main conclusion is that TFP has the general pieces necessary to fit HLM-like models and that it produces results which are consistent with other software packages, i.e.., lme4
, rstanarm
. This colab is not an accurate reflection of the computational efficiency of any of the packages compared.
2 Hierarchical Linear Model
For our comparison between R, Stan, and TFP, we will fit a Hierarchical Linear Model (HLM) to the Radon dataset made popular in Bayesian Data Analysis by Gelman, et. al. (page 559, second ed; page 250, third ed.).
We assume the following generative model:
In R's lme4
"tilde notation", this model is equivalent to:
log_radon ~ 1 + floor + (0 + log_uranium_ppm | county)
We will find MLE for using the posterior distribution (conditioned on evidence) of .
For essentially the same model but with a random intercept, see Appendix A.
For a more general specification of HLMs, see Appendix B.
3 Data Munging
In this section we obtain the radon
dataset and do some minimal preprocessing to make it comply with our assumed model.
3.1 Know Thy Data
In this section we explore the radon
dataset to get a better sense of why the proposed model might be reasonable.
Conclusions:
There's a long tail of 85 counties. (A common occurrence in GLMMs.)
Indeed is unconstrained. (So linear regression might make sense.)
Readings are most made on the -th floor; no reading was made above floor . (So our fixed effects will only have two weights.)
4 HLM In R
In this section we use R's lme4
package to fit probabilistic model described above.
NOTE: To execute this section, you must switch to an R
colab runtime.
5 HLM In Stan
In this section we use rstanarm to fit a Stan model using the same formula/syntax as the lme4
model above.
Unlike lme4
and the TF model below, rstanarm
is a fully Bayesian model, i.e., all parameters are presumed drawn from a Normal distribution with parameters themselves drawn from a distribution.
NOTE: To execute this section, you must switch an R
colab runtime.
Note: The runtimes are from a single CPU core. (This colab is not intended to be a faithful representation of Stan or TFP runtime.)
Note: Switch back to the Python TF kernel runtime.
Retrieve the point estimates and conditional standard deviations for the group random effects from lme4 for visualization later.
Draw samples for the county weights using the lme4 estimated means and standard deviations.
We also retrieve the posterior samples of the county weights from the Stan fit.
This Stan example shows how one would implement LMER in a style closer to TFP, i.e., by directly specifying the probabilistic model.
6 HLM In TF Probability
In this section we will use low-level TensorFlow Probability primitives (Distributions
) to specify our Hierarchical Linear Model as well as fit the unkown parameters.
6.1 Specify Model
In this section we specify the radon linear mixed-effect model using TFP primitives. To do this, we specify two functions which produce two TFP distributions:
make_weights_prior
: A multivariate Normal prior for the random weights (which are multiplied by to compue the linear predictor).make_log_radon_likelihood
: A batch ofNormal
distributions over each observed dependent variable.
Since we will be fitting the parameters of each of these distributions we must use TF variables (i.e., tf.get_variable
). However, since we wish to use unconstrained optimzation we must find a way to constrain real-values to achieve the necessary semantics, eg, postives which represent standard deviations.
The following function constructs our prior, where denotes the random-effect weights and the standard deviation.
We use tf.make_template
to ensure that the first call to this function instantiates the TF variables it uses and all subsequent calls reuse the variable's current value.
The following function constructs our likelihood, where denote response and evidence, denote fixed- and random-effect weights, and the standard deviation.
Here again we use tf.make_template
to ensure the TF variables are reused across calls.
Finally we use the prior and likelihood generators to construct the joint log-density.
6.2 Training (Stochastic Approximation of Expectation Maximization)
To fit our linear mixed-effect regression model, we will use a stochastic approximation version of the Expectation Maximization algorithm (SAEM). The basic idea is to use samples from the posterior to approximate the expected joint log-density (E-step). Then we find the parameters which maximize this calculation (M-step). Somewhat more concretely, the fixed-point iteration is given by:
where denotes evidence, some latent variable which needs to be marginalized out, and possible parameterizations.
For a more thorough explanation, see Convergence of a stochastic approximation version of the EM algorithms by Bernard Delyon, Marc Lavielle, Eric, Moulines (Ann. Statist., 1999).
To compute the E-step, we need to sample from the posterior. Since our posterior is not easy to sample from, we use Hamiltonian Monte Carlo (HMC). HMC is a Monte Carlo Markov Chain procedure which uses gradients (wrt state, not parameters) of the unnormalized posterior log-density to propose new samples.
Specifying the unnormalized posterior log-density is simple--it is merely the joint log-density "pinned" at whatever we wish to condition on.
We now complete the E-step setup by creating an HMC transition kernel.
Notes:
We use
state_stop_gradient=True
to prevent the M-step from backpropping through draws from the MCMC. (Recall, we needn't backprop through because our E-step is intentionally parameterized at the previous best known estimators.)We use
tf.placeholder
so that when we eventually execute our TF graph, we can feed the previous iteration's random MCMC sample as the the next iteration's chain's value.We use TFP's adaptive
step_size
heuristic,tfp.mcmc.hmc_step_size_update_fn
.
We now set-up the M-step. This is essentially the same as an optimization one might do in TF.
We conclude with some housekeeping tasks. We must tell TF that all variables are initialized. We also create handles to our TF variables so we can print
their values at each iteration of the procedure.
6.3 Execute
In this section we execute our SAEM TF graph. The main trick here is to feed our last draw from the HMC kernel into the next iteration. This is achieved through our use of feed_dict
in the sess.run
call.
Looks like after ~1500 steps, our estimates of the parameters have stabilized.
6.4 Results
Now that we've fit the parameters, let's generate a large number of posterior samples and study the results.
We now construct a box and whisker diagram of the random-effect. We'll order the random-effects by decreasing county frequency.
From this box and whisker diagram, we observe that the variance of the county-level random-effect increases as the county is less represented in the dataset. Intutively this makes sense--we should be less certain about the impact of a certain county if we have less evidence for it.
7 Side-by-Side-by-Side Comparison
We now compare the results of all three procedures. To do this, we will compute non-parameteric estimates of the posterior samples as generated by Stan and TFP. We will also compare against the parameteric (approximate) estimates produced by R's lme4
package.
The following plot depicts the posterior distribution of each weight for each county in Minnesota. We show results for Stan (red), TFP (blue), and R's lme4
(orange). We shade results from Stan and TFP thus expect to see purple when the two agree. For simplicity we do not shade results from R. Each subplot represents a single county and are ordered in descending frequency in raster scan order (i.e., from left-to-right then top-to-bottom).
8 Conclusion
In this colab we fit a linear mixed-effect regression model to the radon dataset. We tried three different software packages: R, Stan, and TensorFlow Probability. We concluded by plotting the 85 posterior distributions as computed by the three different software packages.
Appendix A: Alternative Radon HLM (Add Random Intercept)
In this section we describe an alternative HLM which also has a random intercept associated with each county.
In R's lme4
"tilde notation", this model is equivalent to:
log_radon ~ 1 + floor + (1 + log_county_uranium_ppm | county)
Appendix B: Generalized Linear Mixed-Effect Models
In this section we give a more general characterization of Hierarchical Linear Models than what is used in the main body. This more general model is known as a generalized linear mixed-effect model (GLMM).
GLMMs are generalizations of generalized linear models (GLMs). GLMMs extend GLMs by incorporating sample specific random noise into the predicted linear response. This is useful in part because it allows rarely seen features to share information with more commonly seen features.
As a generative process, a Generalized Linear Mixed-effects Model (GLMM) is characterized by:
ParseError: KaTeX parse error: Expected 'EOF', got '#' at position 66: …e{2.45cm}\text{#̲ for each rando…where:
In words, this says that every category of each group is associated with an iid MVN, . Although the draws are always independent, they are only indentically distributed for a group ; notice there is exactly one for each .
When affinely combined with a sample's group's features, , the result is sample-specific noise on the -th predicted linear response (which is otherwise ).
When we estimate we're essentially estimating the amount of noise a random-effect group carries which would otherwise drown out the signal present in .
There are a variety of options for the and inverse link function, . Common choices are:
,
ParseError: KaTeX parse error: Expected 'EOF', got '_' at position 81: …i), \text{total_̲count}=n_i), and,
.
For more possibilities, see the tfp.glm
module.