Variational Inference

近似推断

在许多机器学习问题中,一个核心任务是给定观测数据 x,估计隐变量 z 的后验分布 p(z|x),或者估计某函数 f(z) 关于后验分布的期望 Ep(z|x)[f(z)]. 这样的估计过程就是所谓的推断 (inference) 问题。例如,在 EM 算法中我们就需要估计完整数据的对数似然关于后验分布的期望。

然而,许多问题的后验分布是不可解 (intractable) 的,这是因为根据贝叶斯公式,后验分布为: p(z|x)=p(x|z)p(z)p(x)=p(x|z)p(z)p(x|z)p(z)dz 其分母涉及对全体隐变量的积分,只有隐空间取值离散且维度较小时才有可能精确计算(比如高斯混合模型)。因此,人们提出了许多近似推断 (approximate inference) 方法来解决这个问题。

近似推断方法可以大体分为两类——确定性方法和随机性方法。确定性方法的典型代表就是变分推断 (variational inference),而随机性方法的典型代表是马尔可夫蒙特卡洛 (MCMC). 本文主要探究变分推断方法。

变分推断

设观测数据为 x,隐变量为 z,同 EM 算法的推导一样,为 z 引入一个分布 q(z),则有: logp(x)=q(z)logp(x)dz=q(z)logp(x,z)p(z|x)dz=q(z)log(p(x,z)q(z)q(z)p(z|x))dz=q(z)logp(x,z)q(z)dz+q(z)logq(z)p(z|x)dz=L(q)+KL(q(z)p(z|x)) 值得注意的是,这里没有将模型参数 θ 显式写出来。从频率派的角度而言,可以认为我们目前并不关心 θ,所以省略了;从贝叶斯的角度而言,可以认为参数被吸收进了隐变量 z 之中,因而不再单独列出来。

通过上面的推导,我们得到了对数似然的下界 L(q),于是可以通过最大化 L(q) 来最大化对数似然 logp(x). 显然,这个下界应该越紧越好,也就是说我们希望 KL 项越小越好。如果不限制 q(z) 的形式,那么当 q(z)=p(z|x) 时 KL 项达到最小值 0,事实上这就是 EM 算法做的事情。但是现在问题的基本假设是 p(z|x) 是不可解的,因此直接令 q(z)=p(z|x) 就行不通了。为此,我们考虑为 q(z) 引入一些假设以使得问题可解。值得注意的是,人为引入假设意味着 q(z) 的形式被限制了,因此这样求出的解并不是真正的最优解,这就是为什么变分推断属于近似推断而非精确推断。

根据引入的假设的不同,我们就得到了不同的变分推断方法,例如:

  • 平均场变分推断:假设 q(z)​ 可分解为各分量密度函数之乘积,则可采用坐标上升法优化之;
  • 随机梯度变分推断:假设 是以 为参数的分布族,则可采用随机梯度下降优化之。

顺便补充一点,由于 的自变量是概率密度函数 ,所以 是一个泛函。求泛函极值的方法被称作变分法,这就是变分推断这个名称的由来,也因此我们常称 为变分下界。

平均场变分推断

设隐变量 ,并且假设 可分解为各分量密度函数之乘积:

注: 也可以是一些分量形成的一个组,但本质一样的,不影响推导。

由于这种假设来源于统计力学中的平均场理论 (mean-field theory),因此称该假设下的变分推断为平均场变分推断

将上式代入 得: 为了最大化 ,可以逐个优化 并不断迭代。为此,固定 不动,将 视为变量,则:

其中 表示 ,而 表示对 求期望。注意到 是关于 的函数,可以将其视作能量函数并构建玻尔兹曼分布 于是: 因此最优解 为: 或写作: 正如上文所言,整个优化过程是一个轮转迭代的过程——首先初始化所有的 ,然后根据上式循环更新各个分量——即坐标上升法。可以证明迭代过程能够收敛。

随机梯度变分推断

对于泛函优化问题,一个常用的方法是将作为自变量的那个函数参数化,这样优化对象就从函数变成了参数,问题从而转化成了一般的函数优化问题。在变分推断的语境中,就是将 限制为一个以 为参数的分布族 ,那么此时 就变成了关于 的函数 ,于是使用随机梯度下降即可求解。这就是随机梯度变分推断 (SGVI)随机梯度贝叶斯方法 (SGVB)

具体而言,将 代入 得: 计算参数 的梯度: 其中第二项: 于是只剩下第一项。再利用 ,可以将第一项写作期望的形式: 训练时用蒙特卡洛采样估计期望,这样就估计出了梯度。然而,用这种方式估计的梯度的方差很大,导致训练不稳定,因此并不实用。

一种常见的解决方案是重参数化技巧,相关内容在之前的文章中有详细介绍。对于特定的分布(例如高斯分布或离散类别分布),我们可以构造函数 使得 ,满足 并且 ,其中 是一个简单的分布。于是有: 那么: 用蒙特卡洛采样估计期望即可估计出梯度。

References

  1. Bishop, Christopher. Pattern Recognition and Machine Learning. ↩︎
  2. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. ↩︎
  3. Kingma, Diederik P., and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013). ↩︎
  4. 【机器学习-白板推导系列(十二)-变分推断(Variational Inference)】 https://www.bilibili.com/video/BV1DW41167vr/?p=4&share_source=copy_web&vd_source=a43b4442e295a96065c7ae919b4866d3 ↩︎

Variational Inference
https://xyfjason.github.io/blog-main/2024/03/05/Variational-Inference/
作者
xyfJASON
发布于
2024年3月5日
许可协议