Path: blob/master/site/zh-cn/probability/examples/Generalized_Linear_Models.ipynb
25118 views
Copyright 2018 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
在此笔记本中,我们将通过一个工作示例来介绍广义线性模型。我们使用两种算法以两种不同的方式解决此示例,以在 TensorFlow Probability 中有效地拟合 GLM:针对密集数据使用 Fisher 得分算法,针对稀疏数据使用坐标近端梯度下降算法。我们将拟合系数与真实系数进行对比,在坐标近端梯度下降算法下则与 R 语言的类似 glmnet
算法的输出进行对比。最后,我们提供了 GLM 一些关键属性的进一步数学细节和推导。
背景
广义线性模型 (GLM) 是一种封装在转换(联系函数)中并配备了指数族的响应分布的线性模型 () 。联系函数和响应分布的选择非常灵活,这为 GLM 赋予了出色的表达性。在下面的“GLM 事实的推导”中可以找到完整的详细信息,包括以明确的表示法对 GLM 构建的所有定义和结果的有序介绍。我们总结如下:
在 GLM 中,响应变量 的预测分布与观察到的预测变量 的向量相关联。分布形式如下:
其中, 是参数(“权重”), 是表示离散度(“方差”)的超参数,、、、 由用户指定模型族表征。
的均值取决于 的线性响应 和(逆)联系函数,即:
其中 是所谓的联系函数。在 TFP 中,联系函数和模型族的选择由 tfp.glm.ExponentialFamily
子类共同指定。示例包括:
tfp.glm.Normal
,又名“线性回归”tfp.glm.Bernoulli
,又名“逻辑回归”tfp.glm.Poisson
,又名“泊松回归”tfp.glm.BernoulliNormalCDF
,又名“概率回归”。
TFP 更喜欢根据 Y
的分布而非联系函数来命名模型族,因为 tfp.Distribution
已经是一等公民。如果 tfp.glm.ExponentialFamily
子类名称包含第二个单词,则表示非正则联系函数。
GLM 具有几项可有效地实现最大似然 estimator 的显著特性。这些特性中最主要的是为对数似然函数 梯度以及 Fisher 信息矩阵提供了简单的公式,它是在相同预测变量下对响应重新采样时负对数似然函数的 Hessian 的期望值。即:
其中 是矩阵,其第 行是第 个数据样本的预测变量向量; 是向量,其第 个坐标是第 个数据样本的观察到的响应。这里(粗略地讲), 和 ,粗体表示这些函数的矢量化。有关这些期望和方差的分布的完整详细信息,请参阅下方的“GLM 事实的推导”。
示例
在本部分中,我们将简要介绍和展示 TensorFlow Probability 中的两种内置 GLM 拟合算法:Fisher 得分 (tfp.glm.fit
) 和坐标近端梯度下降 (tfp.glm.fit_sparse
)。
合成数据集
让我们假装加载一些训练数据集。
注:连接到本地运行时。
在此笔记本中,我们使用本地文件在 Python 和 R 内核之间共享数据。要启用此共享,请在您具备本地文件读写权限的同一台计算机上使用运行时。
不使用 L1 正则化
函数 tfp.glm.fit
实现 Fisher 得分,它采用一些参数:
model_matrix
=response
=model
= 可调用对象,给定参数 ,返回三元组 。
我们建议该 model
为 tfp.glm.ExponentialFamily
类的实例。有几种预制的实现可用,对于大多数常见的 GLM,不需要自定义代码。
数学细节
Fisher 得分法是对牛顿法的修改,用于寻找最大似然估计
普通牛顿法,搜索对数似然函数梯度的零点,将遵循更新规则
$$ \beta^{(t+1)}_{\text{Newton}} := \beta^{(t)}
其中 是用于控制步长的学习率。
在 Fisher 得分法中,我们将 Hessian 替换为负的 Fisher 信息矩阵:
$$ \begin{align*} \beta^{(t+1)} &:= \beta^{(t)}
[注:此处 是随机的,而 仍是观察到的响应的向量。]
通过下文“将 GLM 参数拟合到数据”中的公式,可将其简化为
使用 L1 正则化
tfp.glm.fit_sparse
基于 Yuan, Ho and Lin 2012 中的算法实现了更适合稀疏数据集的 GLM 拟合器。特性包括:
L1 正则化
不使用矩阵求逆
只需少量梯度和 Hessian 评估。
我们首先展示代码的示例用法。算法的细节会在下文“tfp.glm.fit_sparse
的算法细节”中进一步阐述。
请注意,学习的系数与真实系数具有相同的稀疏模式。
对比 R 语言的 glmnet
我们将坐标近端梯度下降算法的输出与使用类似算法的 R 语言的 glmnet
的输出进行对比。
注:要执行此部分,您必须切换到 R colab 运行时。
比较 R、TFP 和真实系数(注:回到 Python 内核)
tfp.glm.fit_sparse
的算法细节
我们将算法依次呈现为对牛顿法的三种修改形式。在每种形式中, 的更新规则都基于向量 和矩阵 ,它们会逼近对数似然函数的梯度和 Hessian。在步骤 中,我们选择坐标 进行更改,并根据更新规则更新 :
$$ \begin{align*} u^{(t)} &:= \frac{ \left( s^{(t)} \right){j^{(t)}} }{ \left( H^{(t)} \right){j^{(t)},, j^{(t)}} } [3mm] \beta^{(t+1)} &:= \beta^{(t)}
此更新是一种类似牛顿法的步骤,学习率为 。除了最后一部分(L1 正则化),下面的修改仅在 和 的更新方式上有所不同。
起点:坐标牛顿法
在坐标牛顿法中,我们将 和 设置为对数似然函数的真实梯度和 Hessian:
只需少量梯度和 Hessian 评估
对数似然函数的梯度和 Hessian 的计算通常十分消耗算力,因此通常值得对其采用逼近算法。我们可以如下处理:
通常,将 Hessian 逼近为局部常值,并使用(逼近)Hessian 将梯度逼近为一阶:
有时,可执行上述“普通”更新步骤,将 设置为对数似然函数的精确梯度并将 设置为其精确 Hessian,在 处评估。
使用负 Fisher 信息矩阵代替 Hessian
为了进一步降低普通更新步骤的算力成本,我们可以将 设置为负 Fisher 信息矩阵(使用下文“将 GLM 参数拟合到数据”中的公式可以有效计算),而非确切的 Hessian:
通过近端梯度下降求解 L1 正则化
为包含 L1 正则化,我们将更新规则
$$ \beta^{(t+1)} := \beta^{(t)}
替换为更通用的更新规则
其中 ParseError: KaTeX parse error: Extra } at position 9: r_{\text}̲ > 0 是提供的常值(L1 正则化系数), 是软阈值算子,定义为
此更新规则具有以下两项令人欣喜的性质,解释如下:
在极限情况 ParseError: KaTeX parse error: Extra } at position 9: r_{\text}̲ \to 0(即不使用 L1 正则化)下,此更新规则与原始更新规则相同。
此更新规则可以解释为应用邻近算子,其不动点是 L1 正则化最小化问题的解
其中 ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 15: r_{\text{l0}} &̲gt; 0 是提供的常值(L1 正则化系数), 是软阈值算子,定义为
此更新规则具有以下两项令人欣喜的性质,解释如下:
在极限情况 (即不使用 L1 正则化)下,此更新规则与原始更新规则相同。
此更新规则可以解释为应用邻近算子,其不动点是 L1 正则化最小化问题的解
$$ \underset{\beta - \beta^{(t)} \in \text{span}{ \text{onehot}(j^{(t)}) }}{\text{arg min}} \left( -\ell(\beta ,;, \mathbf{x}, \mathbf{y})
退化情况 可恢复原始更新规则
要查看 (1),请注意如果 则 ,因此
因此
不动点为正则化最大似然估计的邻近算子
要查看 (2),首先要注意(参见 Wikipedia)对于任何 ParseError: KaTeX parse error: Expected 'EOF', got '&' at position 8: \gamma &̲gt; 0,更新规则
均满足 (2),其中 是邻近算子(参见 Yu,其中该算子表示为 )。上述方程的右半部分在此处计算:
$$ \left(\beta_{\text{exact-prox}, \gamma}^{(t+1)}\right)_{j^{(t)}}
特别地,设置 ParseError: KaTeX parse error: Extra } at position 48: …alpha, r_{\text}̲}{\left(H^{(t)}…(注:只要负对数似然函数是凸函数,),我们得到更新规则
$$ \left(\beta_{\text{exact-prox}, \gamma^{(t)}}^{(t+1)}\right)_{j^{(t)}}
特别地,设置 (注:只要负对数似然函数是凸函数,),我们得到更新规则
$$ \left(\beta_{\text{exact-prox}, \gamma^{(t)}}^{(t+1)}\right)_{j^{(t)}}
然后,我们将精确梯度 替换为其近似值 ,得到
因此
GLM 事实的推导
在本部分中,我们将详细说明并推导出在之前几部分中使用的 GLM 相关结果。然后,我们将使用 TensorFlow 的 gradients
对导出的对数似然函数和 Fisher 信息的梯度公式进行数值验证。
得分和 Fisher 信息
考虑由参数向量 参数化的概率分布族,其概率密度为 ParseError: KaTeX parse error: Expected '}', got '\right' at position 24: …\cdot | \theta)\̲r̲i̲g̲h̲t̲}_{\theta \in \…。参数向量 处的结果 的得分定义为 的对数似然函数的梯度(在 处评估),即:
声明:得分的期望值为零
在非极端正则条件(允许我们传递积分符号内取微分)下,
证明
已知
其中我们使用了:(1) 微分连锁律、(2) 期望的定义、(3) 传递积分符号内取微分(使用正则条件)、(4) 概率密度的积分为 1。
声明(Fisher 信息):得分方差等于对数似然函数的 Hessian 负期望值
在非极端正则条件(允许我们传递积分符号内取微分)下,
$$ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \text{score}(Y, \theta_0) \text{score}(Y, \theta_0)^\top \right]
其中 表示 Hessian 矩阵,其 项为 。
此方程的左半部分称为参数向量 处的族 ParseError: KaTeX parse error: Expected '}', got '\right' at position 24: …\cdot | \theta)\̲r̲i̲g̲h̲t̲}_{\theta \in \… 的 Fisher 信息。
声明证明
已知
$$ \begin{align*} \mathbb{E}{Y \sim p(\cdot | \theta=\theta_0)}\left[ \left(\nabla\theta^2 \log p(Y | \theta)\right){\theta=\theta_0} \right] &\stackrel{\text{(1)}}{=} \mathbb{E}{Y \sim p(\cdot | \theta=\theta_0)}\left[ \left(\nabla_\theta^\top \frac{ \nabla_\theta p(Y | \theta) }{ p(Y|\theta) }\right){\theta=\theta_0} \right] \ &\stackrel{\text{(2)}}{=} \mathbb{E}{Y \sim p(\cdot | \theta=\theta_0)}\left[ \frac{ \left(\nabla^2_\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) }
其中我们使用了 (1) 微分链式法则、(2) 微分商法则、(3)再次反向使用链式法则。
要完成证明,只需证明
为此,我们传递积分符号内取微分两次:
其中我们使用了 (1) 微分链式法则、(2) 微分商法则、(3)再次反向使用链式法则。
要完成证明,只需证明
为此,我们传递积分符号内取微分两次:
对数配分函数的导数相关引理
如果 、 和 是标量值函数,则 二次可微,使分布族 ParseError: KaTeX parse error: Expected '}', got '\right' at position 24: …\cdot | \theta)\̲r̲i̲g̲h̲t̲}_{\theta \in \… 定义为
满足非极端正则条件,允许传递在对 的积分符号内取对 的微分,然后
and
(这里 表示微分,所以 和 是 的一阶导数和二阶导数。)
证明
对于此分布族,已知 。然后第一个方程遵循以下事实 。接下来,已知
过度离散指数族
过度离散指数族(标量)是一种分布族,其密度为
其中 和 是已知的标量值函数, 和 是标量参数。
[注: 是超定的:对于任何 ,函数 完全由此约束定义:对所有 ,均满足 \int p_{\text{OEF}(m, T)}(y\ |\ \theta, \phi=\phi_0), dy = 1$。由不同的 值求得的 必须全部相同,这对 和 函数施加了约束。]
充分统计量的均值和方差
在与“对数配分函数的导数相关引理”部分的相同条件下,已知
$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y) \right]
and
$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y) \right]
证明
根据“对数配分函数的导数相关引理”,已知
$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi} \right]
and
$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi} \right]
结果满足期望为线性 () 并且方差为二次齐次式 ()。
广义线性模型
在广义线性模型中,响应变量 的预测分布与观察到的预测变量 的向量相关联。该分布是过度离散指数族的成员,参数 被替换为 ,其中 是已知函数, 是所谓的线性响应, 是要学习的参数(回归系数)的向量。通常,也可以学习离散参数 ,但在我们的设置中,我们将 视为已知。因此我们设置如下
其中模型结构的特征在于分布 和将线性响应转换为参数的函数 。
传统上,从线性响应 到均值 的映射表示为
此映射需为一对一映射,它的反函数 被称为此 GLM 的联系函数。通常,人们通过命名其联系函数及其分布族来描述 GLM,例如,“具有伯努利分布和 logit 联系函数的 GLM”(也称为逻辑回归模型)。为了完全表征 GLM,还必须指定函数 。如果 为恒等函数,则称 是正则联系函数。
声明:用充分统计量表达
定义
and
然后,已知
证明
根据“充分统计量的均值和方差”,已知
Differentiating with the chain rule, we obtain
根据“充分统计量的均值和方差”
结论如下。
将 GLM 参数拟合到数据
上面推导出的属性非常适合将 GLM 参数 拟合到数据集。诸如 Fisher 得分法之类的拟牛顿法依赖于对数似然函数的梯度和 Fisher 信息,我们现在将展示对于 GLM 可以特别有效地计算这些信息。
假设我们已经观察到预测变量向量 和相关的标量响应 。在矩阵形式中,我们会说我们观察到了预测变量 和响应 ,其中 是第 行为 的矩阵, 是第 个元素为 的向量。参数 的对数似然函数为
对于单个数据样本
为了简化表示法,让我们首先考虑单个数据点 时的情况;然后我们将通过可加性扩展到一般情况。
梯度
已知
因此,根据链式法则,
另外,根据充分统计量的均值和方差”,已知 。因此,根据“声明:用充分统计量表达 ”,可得
Hessian
由乘积法则二次求导,得到
Fisher 信息
根据“充分统计量的均值和方差”,已知
因此
对于多个数据样本
我们现在将 情况扩展到一般情况。让 表示第 i 个数据样本的线性响应的向量。让 (resp. , resp. ) 表示对每个坐标应用标量值函数 (resp. , resp. ) 的广播(矢量化)函数。然后可得
and
其中分数表示逐元素相除。
以数值方式验证公式
我们现在使用 tf.gradients
以数值方式验证上述对数似然函数的梯度的公式,并使用 tf.hessians
通过蒙特卡洛估计验证 Fisher 信息的公式:
参考文献
[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. Journal of Machine Learning Research, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf
[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106
[3]: Wikipedia Contributors. Proximal gradient methods for learning. Wikipedia, The Free Encyclopedia, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning
[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf