Variational Autoencoder

核心思想

VAE 是一种基于隐变量的生成模型,它将隐变量 \(z\in\mathbb R^d\)(一般采自正态分布)映射到 \(x\in\mathbb R^D\),并要求 \(x\) 的分布尽可能接近真实数据的分布。这个映射不是确定性的,因此可以写作概率分布的形式 \(p_\theta(x\vert z)\),其中 \(\theta\) 是模型参数。那么,模型能够生成的所有样本的分布为:

\[ p_\theta(x)=\int p_\theta(x\vert z)p(z)\mathrm dz \]

为了使得 VAE 生成真实样本的概率变大,考虑用极大似然法,即采样真实样本 \(x\) ,并最大化其对数似然: \[ L(\theta)=\log p_\theta(x)=\log\left(\int p_\theta(x\vert z)p(z)\mathrm dz\right) \] 我们发现 \(\log\)​ 里积分的形式阻碍了我们继续求解,这是否让你想起了什么——没错,EM 算法!

EM 算法回顾:为隐变量 \(z\) 引入分布 \(q(z)\),则对数似然可分解为 ELBO 和 KL 两项: \[\begin{align}L(\theta)&=\log p_\theta(x)\\&=\int_z q(z)\log p_\theta(x)\mathrm dz\\&=\int_z q(z)\log \left(\frac{p_\theta(x\vert z)p(z)}{p_\theta(z\vert x)}\cdot\frac{q(z)}{q(z)}\right)\mathrm dz\\&=\int_z q(z)\left[\log\frac{p(z)}{q(z)}+\log p_\theta(x\vert z)+\log\frac{q(z)}{p_\theta(z\vert x)}\right]\mathrm dz\\&=\underbrace{\mathbb E_{z\sim q(z)}[\log p_\theta(x\vert z)]-\mathrm {KL}(q(z)\|p(z))}_{\mathrm{ELBO}(\theta,q)}+\underbrace{\mathrm{KL}(q(z)\|p_\theta(z\vert x))}_{\mathrm {KL}}\end{align}\] 优化过程是迭代执行 E-step 和 M-step:

  • E-step:固定 \(\theta\),取 \(q(z)=p_\theta(z\vert x)\),即使得 \(\mathrm{KL}(q(z)\Vert p_\theta(z\vert x))=0\),也即让 ELBO 增大到与 \(L(\theta)\) 相等。
  • M-step:固定 \(q\),最大化 ELBO,从而达到优化 \(L(\theta)\)​ 的目的。

非常可惜的是,EM 算法无法直接应用于此,这是因为 E-step 要求我们能够表达出后验分布 \(p_\theta(z\vert x)\),但根据贝叶斯公式,后验分布为: \[ p_\theta(z\vert x)=\frac{p_\theta(x\vert z)p(z)}{p_\theta(x)}=\frac{p_\theta(x\vert z)p(z)}{\int p_\theta(x\vert z)p(z)\mathrm dz} \] 分母部分的积分是无法计算的,因此后验分布是不可解 (intractable) 的。换句话说,我们无法给出 E-step 的解析解。既然如此,我们只好用优化算法去求 E-step 的数值解。于是,原本的 EM 算法变成了:

  • E-step:固定 \(\theta\),最小化 KL 项;由于此时 \(L(\theta)\) 是定值,所以等价于最大化 ELBO,即 \(\max_q \text{ELBO}(\theta,q)\).
  • M-step:固定 \(q\),最大化 ELBO,即 \(\max_\theta\text{ELBO}(\theta,q)\).

需要注意的是,E-step 的优化变量是一个概率分布函数 \(q\),并不好直接优化(用相关术语来讲,ELBO 是关于函数 \(q\)泛函,优化泛函的方法统称为变分法)。为了解决这个问题,我们将 \(q(z)\) 限制为以 \(\phi\) 为参数的某分布族 \(q_\phi(z\vert x)\),这样优化变量就从函数 \(q\) 变成了参数 \(\phi\). 不过,由于我们限制了 \(q\) 的形式,所以即便能求出最优的参数 \(\phi\),也大概率不是 E-step 的最优解。理论上,为了尽可能逼近最优解,我们应该让选取的分布族越复杂越好;但是分布越复杂,优化也越难进行,因此这里存在一个 trade-off.

这里有一个小问题——为什么 \(q(z)\) 参数化后写作 \(q_\phi(z\vert x)\) 而不是 \(q_\phi(z)\)?首先,\(q\) 本来就是我们人为引入的,它是否以 \(x\) 为条件完全是我们的设计,并不与推导过程冲突;其次,EM 算法中找到的最优 \(q^\ast(z)=p_\theta(z\vert x)\),本就是依赖于 \(x\) 的,即不同的数据的最优 \(q(z)\) 是不一样的,只是没在记号中体现出来而已。

在下文我们将看到,VAE 的 \(p_\theta(x\vert z)\)\(q_\phi(z\vert x)\) 都由神经网络表示,因此我们只能用梯度下降来求解 E-step 和 M-step. 既然都是梯度下降,那就没有必要交替迭代了,直接两步合一步最大化 ELBO 即可:

\[ \max_{\theta,\phi}\quad\mathrm{ELBO}(\theta,q)=\mathbb E_{z\sim q_\phi(z\vert x)}[\log p_\theta(x\vert z)]-\mathrm{KL}(q_\phi(z\vert x)\|p(z)) \] 取个负号就是 VAE 的损失函数: \[ \mathcal L=\mathbb E_{z\sim q_\phi(z\vert x)}[-\log p_\theta(x\vert z)]+\mathrm{KL}(q_\phi(z\vert x)\|p(z)) \]

我们看到,VAE 的损失函数由两部分构成:

  1. \(\mathbb E_{z\sim q_\phi(z\vert x)}[-\log p_\theta(x\vert z)]\)重构项,最大化 \(x\) 被重构的似然;
  2. \(\mathrm{KL}(q_\phi(z\vert x)\Vert p(z))\) 可以视作正则项,让估计的后验分布逼近先验分布。

怎么理解呢?假设只有重构项,可以想见为了更好的重构,网络会尽可能地减小不确定性——一方面让分布 \(q_\phi(z\vert x)\) 的方差很小,基本集中在一个点上;另一方面对不同的 \(x\) 让分布 \(q_\phi(z\vert x)\) 均值差异很大,以便更好地区分不同 \(x\) 编码出来的 \(z\)(如下左图所示)。如此一来,VAE 就退化成一般的 autoencoder 了;而正则项强制让 \(q_\phi(z\vert x)\) 逼近 \(p(z)\),一个我们预先设定的分布,就可以约束上述两点的发生(如下右图所示),所以两项存在一种“对抗”的感觉。

source: https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

实例化

我们现在得到了 VAE 的损失函数,但其中的 \(p(z)\)\(p_\theta(x\vert z)\)\(q_\phi(z\vert x)\) 具体是什么并没有说明。要真正落地,还需要把它们“实例化”。

Encoder network

首先考虑损失函数中 \(\mathrm{KL}(q_\phi(z\vert x)\Vert p(z))\) 一项。由于正态分布的 KL 散度相对来说好算一些,我们希望 \(p(z)\)\(q_\phi(z\vert x)\) 都是正态分布:

  • \(p(z)\):简便起见,直接取为 \(\mathcal N(0, I)\) 标准正态分布;
  • \(q_\phi(z\vert x)\):考虑到它依赖于 \(x\),所以应该是 \(\mathcal N(\mu_\phi(x),\Sigma_\phi(x))\) 的形式。可是 \(\mu_\phi(x),\Sigma_\phi(x)\) 用怎样的函数才好呢?在深度学习的时代,这种开放性问题就无脑上神经网络呗!这就是 VAE 中的 encoder network.

\(p(z)\)\(q_\phi(z\vert x)\) 代入,得到: \[ \begin{align} \mathrm{KL}(q_\phi(z\vert x)\|p(z))&=\mathrm{KL}(\mathcal N(\mu_\phi(x),\Sigma_\phi(x))\|\mathcal N(0,I))\\ &=\frac{1}{2}\left[\mathrm{tr}(\Sigma_\phi(x))+\mu_\phi(x)^T\mu_\phi(x)-d-\log\det(\Sigma_\phi(x)) \right] \end{align} \] 实操时,我们一般会做简化:取 \(\Sigma_\phi(x)=\mathrm{diag}(\sigma_\phi^2(x))\),即各分量独立,协方差矩阵只有对角线有值,那么: \[ \begin{align} \mathrm{KL}(q_\phi(z\vert x)\|p(z))&=\mathrm{KL}(\mathcal N(\mu_\phi(x),\mathrm{diag}(\sigma_\phi^2(x)))\| \mathcal N(0,I))\\ &=\frac{1}{2}\sum_{i=1}^d\left(\mu_\phi^2(x)_i+\sigma_\phi^2(x)_i-\log \sigma_\phi^2(x)_i-1\right) \end{align} \] 一个小细节,由于方差一定非负,我们可以视 encoder 的输出为 \(\log \sigma_\phi^2(x)\) 而非 \(\sigma_\phi^2(x)\).

Decoder network

接下来考虑损失函数中 \(\mathbb E_{z\sim q_\phi(z\vert x)}[-\log p_\theta(x\vert z)]\) 一项,也就是生成模型 \(p_\theta(x\vert z)\) 的形式,常见的有两种选择:

伯努利分布:输出只有 0/1,所以适用于生成二值数据(比如黑白图像)。设伯努利分布的参数为 \(\rho_\theta(z)\in\mathbb [0,1]^D\),那么: \[ p_\theta(x_i\vert z)=\begin{cases}\rho_\theta(z)_i,&x_i=1\\1-\rho_\theta(z)_i,&x_i=0\end{cases} \] 于是: \[ -\log p_\theta(x\vert z)=-\sum_{i=1}^D\log p_\theta(x_i\vert z)=\sum_{i=1}^D\left[-x_i\log \rho_\theta(z)_i-(1-x_i)\log(1-\rho_\theta(z)_i)\right] \]BCELoss.

正态分布:设参数为 \(\mu_\theta(z)\in\mathbb R^D,\Sigma_\theta(z)\in\mathbb R^{D\times D}\),那么: \[ p_\theta(x\vert z)=\frac{1}{(2\pi)^{D/2}(\det\Sigma_\theta(z))^{1/2}}\exp\left(-\frac{1}{2}(x-\mu_\theta(z))^T\Sigma_\theta^{-1}(z)(x-\mu_\theta(z))\right) \] 于是: \[ -\log p_\theta(x\vert z)=\frac{D}{2}\log(2\pi)+\frac{1}{2}\log\det\Sigma_\theta(z)+\frac{1}{2}(x-\mu_\theta(z))^T\Sigma_\theta^{-1}(z)(x-\mu_\theta(z)) \] 实操时,我们一般会取 \(\Sigma_\theta(z)=\sigma^2I\),即各分量独立且方差固定为某常数。那么: \[ -\log p_\theta(x\vert z)=\frac{D}{2}\log(2\pi)+\frac{D}{2}\log\sigma^2+\frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2 \] 前两项是定值,与优化无关,所以优化目标就是: \[ \frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2 \]MSELoss.

注意:上式中 \(\Vert\cdot\Vert^2\) 是欧氏距离,如果实现时直接用 nn.MSELoss 会对 CHW 维也取平均(假设在图像上训练),结果是实际欧氏距离的 \(1/CHW\),导致重构项和 KL 项权重失衡。所以实现时要么只对 mini-batch 取平均、CHW 维求和,要么全取平均,但是 KL 项加个系数缩小。

与 encoder network 同理,\(\rho_\theta(z)\) 或者 \(\mu_\theta(z)\) 直接由一个 decoder network 得到。

至此我们算出了 \(-\log p_\theta(x\vert z)\). 注意损失函数里还要对 \(z\) 取期望,所以理论上我们应该多次采样做蒙特卡洛估计,但实践中只需要采一个 \(z\) 就足够了。

Loss 权重

考虑实践中最常用的设置:

  • \(p(z)\)\(\mathcal N(0,I)\)
  • \(q_\phi(z\vert x)\)\(\mathcal N(\mu_\phi(x),\Sigma_\phi(x))\),且 \(\Sigma_\phi(x)=\mathrm{diag}(\sigma_\phi^2(x))\)
  • \(p_\theta(x\vert z)\)\(\mathcal N(\mu_\theta(z),\Sigma_\theta(z))\),且 \(\Sigma_\theta(z)=\sigma^2I\),其中 \(\sigma^2\)事先取定的一个超参数

那么根据前两小节的推导,损失函数是: \[ \mathcal L=\underbrace{\frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2}_\text{Reconstruction}+\underbrace{\frac{1}{2}\sum_{i=1}^d\left(\mu_\phi^2(x)_i+\sigma_\phi^2(x)_i-\log \sigma_\phi^2(x)_i-1\right)}_\text{KL Regularization},\quad z\sim\mathcal N(\mu_\phi(x),\text{diag}(\sigma_\phi^2(x))) \] 可以看到,重构项和 KL 正则项由超参数 \(\sigma^2\) 加权。\(\sigma^2\) 越小,重构项权重越大,意味着结果更真实,但泛化性下降。一般直接取 \(\sigma^2=1\) 即可。

重参数化技巧

重参数化技巧在之前的文章中已经介绍过了,所以这里不再赘述。简单说来,就是现在 \(z\) 是从 \(q_\phi(z\vert x)\sim\mathcal N(\mu_\phi(x),\mathrm{diag}(\sigma_\phi^2(x)))\) 中采样的,但梯度无法经过采样传播到参数 \(\phi\)。解决方法很简单,先从 \(\mathcal N(0,I)\) 中采样 \(z'\),再计算 \(z=\mu_\phi(x)+z'\cdot\sigma_\phi(x)\) 即可。

代码实现

Github repo: https://github.com/xyfJASON/vaes-pytorch

放个结果:

参考资料

  1. 苏剑林. (Mar. 18, 2018). 《变分自编码器(一):原来是这么一回事 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/5253 ↩︎
  2. 苏剑林. (Mar. 28, 2018). 《变分自编码器(二):从贝叶斯观点出发 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/5343 ↩︎
  3. 原来VAE是这么回事(从EM到VAE) - 市井小民的文章 - 知乎 https://zhuanlan.zhihu.com/p/368959795 ↩︎
  4. EM的升级打怪之路:EM-变分EM-VAE(part1) - Young Zicon的文章 - 知乎 https://zhuanlan.zhihu.com/p/418203971 ↩︎
  5. VAE 的前世今生:从最大似然估计到 EM 再到 VAE - AI科技评论的文章 - 知乎 https://zhuanlan.zhihu.com/p/443540253 ↩︎
  6. Weng, Lilian. From Autoencoder to Beta-VAE. https://lilianweng.github.io/posts/2018-08-12-vae/ ↩︎
  7. Doersch, Carl. Tutorial on variational autoencoders. arXiv preprint arXiv:1606.05908 (2016). ↩︎
  8. Joseph Rocca. Understanding Variational Autoencoders (VAEs). https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73 ↩︎

Variational Autoencoder
https://xyfjason.github.io/blog-main/2022/09/17/Variational-Autoencoder/
作者
xyfJASON
发布于
2022年9月17日
许可协议