从VAE到DDPM

\[ \newcommand{\E}{\mathbb E} \newcommand{\KL}{\mathrm{KL}} \newcommand{\calN}{\mathcal N} \newcommand{\x}{\mathbf x} \newcommand{\z}{\mathbf z} \newcommand{\coloneqq}{\mathrel{\mathrel{\vcenter{:}}=}} \]

VAE 回顾

之前的文章中,我们详细地梳理了一遍 VAE,这里做一个简单回顾。

source:Angus Turner. Diffusion Models as a kind of VAE

在 VAE 中,为了最大化对数似然: \[ L(\theta)=\log p_\theta(x)=\log\left(\int_zp_\theta(x\vert z)p(z)\mathrm dz\right) \] 我们引入变分后验 \(q_\phi(z\vert x)\)\[ \begin{align} L(\theta)&=\log p_\theta(x)\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log p_\theta(x)\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{p_\theta(z\vert x)}\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\left(\frac{p_\theta(x,z)}{p_\theta(z\vert x)}\cdot\frac{q_\phi(z\vert x)}{q_\phi(z\vert x)}\right)\right]\\ &=\underbrace{\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{q_\phi(z\vert x)}\right]}_\text{ELBO}+\underbrace{\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{q_\phi(z\vert x)}{p_\theta(z\vert x)}\right]}_{\KL(q_\phi(z\vert x)\| p_\theta(z\vert x))}\\ &\geq \text{ELBO} \end{align}\tag{1}\label{vae} \] 得到证据下界 ELBO,通过最大化 ELBO 来最大化对数似然。进一步地,ELBO 还可以拆写成重构项和 KL 正则项: \[ \begin{align} \text{ELBO}&=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{q_\phi(z\vert x)}\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x\vert z)p(z)}{q_\phi(z\vert x)}\right]\\ &=\underbrace{\E_{z\sim q_\phi(z\vert x)}[\log p_\theta(x\vert z)]}_{\text{reconstruction}}-\underbrace{\KL(q_\phi(z\vert x)\| p(z))}_{\text{regularization}}\\ \end{align}\label{elbo-vae}\tag{2} \] 为了计算上的方便,实践中常将 \(p(z),q_\phi(z\vert x),p_\theta(x\vert z)\) 都取为正态分布,具体而言,它们分别是: \[ \begin{align} &p(z)\coloneqq \calN(z;0,I)\\ &q_\phi(z\vert x)\coloneqq\calN\left(z;\mu_\phi(x),\mathrm{diag}(\sigma_\phi^2(x))\right)\\ &p_\theta(x\vert z)\coloneqq \calN(\mu_\theta(z),\sigma^2 I)&&\sigma\text{ is a constant}\\ \end{align} \] 代入 \(\eqref{elbo-vae}\) 式即可得到损失函数: \[ \mathcal L=\underbrace{\frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2}_\text{reconstruction}+\underbrace{\frac{1}{2}\sum_{i=1}^d\left(\mu_\phi^2(x)_i+\sigma_\phi^2(x)_i-\log \sigma_\phi^2(x)_i-1\right)}_\text{KL regularization} \] 正态分布的假设虽然便于计算,但限制了变分后验的形式,使之不能很好地近似真实后验分布,从而限制了 VAE 的能力。因此,一个自然的改进思路就是用更强的方式去建模变分后验——比如再套一个 VAE?

双层 VAE

把 VAE 中的单个隐变量 \(z\) 换成两个隐变量 \(z_1,z_2\),形成如下马尔可夫链:

source:Angus Turner. Diffusion Models as a kind of VAE

虽然有两个隐变量,但如果把它们视为一个整体,那证据下界 ELBO 的推导过程与 \(\eqref{vae}\) 式没有什么本质不同,因此我们只需将 \(\eqref{elbo-vae}\) 式中的 \(z\) 换做 \(z_1,z_2\) 就得到了双层 VAE 的 ELBO: \[ \text{ELBO}_\text{2-layers}=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2\vert x)}\right]=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x\vert z_1)p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_2\vert z_1)q_\phi(z_1\vert x)}\right]\label{elbo-2}\tag{3} \]

我们依旧希望从 ELBO 中拆解出重构项和正则项。重构项比较简单,只需要把 \(p_\theta(x\vert z_1)\) 拆出来就是了,问题在于剩下的一坨应该如何处理。破局的关键在于利用马尔可夫性质把分母上的 \(q_\phi(z_2\vert z_1)\) 改写作 \(q_\phi(z_2\vert z_1,x)\),然后用贝叶斯公式: \[ q_\phi(z_2\vert z_1)=q_\phi(z_2\vert z_1,x)=\frac{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}{q_\phi(z_1\vert x)} \] 代回 \(\eqref{elbo-2}\) 式得: \[ \begin{align} \text{ELBO}_\text{2-layers}&=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x\vert z_1)p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_2\vert z_1)q_\phi(z_1\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)p(z_2)}{\frac{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}{q_\phi(z_1\vert x)}q_\phi(z_1\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_2\sim q_\phi(z_2\vert x)}\left[\log\frac{p(z_2)}{q_\phi(z_2\vert x)}\right]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)}{q_\phi(z_1\vert z_2,x)}\right]\\ &=\underbrace{\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]}_{\text{reconstruction}}-\underbrace{\KL(q_\phi(z_2\vert x)\| p(z_2))}_{\text{regularization}}-\underbrace{\E_{z_2\sim q_\phi(z_2\vert x)}[\KL(q_\phi(z_1\vert z_2,x)\|p_\theta(z_1\vert z_2))]}_{\text{matching}}\\ \end{align} \] 可以看见,双层 VAE 的 ELBO 由三部分组成:重构项、正则项以及匹配项。对比 VAE 的 ELBO,我们可以认为 \(z_1,z_2\) 分担了原本一个隐变量的工作。具体而言,在 VAE 中,隐变量既要重构,又要逼近先验分布,且这两个任务是有点矛盾的;而现在,重构依靠 \(z_1\) 完成,逼近先验依靠 \(z_2\) 完成,二者由匹配项联系起来,从而增加了模型的灵活性。

DDPM

在双层 VAE 的基础上,我们能再多加几层吗?

source:Angus Turner. Diffusion Models as a kind of VAE

如上图所示,为方便叙述,我们引入两个称呼:

  • 称从 \(\x_0\)\(\x_T\) 的马尔可夫链为前向过程 (forward process)扩散过程 (diffusion process)
  • 称从 \(\x_T\)\(\x_0\) 的马尔可夫链为逆向过程 (reverse process)去噪过程 (denoising process).

用概率图模型的术语来说,前向过程对应 inference model,逆向过程对应 generative model. 另外,为书写上的方便,下文将 \(\x_l,\cdots,\x_r\) 简写为 \(\x_{l:r}\).

同双层 VAE 一样的道理,把 \(\x_{1:T}\) 整体看作 VAE 中的隐变量,代入 \(\eqref{elbo-vae}\) 式就可以得到 DDPM 的 ELBO: \[ \text{ELBO}=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p_\theta(\x_{0:T})}{q(\x_{1:T}\vert \x_0)}\right]=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert\x_t)}{\prod_{t=1}^T q(\x_t\vert \x_{t-1})}\right]\tag{4}\label{elbo-ddpm} \] 接下来的推导技巧和双层 VAE 如出一辙,即将分母中的 \(q(\x_t\vert\x_{t-1})\) 写作 \(q(\x_t\vert\x_{t-1},\x_0)\),然后使用贝叶斯公式,即可进行大量的消元\[ \begin{align} \prod_{t=1}^T q(\x_t\vert \x_{t-1})&=q(\x_1\vert\x_0)\prod_{t=2}^T q(\x_t\vert \x_{t-1})=q(\x_1\vert\x_0)\prod_{t=2}^T q(\x_t\vert \x_{t-1},\x_0)\\ &=q(\x_1\vert\x_0)\prod_{t=2}^T \frac{q(\x_t\vert\x_0) q(\x_{t-1}\vert\x_t,\x_0)}{q(\x_{t-1}\vert\x_0)}=q(\x_T\vert\x_0)\prod_{t=2}^T q(\x_{t-1}\vert\x_t,\x_0) \end{align} \] 代回 \(\eqref{elbo-ddpm}\) 式得: \[ \begin{align} \text{ELBO}&=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert\x_t)}{\prod_{t=1}^T q(\x_t\vert \x_{t-1})}\right]\\ &=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)p_\theta(\x_0\vert\x_1)\prod_{t=2}^Tp_\theta(\x_{t-1}\vert\x_t)}{q(\x_T\vert\x_0)\prod_{t=2}^T q(\x_{t-1}\vert\x_t,\x_0)}\right]\\ &=\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]+\E_{\x_T\sim q(\x_T\vert \x_0)}\left[\log\frac{p(\x_T)}{q(\x_T\vert\x_0)}\right]+\sum_{t=2}^T\E_{\x_{t-1},\x_t\sim q(\x_{t-1},\x_t\vert \x_0)}\left[\log\frac{p_\theta(\x_{t-1}\vert\x_t)}{q(\x_{t-1}\vert\x_t,\x_0)}\right]\\ &= \underbrace{\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]}_{\text{reconstruction}}- \underbrace{\KL(q(\x_T\vert \x_0)\|p(\x_T))}_{\text{regularization}}- \sum_{t=2}^T\underbrace{\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\KL(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right]}_{\text{matching}} \end{align}\tag{5}\label{obj} \] 同样出现了重构项、正则项和匹配项。重构项要求 \(\x_1\) 能够重构 \(\x_0\),正则项要求 \(\x_T\) 的后验分布逼近先验分布,而匹配项则建立起相邻两项 \(\mathbf x_{t-1},\mathbf x_t\) 之间的联系。

现在,我们只需要为 \(\eqref{obj}\) 式中出现的所有概率分布设计具体的形式,就可以代入计算了。为了让 KL 散度可解,一个自然的想法就是把它们都设计为正态分布的形式。

前向过程

source:https://cvpr2022-tutorial-diffusion-models.github.io/

首先我们关注前向过程,即从 \(\mathbf x_0\)\(\mathbf x_T\) 的马尔可夫链: \[ q(\mathbf x_{0:T})=q(\mathbf x_0)\prod_{t=1}^Tq(\mathbf x_t\vert\mathbf x_{t-1}) \] DDPM 将 \(q(\x_t\vert\x_{t-1})\) 设计为: \[ q(\x_t\vert \x_{t-1})=\calN(\x_t;\sqrt{1-\beta_t}\x_{t-1},\beta_t\mathbf{I})\tag{6}\label{q} \] 其中 \(\beta_t\in(0,1)\) 是事先指定的超参数,代表从 \(\x_{t-1}\)\(\x_t\) 这一步的方差。直观上理解,如果 \(\beta_t\) 比较小,那么 \(q(\x_t\vert\x_{t-1})\) 均值依旧在 \(\x_{t-1}\) 附近,方差也不大,故 \(\x_t\) 看起来就是在 \(\x_{t-1}\) 的基础上加了一些噪声。值得注意的是,\(q\) 不带任何可学习参数,这是 DDPM 与 VAE 不一样的地方。

基于 \(\eqref{q}\) 式,我们可以推导出 \(\eqref{obj}\) 式中需要的 \(q(\x_t\vert\x_0)\)\(q(\x_{t-1}\vert\x_t,\x_0)\). 首先推导 \(q(\x_t\vert\x_0)\). 为了书写上的方便,做一个变量代换: \[ \alpha_t=1-\beta_t,\quad\bar\alpha_t=\prod_{i=1}^t\alpha_i \] 那么 \(\eqref{q}\) 式改写作: \[ q(\x_t\vert \x_{t-1})=\calN(\x_t;\sqrt{\alpha_t}\x_{t-1},(1-\alpha_t)\mathbf{I}) \] 这意味着我们可以用如下方式从 \(\x_{t-1}\) 采样 \(\x_t\)\[ \x_t=\sqrt{\alpha_t}\x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1},\quad \epsilon_{t-1}\sim\calN(\mathbf 0,\mathbf{I}) \] 类似地,我们可以用如下方式从 \(\x_{t-2}\) 采样 \(\x_{t-1}\)

\[ \x_{t-1}=\sqrt{\alpha_{t-1}}\x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2},\quad \epsilon_{t-2}\sim\calN(\mathbf 0,\mathbf{I}) \] 合并上面两个式子,从 \(\x_{t-2}\) 直接采样 \(\x_t\) 写作: \[ \x_t=\sqrt{\alpha_t\alpha_{t-1}}\x_{t-2}+{\color{green}\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}} \] 由于两个正态随机变量之和服从均值方差分别相加的正态分布,即: \[ {\color{green}\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}}\sim\calN(\mathbf 0,(1-\alpha_t\alpha_{t-1})\mathbf{I}) \] 所以只需采样一个正态随机变量即可: \[ \x_t=\sqrt{\alpha_t\alpha_{t-1}}\x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\epsilon_{t-2},\quad \epsilon_{t-2}\sim\calN(\mathbf 0,\mathbf{I}) \] 以此类推,从 \(\x_0\) 直接采样 \(\x_t\) 写作: \[ \x_t=\sqrt{\bar\alpha_t}\x_0+\sqrt{1-\bar\alpha_t}\epsilon,\quad \epsilon\sim\calN(\mathbf 0,\mathbf{I}) \tag{7}\label{xtx0} \] 也就是: \[ q(\x_t\vert \x_0)=\calN\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf{I}\right)\tag{8}\label{qxtx0} \] 这样就推出了 \(q(\x_t\vert\x_0)\). 进一步地,我们希望无论输入什么,前向过程最后得到的分布都趋近于标准正态分布,即 \(q(\x_\infty\vert\x_0)=\calN(\x_\infty;\mathbf 0,\mathbf{I})\),因此要求: \[ \lim_{t\to\infty}\sqrt{\bar\alpha_t}=0,\quad\lim_{t\to\infty}\sqrt{1-\bar\alpha_t}=1 \] 为满足这个要求,只需 \(\alpha_1>\alpha_2>\cdots>\alpha_T\),也即 \(\beta_1<\beta_2<\cdots<\beta_T\) 即可。直观来看,这意味着初期加噪较弱,后期加噪变强。在 DDPM 中,作者取 \(\beta_1,\ldots,\beta_T\) 为从 \(0.0001\)\(0.02\) 的线性递增序列。


接下来推导 \(q(\x_{t-1}\vert\x_t,\x_0)\),根据贝叶斯公式有: \[ \begin{align} q(\x_{t-1}\vert \x_t,\x_0)&=\frac{q(\x_t\vert \x_{t-1},\x_0)q(\x_{t-1}\vert \x_0)}{q(\x_t\vert\x_0)} =\frac{q(\x_t\vert \x_{t-1})q(\x_{t-1}\vert \x_0)}{q(\x_t\vert \x_0)}\\ &\propto\exp\left(-\frac{1}{2}\left(\frac{\Vert\x_t-\sqrt{\alpha_t}\x_{t-1}\Vert^2}{\beta_t}+\frac{\Vert\x_{t-1}-\sqrt{\bar\alpha_{t-1}}\x_0\Vert^2}{1-\bar\alpha_{t-1}}-\frac{\Vert\x_t-\sqrt{\alpha_t}\x_0\Vert^2}{1-\bar\alpha_t}\right)\right)\\ &=\exp\left(-\frac{1}{2}\left(\underbrace{\frac{1-\bar\alpha_t}{\beta_t(1-\bar\alpha_{t-1})}}_A\Vert\x_{t-1}\Vert^2+\underbrace{\left(-\frac{2\sqrt{\alpha_t}\x_t}{\beta_t}-\frac{2\sqrt{\bar\alpha_{t-1}}\x_{0}}{1-\bar\alpha_{t-1}}\right)}_B\cdot\x_{t-1}+C(\x_t,\x_0)\right)\right) \end{align} \] 这意味着 \(q(\x_{t-1}\vert\x_t,\x_0)\) 也是一个正态分布: \[ \begin{align} &q(\x_{t-1}\vert \x_t,\x_0)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I\right)\\ \text{where}\quad&\mu_t(\x_t,\x_0)=\frac{-B}{2A}=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_0\\ &\tilde\beta_t=\frac{1}{A}=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t \end{align}\tag{9}\label{q-reverse} \]

注意\(t>1\) 时上面的推导没问题,但需要特别考虑 \(t=1\) 的情形。当 \(t=1\) 时,\(q(\x_{t-1}\vert \x_t,\x_0)=q(\x_0\vert\x_1,\x_0)\) 其实是一个确定性的分布,即一定取 \(\x_0\);如果我们合理地补充定义 \(\bar\alpha_0=1\)(因为 \(\bar\alpha_i\) 代表累乘,第零项设置为 \(1\) 很合理),会发现 \(\mu_1(\x_1,\x_0)=\x_0,\,\tilde\beta_1=0\),正好符合预期,所以上面 \(\mu_t(\x_t,\x_0)\)\(\tilde\beta_t\) 的表达式对 \(t\geq 1\) 都适用。

Tip:推导时不要每一项都打开老老实实地算,直接提取 \(\x_{t-1}\) 的二次项系数和一次项系数即可。

看到这里,不知读者心中是否有疑惑——为什么人为设置后验分布(即前向过程)是合理的?VAE 中 \(q\) 不是要去拟合真实后验分布吗,现在人为设置好了怎么去拟合啊?私以为,这个问题揭示了 VAE 和 DDPM 出发点的不同。VAE 先定义生成模型 \(p_\theta(x\vert z)\),在这个定义下,存在所谓的“真实”后验分布 \(p_\theta(z\vert x)\),但是它不可解,所以用 \(q_\phi(z\vert x)\) 去近似。DDPM 则是反过来,先定义后验分布(即前向过程),然后根据后验去学习生成模型(即逆向过程)。

逆向过程

source:https://cvpr2022-tutorial-diffusion-models.github.io/

现在我们来关注逆向过程,即从 \(\mathbf x_T\)\(\mathbf x_0\) 的马尔可夫链: \[ p_\theta(\x_{0:T})=p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert \x_t) \] 其中 \(p(\x_T)\) 很容易设计,直接取标准正态分布即可,这也与我们之前设计的 \(q(\x_T\vert\x_0)\) 是匹配的。

对于 \(p_\theta(\x_{t-1}\vert\x_t)\),考虑到 \(\eqref{obj}\) 式中要最小化它与 \(q(\x_{t-1}\vert\x_t,\x_0)\) 之间的 KL 散度,所以为了计算方便,设计为与之相同的形式,即: \[ \begin{align} &p_\theta(\x_{t-1}\vert\x_t)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_\theta(\x_t,t)),\sigma_t^2\mathbf{I}\right)\\ \text{where}\quad&\mu_t(\x_t,\x_\theta(\x_t,t))=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_\theta(\x_t,t)\\ &\sigma_t^2=\tilde\beta_t\text{ or }\beta_t \end{align}\tag{10}\label{p-reverse} \] 其中 \(\x_\theta(\x_t,t)\) 表示以 \(\theta\) 为参数的模型。为了看得更清楚,列表如下:

\(q(\x_{t-1}\vert\x_t,\x_0)\) \(p_\theta(\x_{t-1}\vert\x_t)\mathrel{\mathrel{\vcenter{:}}=}q(\x_{t-1}\vert\x_t,\x_\theta(\x_t,t))\)
表达式 \(\calN(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I)\) \(\calN\left(\x_{t-1};\ \mu_t(\x_t,\x_\theta(\x_t,t)),{\sigma_t^2}\mathbf{I}\right)\)
均值 \(\mu_t(\x_t,\x_0)\) \(\mu_t(\x_t,\x_\theta(\x_t,t))\)
方差 \(\tilde\beta_t\) \(\sigma_t^2=\tilde\beta_t\text{ or }\beta_t\)

可以看到,\(p_\theta(\x_{t-1}\vert\x_t)\) 的均值沿用了 \(q(\x_{t-1}\vert\x_t,\x_0)\) 的形式,只不过用模型 \(\x_\theta(\x_t,t)\) 代替了生成过程中我们并不知道的 \(\x_0\). 对于一个给定的 \(\x_0\),用 \(p_\theta(\x_{t-1}\vert\x_t)\) 去近似 \(q(\x_{t-1}\vert\x_t,\x_0)\),本质上就是在用 \(\x_\theta(\x_t,t)\) 去近似 \(\x_0\),在下一节中我们将优化目标显式写出后可以看得更清楚。

至于方差,DDPM 给出了两个选择 \(\tilde\beta_t\)\(\beta_t\). 前者不难理解,就是沿用了 \(q(\x_{t-1}\vert\x_t,\x_0)\) 的方差,但是后者是出自什么考虑呢?这个问题我们暂时放一放,在之后的文章中详细说明。

损失函数

至此,我们已经确定下 \(\eqref{obj}\) 式中出现的所有概率分布的形式,因而可以代入计算了。为避免读者上下翻阅,把所有公式总结于此: \[ \begin{align} &\text{ELBO}= \underbrace{\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]}_{\text{reconstruction}}- \underbrace{\KL(q(\x_T\vert \x_0)\|p(\x_T))}_{\text{regularization}}- \sum_{t=2}^T\underbrace{\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\KL(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right]}_{\text{matching}}&&\eqref{obj}\\ &q(\x_t\vert \x_0)=\calN\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf{I}\right)&&\eqref{qxtx0}\\ &q(\x_{t-1}\vert \x_t,\x_0)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I\right)&&\eqref{q-reverse}\\ &p_\theta(\x_{t-1}\vert\x_t)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_\theta(\x_t,t)),\sigma_t^2\mathbf{I}\right)&&\eqref{p-reverse} \end{align} \] 首先看正则项,由于我们设计 \(q\) 时要求 \(\lim\limits_{T\to\infty}q(\x_T\vert\x_0)=\calN(\x_\infty;\mathbf0,\mathbf I)\),在 \(T\) 较大时趋近于标准正态分布,并且 \(p(\x_T)\) 也设置为标准正态分布,所以正则项可以忽略。

然后看重构项,代入表达式得: \[ \log p_\theta(\x_0\vert\x_1)=\text{constant}-\frac{1}{2\sigma_1^2}\Vert\x_0-\mu_t(\x_1,\x_\theta(\x_1,1))\Vert^2 \] 最后看匹配项,根据两个正态分布的 KL 散度计算公式,当取 \(\sigma_t^2=\tilde\beta_t\) 时,有: \[ \mathrm{KL}(q(\x_{t-1}\vert \x_t,\x_0) \Vert p_\theta(\x_{t-1}\vert\x_t))=\frac{1}{2\sigma_t^2}\Vert\mu_t(\x_t,\x_0)-\mu_t(\x_t,\x_\theta(\x_t,t))\Vert^2 \] 鉴于 \(\x_0\) 可以写作 \(\mu_1(\x_1,\x_0)\),因此重构项与匹配项的格式可以统一起来。综上,总的损失函数为: \[ \begin{align} \mathcal L_{\x_0}(\theta)&=\sum_{t=1}^T\frac{1}{2\sigma_t^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\mu_t(\x_t,\x_0)-\mu_t(\x_t,\x_\theta(\x_t,t))\Vert^2\right]\\ &=\sum_{t=1}^T\frac{1}{2\sigma_t^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\left\Vert\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_0-\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_\theta(\x_t,t)\right\Vert^2\right]\\ &=\sum_{t=1}^T\frac{\bar\alpha_{t-1}\beta_t^2}{2\sigma_t^2(1-\bar\alpha_t)^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\x_0-\x_\theta(\x_t,t)\Vert^2\right]\tag{11}\label{loss-x} \end{align} \] 也就是 L2 损失或 MSE 损失。这验证了前文我们提到的,对于一个给定的 \(\x_0\),模型 \(\x_\theta(\x_t,t)\) 的作用是去近似 \(\x_0\) 的说法。

进一步地,DDPM 对 \(\x_\theta(\x_t,t)\) 做了重参数化。根据 \(\eqref{xtx0}\) 式有: \[ \x_0=\frac{1}{\sqrt{\bar\alpha_t}}\left(\x_t-\sqrt{1-\bar\alpha_t}\epsilon\right),\quad\epsilon\sim\calN(\mathbf0,\mathbf I) \] 因此可以把 \(\x_\theta(\x_t,t)\) 写作相同的形式: \[ \x_\theta(\x_t,t)=\frac{1}{\sqrt{\bar\alpha_t}}\left(\x_t-\sqrt{1-\bar\alpha_t}\epsilon_\theta(\x_t,t)\right) \] 代入 \(\eqref{loss-x}\) 式得: \[ \mathcal L_{\x_0}(\theta)=\sum_{t=1}^T\frac{\beta_t^2}{2\sigma_t^2{\alpha_t}(1-\bar\alpha_t)}\E_{\epsilon\sim\calN(\mathbf0,\mathbf I)}\left[\Vert{\epsilon}-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 其中 \(\x_t\)\(\x_0\) 与采样出的 \(\epsilon\) 根据 \(\eqref{xtx0}\) 式计算,为简便起见没有显式地代入上式。这里的 \(\epsilon_\theta(\x_t,t)\) 就是所谓的“噪声预测模型”,用于近似当前采样出的噪声 \(\epsilon\). 通过实验探索,DDPM 作者发现将模型参数化为 \(\epsilon_\theta(\x_t,t)\) 的效果比 \(\x_\theta(\x_t,t)\) 更好,并且把前面的系数丢掉效果更好。另外,对 \(t\) 求和可以改作对 \(t\) 均匀采样,因此损失函数简化为: \[ \mathcal L_{\x_0,\text{simple}}(\theta)=\E_{t,\epsilon}\left[\Vert{\epsilon}-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 最后,注意本文至此的所有推导都建立在给定一个 \(\x_0\) 的基础上,实际训练时 \(\x_0\) 是从训练集中采样的,因此最终的损失函数为: \[ \mathcal L_\text{simple}(\theta)=\E_{t,\x_0,\epsilon}\left[\Vert\epsilon-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 相应算法流程如下:

可见 DDPM 虽然推导有些复杂,但最后得到的算法流程却异常简单,效果也很好,难怪迅速成为了研究的热点。

一些注解

直观上 DDPM 干的事情可以总结为——前向过程对输入图像一步步加噪,使之变成高斯噪声;逆向过程使用模型来预测原图(或预测添加的噪声),进而把带噪图像一步步转换回真实图像。这里容易产生一个误解:既然每一步 \(\x_\theta(\x_t,t)\) 都是去近似 \(\x_0\),那么岂不是直接一步生成 \(\x_\theta(\x_T,T)\approx \x_0\) 就可以了?并不是这样的。注意 \(\mathbf x_0\) 是不断从数据集中采样出来的,由于不同的 \(\mathbf x_0\) 都有可能得到相同的 \(\mathbf x_t\),所以随着训练的进行,\(\mathbf x_\theta(\mathbf x_t,t)\) 拟合的是这些 \(\mathbf x_0\) 的平均值,是一个模糊的图像。其实通过简单的推导就可以知道:

\[ \min_\theta\mathbb E_{\x_0\sim q(\x_0),\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\mathbf x_\theta(\mathbf x_t,t)-\mathbf x_0\Vert_2^2\right]\implies \mathbf x_{\theta^\ast}(\mathbf x_t,t)=\mathbb E_{q(\mathbf x_0\vert\mathbf x_t)}[\mathbf x_0]=\mathbb E[\mathbf x_0\vert\mathbf x_t],\quad\forall \x_t\sim q(\x_t) \] 因此模型 \(\mathbf x_\theta(\mathbf x_t,t)\) 拟合的真值是 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\),即 \(\mathbf x_0\) 关于 \(q(\mathbf x_0\vert\mathbf x_t)\) 的加权平均。同理,\(\epsilon_\theta(\x_t,t)\) 拟合的真值是 \(\mathbb E[\epsilon\vert\x_t]\).

那么,纵观整个生成过程,我们可以把 \(\x_\theta(\x_t,t)\approx\E[\x_0\vert\x_t]\) 理解为大方向。每次我们朝着大方向走一小步,然后重新看看大方向在哪里,再走下一小步。打个比方,我想从成都走到深圳,我知道大致要朝东南 45° 方向走,但是“差之毫厘,谬以千里”,直接走可能一不小心就登陆台湾了,所以我先走一小步到重庆;然后再看地图,大方向变成了东南 50°,于是又走一小步,但是拐弯过猛到了贵阳;没关系,再看地图,大方向变成了东南 30°……这样每走一小步都对大方向做一点修正,最后就能平稳地到达目的地了。

代码实现

Github repo: https://github.com/xyfJASON/Diffusion-Models-Implementations

结果展示

更多内容请查看代码仓库。

关于 clipping

在官方代码[16]和若干其他实现中,我发现大家普遍喜欢使用 clipping,即对于逆向过程的每一步,在预测 \(\epsilon_\theta(\x_t,t)\) 之后,先算 \(\x_\theta(\x_t,t)\)然后 clip 到 \([-1,1]\) 之间,再算 \(\mu_t(\x_t,\x_\theta(\x_t,t))\). 这样做为什么合理呢?这是因为 clipping 本质上是对模型误差的人工修正——\(\x_\theta(\x_t,t)\) 是用来估计 \(\x_0\) 的,本就应该在 \([-1,1]\) 之间,只是出于模型误差而跳脱了这个范围,所以强行把它 clip 回来并不违背理论;另外,clipping 只影响逆向过程,并不需要重新训练模型。

色调偏移问题

早期的实现版本在 MNIST 上 work 得很好,但是在 CelebA-HQ 上训练时出现了色调偏移(color shifting)问题。具体而言,我发现各个 epoch 之间的图片色调会发生明显偏移,比如前一个 epoch 图片都偏红,后一个 epoch 图片都偏蓝,有时候甚至亮/暗得根本看不清人脸,如下图所示:

本以为是模型还没收敛,但是 300 多个 epochs 之后仍然是这样,这就不得不重视起来。一番排查后,发现是我偷懒没有实现 EMA 导致的,特别是原作者把 decay rate 设置为 0.9999,意味着参数更新其实是很慢的。EMA 的本质是对历史权重做了加权平均,可以看作若干历史模型的集成。从这个角度来说,那些色调发生不同偏移的模型互相“抵消”,从而缓解了色调偏移问题。(注意只是缓解,并没有消除!)

后来我读到其实宋飏在论文[3]里面就提到了这一现象,这也是他引入 EMA 的原因。说到底,色调偏移就是模型还没有收敛到真实分布的一个表现,只不过视觉上给人的冲击比较强烈罢了。

[update 2022.11.27] 虽然 EMA 的 decay rate 设置为 0.9999,但 tensorflow 的官方实现其实是这样的: \[ \text{decay}=\min\left(\text{decay}_\max,\frac{1+\text{num}\_\text{updates}}{10+\text{num}\_\text{updates}}\right) \] 随着 num_updates 增加,对应的 decay 序列是 \(0.1818,0.2500,0.3077,0.3571,0.4000,\ldots\),一直到 90000 步左右 decay 才会固定在 0.9999. 这样做能减小初始化的随机权重对整体权重的影响,模型见效更快。

References

  1. Ho, Jonathan, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33 (2020): 6840-6851. ↩︎
  2. Sohl-Dickstein, Jascha, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International Conference on Machine Learning, pp. 2256-2265. PMLR, 2015. ↩︎
  3. Song, Yang, and Stefano Ermon. Improved techniques for training score-based generative models. Advances in neural information processing systems 33 (2020): 12438-12448. ↩︎
  4. Luo, Calvin. Understanding diffusion models: A unified perspective. arXiv preprint arXiv:2208.11970 (2022). ↩︎
  5. Lilian Weng. What are Diffusion Models?. https://lilianweng.github.io/posts/2021-07-11-diffusion-models ↩︎
  6. Angus Turner. Diffusion Models as a kind of VAE. https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html ↩︎
  7. Denoising Diffusion-based Generative Modeling: Foundations and Applications. https://cvpr2022-tutorial-diffusion-models.github.io ↩︎
  8. 苏剑林. (Jul. 06, 2022). 《生成扩散模型漫谈(二):DDPM = 自回归式VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/9152 ↩︎
  9. 苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://kexue.fm/archives/9164 ↩︎
  10. 由浅入深了解Diffusion Model - ewrfcas的文章 - 知乎 https://zhuanlan.zhihu.com/p/525106459 ↩︎
  11. 扩散模型之DDPM - 小小将的文章 - 知乎 https://zhuanlan.zhihu.com/p/563661713 ↩︎
  12. Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读. https://www.bilibili.com/video/BV1b541197HX ↩︎
  13. Diffusion Model:比“GAN”还要牛逼的图像生成模型!https://www.bilibili.com/video/BV1pD4y1179T ↩︎
  14. 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - Nicolas的文章 - 知乎 https://zhuanlan.zhihu.com/p/68748778 ↩︎
  15. https://huggingface.co/blog/annotated-diffusion ↩︎
  16. https://github.com/lucidrains/denoising-diffusion-pytorch ↩︎
  17. https://github.com/hojonathanho/diffusion ↩︎
  18. https://github.com/openai/improved-diffusion ↩︎
  19. https://github.com/lucidrains/imagen-pytorch ↩︎
  20. https://github.com/tqch/ddpm-torch ↩︎
  21. https://github.com/abarankab/DDPM ↩︎
  22. https://github.com/w86763777/pytorch-ddpm ↩︎

从VAE到DDPM
https://xyfjason.github.io/blog-main/2022/09/29/从VAE到DDPM/
作者
xyfJASON
发布于
2022年9月29日
许可协议