从VAE到DDPM
\[ \newcommand{\E}{\mathbb E} \newcommand{\KL}{\mathrm{KL}} \newcommand{\calN}{\mathcal N} \newcommand{\x}{\mathbf x} \newcommand{\z}{\mathbf z} \newcommand{\coloneqq}{\mathrel{\mathrel{\vcenter{:}}=}} \]
VAE 回顾
在之前的文章中,我们详细地梳理了一遍 VAE,这里做一个简单回顾。
在 VAE 中,为了最大化对数似然: \[ L(\theta)=\log p_\theta(x)=\log\left(\int_zp_\theta(x\vert z)p(z)\mathrm dz\right) \] 我们引入变分后验 \(q_\phi(z\vert x)\): \[ \begin{align} L(\theta)&=\log p_\theta(x)\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log p_\theta(x)\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{p_\theta(z\vert x)}\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\left(\frac{p_\theta(x,z)}{p_\theta(z\vert x)}\cdot\frac{q_\phi(z\vert x)}{q_\phi(z\vert x)}\right)\right]\\ &=\underbrace{\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{q_\phi(z\vert x)}\right]}_\text{ELBO}+\underbrace{\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{q_\phi(z\vert x)}{p_\theta(z\vert x)}\right]}_{\KL(q_\phi(z\vert x)\| p_\theta(z\vert x))}\\ &\geq \text{ELBO} \end{align}\tag{1}\label{vae} \] 得到证据下界 ELBO,通过最大化 ELBO 来最大化对数似然。进一步地,ELBO 还可以拆写成重构项和 KL 正则项: \[ \begin{align} \text{ELBO}&=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x,z)}{q_\phi(z\vert x)}\right]\\ &=\E_{z\sim q_\phi(z\vert x)}\left[\log\frac{p_\theta(x\vert z)p(z)}{q_\phi(z\vert x)}\right]\\ &=\underbrace{\E_{z\sim q_\phi(z\vert x)}[\log p_\theta(x\vert z)]}_{\text{reconstruction}}-\underbrace{\KL(q_\phi(z\vert x)\| p(z))}_{\text{regularization}}\\ \end{align}\label{elbo-vae}\tag{2} \] 为了计算上的方便,实践中常将 \(p(z),q_\phi(z\vert x),p_\theta(x\vert z)\) 都取为正态分布,具体而言,它们分别是: \[ \begin{align} &p(z)\coloneqq \calN(z;0,I)\\ &q_\phi(z\vert x)\coloneqq\calN\left(z;\mu_\phi(x),\mathrm{diag}(\sigma_\phi^2(x))\right)\\ &p_\theta(x\vert z)\coloneqq \calN(\mu_\theta(z),\sigma^2 I)&&\sigma\text{ is a constant}\\ \end{align} \] 代入 \(\eqref{elbo-vae}\) 式即可得到损失函数: \[ \mathcal L=\underbrace{\frac{1}{2\sigma^2}\|x-\mu_\theta(z)\|^2}_\text{reconstruction}+\underbrace{\frac{1}{2}\sum_{i=1}^d\left(\mu_\phi^2(x)_i+\sigma_\phi^2(x)_i-\log \sigma_\phi^2(x)_i-1\right)}_\text{KL regularization} \] 正态分布的假设虽然便于计算,但限制了变分后验的形式,使之不能很好地近似真实后验分布,从而限制了 VAE 的能力。因此,一个自然的改进思路就是用更强的方式去建模变分后验——比如再套一个 VAE?
双层 VAE
把 VAE 中的单个隐变量 \(z\) 换成两个隐变量 \(z_1,z_2\),形成如下马尔可夫链:
虽然有两个隐变量,但如果把它们视为一个整体,那证据下界 ELBO 的推导过程与 \(\eqref{vae}\) 式没有什么本质不同,因此我们只需将 \(\eqref{elbo-vae}\) 式中的 \(z\) 换做 \(z_1,z_2\) 就得到了双层 VAE 的 ELBO: \[ \text{ELBO}_\text{2-layers}=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2\vert x)}\right]=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x\vert z_1)p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_2\vert z_1)q_\phi(z_1\vert x)}\right]\label{elbo-2}\tag{3} \]
我们依旧希望从 ELBO 中拆解出重构项和正则项。重构项比较简单,只需要把 \(p_\theta(x\vert z_1)\) 拆出来就是了,问题在于剩下的一坨应该如何处理。破局的关键在于利用马尔可夫性质把分母上的 \(q_\phi(z_2\vert z_1)\) 改写作 \(q_\phi(z_2\vert z_1,x)\),然后用贝叶斯公式: \[ q_\phi(z_2\vert z_1)=q_\phi(z_2\vert z_1,x)=\frac{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}{q_\phi(z_1\vert x)} \] 代回 \(\eqref{elbo-2}\) 式得: \[ \begin{align} \text{ELBO}_\text{2-layers}&=\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(x\vert z_1)p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_2\vert z_1)q_\phi(z_1\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)p(z_2)}{\frac{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}{q_\phi(z_1\vert x)}q_\phi(z_1\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)p(z_2)}{q_\phi(z_1\vert z_2,x)q_\phi(z_2\vert x)}\right]\\ &=\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]+\E_{z_2\sim q_\phi(z_2\vert x)}\left[\log\frac{p(z_2)}{q_\phi(z_2\vert x)}\right]+\E_{z_1,z_2\sim q_\phi(z_1,z_2\vert x)}\left[\log\frac{p_\theta(z_1\vert z_2)}{q_\phi(z_1\vert z_2,x)}\right]\\ &=\underbrace{\E_{z_1\sim q_\phi(z_1\vert x)}[p_\theta(x\vert z_1)]}_{\text{reconstruction}}-\underbrace{\KL(q_\phi(z_2\vert x)\| p(z_2))}_{\text{regularization}}-\underbrace{\E_{z_2\sim q_\phi(z_2\vert x)}[\KL(q_\phi(z_1\vert z_2,x)\|p_\theta(z_1\vert z_2))]}_{\text{matching}}\\ \end{align} \] 可以看见,双层 VAE 的 ELBO 由三部分组成:重构项、正则项以及匹配项。对比 VAE 的 ELBO,我们可以认为 \(z_1,z_2\) 分担了原本一个隐变量的工作。具体而言,在 VAE 中,隐变量既要重构,又要逼近先验分布,且这两个任务是有点矛盾的;而现在,重构依靠 \(z_1\) 完成,逼近先验依靠 \(z_2\) 完成,二者由匹配项联系起来,从而增加了模型的灵活性。
DDPM
在双层 VAE 的基础上,我们能再多加几层吗?
如上图所示,为方便叙述,我们引入两个称呼:
- 称从 \(\x_0\) 到 \(\x_T\) 的马尔可夫链为前向过程 (forward process) 或扩散过程 (diffusion process);
- 称从 \(\x_T\) 到 \(\x_0\) 的马尔可夫链为逆向过程 (reverse process) 或去噪过程 (denoising process).
用概率图模型的术语来说,前向过程对应 inference model,逆向过程对应 generative model. 另外,为书写上的方便,下文将 \(\x_l,\cdots,\x_r\) 简写为 \(\x_{l:r}\).
同双层 VAE 一样的道理,把 \(\x_{1:T}\) 整体看作 VAE 中的隐变量,代入 \(\eqref{elbo-vae}\) 式就可以得到 DDPM 的 ELBO: \[ \text{ELBO}=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p_\theta(\x_{0:T})}{q(\x_{1:T}\vert \x_0)}\right]=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert\x_t)}{\prod_{t=1}^T q(\x_t\vert \x_{t-1})}\right]\tag{4}\label{elbo-ddpm} \] 接下来的推导技巧和双层 VAE 如出一辙,即将分母中的 \(q(\x_t\vert\x_{t-1})\) 写作 \(q(\x_t\vert\x_{t-1},\x_0)\),然后使用贝叶斯公式,即可进行大量的消元: \[ \begin{align} \prod_{t=1}^T q(\x_t\vert \x_{t-1})&=q(\x_1\vert\x_0)\prod_{t=2}^T q(\x_t\vert \x_{t-1})=q(\x_1\vert\x_0)\prod_{t=2}^T q(\x_t\vert \x_{t-1},\x_0)\\ &=q(\x_1\vert\x_0)\prod_{t=2}^T \frac{q(\x_t\vert\x_0) q(\x_{t-1}\vert\x_t,\x_0)}{q(\x_{t-1}\vert\x_0)}=q(\x_T\vert\x_0)\prod_{t=2}^T q(\x_{t-1}\vert\x_t,\x_0) \end{align} \] 代回 \(\eqref{elbo-ddpm}\) 式得: \[ \begin{align} \text{ELBO}&=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert\x_t)}{\prod_{t=1}^T q(\x_t\vert \x_{t-1})}\right]\\ &=\E_{\x_{1:T}\sim q(\x_{1:T}\vert \x_0)}\left[\log\frac{p(\x_T)p_\theta(\x_0\vert\x_1)\prod_{t=2}^Tp_\theta(\x_{t-1}\vert\x_t)}{q(\x_T\vert\x_0)\prod_{t=2}^T q(\x_{t-1}\vert\x_t,\x_0)}\right]\\ &=\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]+\E_{\x_T\sim q(\x_T\vert \x_0)}\left[\log\frac{p(\x_T)}{q(\x_T\vert\x_0)}\right]+\sum_{t=2}^T\E_{\x_{t-1},\x_t\sim q(\x_{t-1},\x_t\vert \x_0)}\left[\log\frac{p_\theta(\x_{t-1}\vert\x_t)}{q(\x_{t-1}\vert\x_t,\x_0)}\right]\\ &= \underbrace{\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]}_{\text{reconstruction}}- \underbrace{\KL(q(\x_T\vert \x_0)\|p(\x_T))}_{\text{regularization}}- \sum_{t=2}^T\underbrace{\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\KL(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right]}_{\text{matching}} \end{align}\tag{5}\label{obj} \] 同样出现了重构项、正则项和匹配项。重构项要求 \(\x_1\) 能够重构 \(\x_0\),正则项要求 \(\x_T\) 的后验分布逼近先验分布,而匹配项则建立起相邻两项 \(\mathbf x_{t-1},\mathbf x_t\) 之间的联系。
现在,我们只需要为 \(\eqref{obj}\) 式中出现的所有概率分布设计具体的形式,就可以代入计算了。为了让 KL 散度可解,一个自然的想法就是把它们都设计为正态分布的形式。
前向过程
首先我们关注前向过程,即从 \(\mathbf x_0\) 到 \(\mathbf x_T\) 的马尔可夫链: \[ q(\mathbf x_{0:T})=q(\mathbf x_0)\prod_{t=1}^Tq(\mathbf x_t\vert\mathbf x_{t-1}) \] DDPM 将 \(q(\x_t\vert\x_{t-1})\) 设计为: \[ q(\x_t\vert \x_{t-1})=\calN(\x_t;\sqrt{1-\beta_t}\x_{t-1},\beta_t\mathbf{I})\tag{6}\label{q} \] 其中 \(\beta_t\in(0,1)\) 是事先指定的超参数,代表从 \(\x_{t-1}\) 到 \(\x_t\) 这一步的方差。直观上理解,如果 \(\beta_t\) 比较小,那么 \(q(\x_t\vert\x_{t-1})\) 均值依旧在 \(\x_{t-1}\) 附近,方差也不大,故 \(\x_t\) 看起来就是在 \(\x_{t-1}\) 的基础上加了一些噪声。值得注意的是,\(q\) 不带任何可学习参数,这是 DDPM 与 VAE 不一样的地方。
基于 \(\eqref{q}\) 式,我们可以推导出 \(\eqref{obj}\) 式中需要的 \(q(\x_t\vert\x_0)\) 和 \(q(\x_{t-1}\vert\x_t,\x_0)\). 首先推导 \(q(\x_t\vert\x_0)\). 为了书写上的方便,做一个变量代换: \[ \alpha_t=1-\beta_t,\quad\bar\alpha_t=\prod_{i=1}^t\alpha_i \] 那么 \(\eqref{q}\) 式改写作: \[ q(\x_t\vert \x_{t-1})=\calN(\x_t;\sqrt{\alpha_t}\x_{t-1},(1-\alpha_t)\mathbf{I}) \] 这意味着我们可以用如下方式从 \(\x_{t-1}\) 采样 \(\x_t\): \[ \x_t=\sqrt{\alpha_t}\x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1},\quad \epsilon_{t-1}\sim\calN(\mathbf 0,\mathbf{I}) \] 类似地,我们可以用如下方式从 \(\x_{t-2}\) 采样 \(\x_{t-1}\):
\[ \x_{t-1}=\sqrt{\alpha_{t-1}}\x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2},\quad \epsilon_{t-2}\sim\calN(\mathbf 0,\mathbf{I}) \] 合并上面两个式子,从 \(\x_{t-2}\) 直接采样 \(\x_t\) 写作: \[ \x_t=\sqrt{\alpha_t\alpha_{t-1}}\x_{t-2}+{\color{green}\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}} \] 由于两个正态随机变量之和服从均值方差分别相加的正态分布,即: \[ {\color{green}\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}}\sim\calN(\mathbf 0,(1-\alpha_t\alpha_{t-1})\mathbf{I}) \] 所以只需采样一个正态随机变量即可: \[ \x_t=\sqrt{\alpha_t\alpha_{t-1}}\x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\epsilon_{t-2},\quad \epsilon_{t-2}\sim\calN(\mathbf 0,\mathbf{I}) \] 以此类推,从 \(\x_0\) 直接采样 \(\x_t\) 写作: \[ \x_t=\sqrt{\bar\alpha_t}\x_0+\sqrt{1-\bar\alpha_t}\epsilon,\quad \epsilon\sim\calN(\mathbf 0,\mathbf{I}) \tag{7}\label{xtx0} \] 也就是: \[ q(\x_t\vert \x_0)=\calN\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf{I}\right)\tag{8}\label{qxtx0} \] 这样就推出了 \(q(\x_t\vert\x_0)\). 进一步地,我们希望无论输入什么,前向过程最后得到的分布都趋近于标准正态分布,即 \(q(\x_\infty\vert\x_0)=\calN(\x_\infty;\mathbf 0,\mathbf{I})\),因此要求: \[ \lim_{t\to\infty}\sqrt{\bar\alpha_t}=0,\quad\lim_{t\to\infty}\sqrt{1-\bar\alpha_t}=1 \] 为满足这个要求,只需 \(\alpha_1>\alpha_2>\cdots>\alpha_T\),也即 \(\beta_1<\beta_2<\cdots<\beta_T\) 即可。直观来看,这意味着初期加噪较弱,后期加噪变强。在 DDPM 中,作者取 \(\beta_1,\ldots,\beta_T\) 为从 \(0.0001\) 到 \(0.02\) 的线性递增序列。
接下来推导 \(q(\x_{t-1}\vert\x_t,\x_0)\),根据贝叶斯公式有: \[ \begin{align} q(\x_{t-1}\vert \x_t,\x_0)&=\frac{q(\x_t\vert \x_{t-1},\x_0)q(\x_{t-1}\vert \x_0)}{q(\x_t\vert\x_0)} =\frac{q(\x_t\vert \x_{t-1})q(\x_{t-1}\vert \x_0)}{q(\x_t\vert \x_0)}\\ &\propto\exp\left(-\frac{1}{2}\left(\frac{\Vert\x_t-\sqrt{\alpha_t}\x_{t-1}\Vert^2}{\beta_t}+\frac{\Vert\x_{t-1}-\sqrt{\bar\alpha_{t-1}}\x_0\Vert^2}{1-\bar\alpha_{t-1}}-\frac{\Vert\x_t-\sqrt{\alpha_t}\x_0\Vert^2}{1-\bar\alpha_t}\right)\right)\\ &=\exp\left(-\frac{1}{2}\left(\underbrace{\frac{1-\bar\alpha_t}{\beta_t(1-\bar\alpha_{t-1})}}_A\Vert\x_{t-1}\Vert^2+\underbrace{\left(-\frac{2\sqrt{\alpha_t}\x_t}{\beta_t}-\frac{2\sqrt{\bar\alpha_{t-1}}\x_{0}}{1-\bar\alpha_{t-1}}\right)}_B\cdot\x_{t-1}+C(\x_t,\x_0)\right)\right) \end{align} \] 这意味着 \(q(\x_{t-1}\vert\x_t,\x_0)\) 也是一个正态分布: \[ \begin{align} &q(\x_{t-1}\vert \x_t,\x_0)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I\right)\\ \text{where}\quad&\mu_t(\x_t,\x_0)=\frac{-B}{2A}=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_0\\ &\tilde\beta_t=\frac{1}{A}=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t \end{align}\tag{9}\label{q-reverse} \]
注意:\(t>1\) 时上面的推导没问题,但需要特别考虑 \(t=1\) 的情形。当 \(t=1\) 时,\(q(\x_{t-1}\vert \x_t,\x_0)=q(\x_0\vert\x_1,\x_0)\) 其实是一个确定性的分布,即一定取 \(\x_0\);如果我们合理地补充定义 \(\bar\alpha_0=1\)(因为 \(\bar\alpha_i\) 代表累乘,第零项设置为 \(1\) 很合理),会发现 \(\mu_1(\x_1,\x_0)=\x_0,\,\tilde\beta_1=0\),正好符合预期,所以上面 \(\mu_t(\x_t,\x_0)\) 和 \(\tilde\beta_t\) 的表达式对 \(t\geq 1\) 都适用。
Tip:推导时不要每一项都打开老老实实地算,直接提取 \(\x_{t-1}\) 的二次项系数和一次项系数即可。
看到这里,不知读者心中是否有疑惑——为什么人为设置后验分布(即前向过程)是合理的?VAE 中 \(q\) 不是要去拟合真实后验分布吗,现在人为设置好了怎么去拟合啊?私以为,这个问题揭示了 VAE 和 DDPM 出发点的不同。VAE 先定义生成模型 \(p_\theta(x\vert z)\),在这个定义下,存在所谓的“真实”后验分布 \(p_\theta(z\vert x)\),但是它不可解,所以用 \(q_\phi(z\vert x)\) 去近似。DDPM 则是反过来,先定义后验分布(即前向过程),然后根据后验去学习生成模型(即逆向过程)。
逆向过程
现在我们来关注逆向过程,即从 \(\mathbf x_T\) 到 \(\mathbf x_0\) 的马尔可夫链: \[ p_\theta(\x_{0:T})=p(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert \x_t) \] 其中 \(p(\x_T)\) 很容易设计,直接取标准正态分布即可,这也与我们之前设计的 \(q(\x_T\vert\x_0)\) 是匹配的。
对于 \(p_\theta(\x_{t-1}\vert\x_t)\),考虑到 \(\eqref{obj}\) 式中要最小化它与 \(q(\x_{t-1}\vert\x_t,\x_0)\) 之间的 KL 散度,所以为了计算方便,设计为与之相同的形式,即: \[ \begin{align} &p_\theta(\x_{t-1}\vert\x_t)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_\theta(\x_t,t)),\sigma_t^2\mathbf{I}\right)\\ \text{where}\quad&\mu_t(\x_t,\x_\theta(\x_t,t))=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_\theta(\x_t,t)\\ &\sigma_t^2=\tilde\beta_t\text{ or }\beta_t \end{align}\tag{10}\label{p-reverse} \] 其中 \(\x_\theta(\x_t,t)\) 表示以 \(\theta\) 为参数的模型。为了看得更清楚,列表如下:
\(q(\x_{t-1}\vert\x_t,\x_0)\) | \(p_\theta(\x_{t-1}\vert\x_t)\mathrel{\mathrel{\vcenter{:}}=}q(\x_{t-1}\vert\x_t,\x_\theta(\x_t,t))\) | |
---|---|---|
表达式 | \(\calN(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I)\) | \(\calN\left(\x_{t-1};\ \mu_t(\x_t,\x_\theta(\x_t,t)),{\sigma_t^2}\mathbf{I}\right)\) |
均值 | \(\mu_t(\x_t,\x_0)\) | \(\mu_t(\x_t,\x_\theta(\x_t,t))\) |
方差 | \(\tilde\beta_t\) | \(\sigma_t^2=\tilde\beta_t\text{ or }\beta_t\) |
可以看到,\(p_\theta(\x_{t-1}\vert\x_t)\) 的均值沿用了 \(q(\x_{t-1}\vert\x_t,\x_0)\) 的形式,只不过用模型 \(\x_\theta(\x_t,t)\) 代替了生成过程中我们并不知道的 \(\x_0\). 对于一个给定的 \(\x_0\),用 \(p_\theta(\x_{t-1}\vert\x_t)\) 去近似 \(q(\x_{t-1}\vert\x_t,\x_0)\),本质上就是在用 \(\x_\theta(\x_t,t)\) 去近似 \(\x_0\),在下一节中我们将优化目标显式写出后可以看得更清楚。
至于方差,DDPM 给出了两个选择 \(\tilde\beta_t\) 与 \(\beta_t\). 前者不难理解,就是沿用了 \(q(\x_{t-1}\vert\x_t,\x_0)\) 的方差,但是后者是出自什么考虑呢?这个问题我们暂时放一放,在之后的文章中详细说明。
损失函数
至此,我们已经确定下 \(\eqref{obj}\) 式中出现的所有概率分布的形式,因而可以代入计算了。为避免读者上下翻阅,把所有公式总结于此: \[ \begin{align} &\text{ELBO}= \underbrace{\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]}_{\text{reconstruction}}- \underbrace{\KL(q(\x_T\vert \x_0)\|p(\x_T))}_{\text{regularization}}- \sum_{t=2}^T\underbrace{\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\KL(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right]}_{\text{matching}}&&\eqref{obj}\\ &q(\x_t\vert \x_0)=\calN\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf{I}\right)&&\eqref{qxtx0}\\ &q(\x_{t-1}\vert \x_t,\x_0)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I\right)&&\eqref{q-reverse}\\ &p_\theta(\x_{t-1}\vert\x_t)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_\theta(\x_t,t)),\sigma_t^2\mathbf{I}\right)&&\eqref{p-reverse} \end{align} \] 首先看正则项,由于我们设计 \(q\) 时要求 \(\lim\limits_{T\to\infty}q(\x_T\vert\x_0)=\calN(\x_\infty;\mathbf0,\mathbf I)\),在 \(T\) 较大时趋近于标准正态分布,并且 \(p(\x_T)\) 也设置为标准正态分布,所以正则项可以忽略。
然后看重构项,代入表达式得: \[ \log p_\theta(\x_0\vert\x_1)=\text{constant}-\frac{1}{2\sigma_1^2}\Vert\x_0-\mu_t(\x_1,\x_\theta(\x_1,1))\Vert^2 \] 最后看匹配项,根据两个正态分布的 KL 散度计算公式,当取 \(\sigma_t^2=\tilde\beta_t\) 时,有: \[ \mathrm{KL}(q(\x_{t-1}\vert \x_t,\x_0) \Vert p_\theta(\x_{t-1}\vert\x_t))=\frac{1}{2\sigma_t^2}\Vert\mu_t(\x_t,\x_0)-\mu_t(\x_t,\x_\theta(\x_t,t))\Vert^2 \] 鉴于 \(\x_0\) 可以写作 \(\mu_1(\x_1,\x_0)\),因此重构项与匹配项的格式可以统一起来。综上,总的损失函数为: \[ \begin{align} \mathcal L_{\x_0}(\theta)&=\sum_{t=1}^T\frac{1}{2\sigma_t^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\mu_t(\x_t,\x_0)-\mu_t(\x_t,\x_\theta(\x_t,t))\Vert^2\right]\\ &=\sum_{t=1}^T\frac{1}{2\sigma_t^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\left\Vert\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_0-\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\x_\theta(\x_t,t)\right\Vert^2\right]\\ &=\sum_{t=1}^T\frac{\bar\alpha_{t-1}\beta_t^2}{2\sigma_t^2(1-\bar\alpha_t)^2}\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\x_0-\x_\theta(\x_t,t)\Vert^2\right]\tag{11}\label{loss-x} \end{align} \] 也就是 L2 损失或 MSE 损失。这验证了前文我们提到的,对于一个给定的 \(\x_0\),模型 \(\x_\theta(\x_t,t)\) 的作用是去近似 \(\x_0\) 的说法。
进一步地,DDPM 对 \(\x_\theta(\x_t,t)\) 做了重参数化。根据 \(\eqref{xtx0}\) 式有: \[ \x_0=\frac{1}{\sqrt{\bar\alpha_t}}\left(\x_t-\sqrt{1-\bar\alpha_t}\epsilon\right),\quad\epsilon\sim\calN(\mathbf0,\mathbf I) \] 因此可以把 \(\x_\theta(\x_t,t)\) 写作相同的形式: \[ \x_\theta(\x_t,t)=\frac{1}{\sqrt{\bar\alpha_t}}\left(\x_t-\sqrt{1-\bar\alpha_t}\epsilon_\theta(\x_t,t)\right) \] 代入 \(\eqref{loss-x}\) 式得: \[ \mathcal L_{\x_0}(\theta)=\sum_{t=1}^T\frac{\beta_t^2}{2\sigma_t^2{\alpha_t}(1-\bar\alpha_t)}\E_{\epsilon\sim\calN(\mathbf0,\mathbf I)}\left[\Vert{\epsilon}-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 其中 \(\x_t\) 由 \(\x_0\) 与采样出的 \(\epsilon\) 根据 \(\eqref{xtx0}\) 式计算,为简便起见没有显式地代入上式。这里的 \(\epsilon_\theta(\x_t,t)\) 就是所谓的“噪声预测模型”,用于近似当前采样出的噪声 \(\epsilon\). 通过实验探索,DDPM 作者发现将模型参数化为 \(\epsilon_\theta(\x_t,t)\) 的效果比 \(\x_\theta(\x_t,t)\) 更好,并且把前面的系数丢掉效果更好。另外,对 \(t\) 求和可以改作对 \(t\) 均匀采样,因此损失函数简化为: \[ \mathcal L_{\x_0,\text{simple}}(\theta)=\E_{t,\epsilon}\left[\Vert{\epsilon}-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 最后,注意本文至此的所有推导都建立在给定一个 \(\x_0\) 的基础上,实际训练时 \(\x_0\) 是从训练集中采样的,因此最终的损失函数为: \[ \mathcal L_\text{simple}(\theta)=\E_{t,\x_0,\epsilon}\left[\Vert\epsilon-\epsilon_\theta(\x_t,t)\Vert^2\right] \] 相应算法流程如下:
可见 DDPM 虽然推导有些复杂,但最后得到的算法流程却异常简单,效果也很好,难怪迅速成为了研究的热点。
一些注解
直观上 DDPM 干的事情可以总结为——前向过程对输入图像一步步加噪,使之变成高斯噪声;逆向过程使用模型来预测原图(或预测添加的噪声),进而把带噪图像一步步转换回真实图像。这里容易产生一个误解:既然每一步 \(\x_\theta(\x_t,t)\) 都是去近似 \(\x_0\),那么岂不是直接一步生成 \(\x_\theta(\x_T,T)\approx \x_0\) 就可以了?并不是这样的。注意 \(\mathbf x_0\) 是不断从数据集中采样出来的,由于不同的 \(\mathbf x_0\) 都有可能得到相同的 \(\mathbf x_t\),所以随着训练的进行,\(\mathbf x_\theta(\mathbf x_t,t)\) 拟合的是这些 \(\mathbf x_0\) 的平均值,是一个模糊的图像。其实通过简单的推导就可以知道:
\[ \min_\theta\mathbb E_{\x_0\sim q(\x_0),\x_t\sim q(\x_t\vert\x_0)}\left[\Vert\mathbf x_\theta(\mathbf x_t,t)-\mathbf x_0\Vert_2^2\right]\implies \mathbf x_{\theta^\ast}(\mathbf x_t,t)=\mathbb E_{q(\mathbf x_0\vert\mathbf x_t)}[\mathbf x_0]=\mathbb E[\mathbf x_0\vert\mathbf x_t],\quad\forall \x_t\sim q(\x_t) \] 因此模型 \(\mathbf x_\theta(\mathbf x_t,t)\) 拟合的真值是 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\),即 \(\mathbf x_0\) 关于 \(q(\mathbf x_0\vert\mathbf x_t)\) 的加权平均。同理,\(\epsilon_\theta(\x_t,t)\) 拟合的真值是 \(\mathbb E[\epsilon\vert\x_t]\).
那么,纵观整个生成过程,我们可以把 \(\x_\theta(\x_t,t)\approx\E[\x_0\vert\x_t]\) 理解为大方向。每次我们朝着大方向走一小步,然后重新看看大方向在哪里,再走下一小步。打个比方,我想从成都走到深圳,我知道大致要朝东南 45° 方向走,但是“差之毫厘,谬以千里”,直接走可能一不小心就登陆台湾了,所以我先走一小步到重庆;然后再看地图,大方向变成了东南 50°,于是又走一小步,但是拐弯过猛到了贵阳;没关系,再看地图,大方向变成了东南 30°……这样每走一小步都对大方向做一点修正,最后就能平稳地到达目的地了。
代码实现
Github repo: https://github.com/xyfJASON/Diffusion-Models-Implementations
结果展示
更多内容请查看代码仓库。
关于 clipping
在官方代码[16]和若干其他实现中,我发现大家普遍喜欢使用 clipping,即对于逆向过程的每一步,在预测 \(\epsilon_\theta(\x_t,t)\) 之后,先算 \(\x_\theta(\x_t,t)\),然后 clip 到 \([-1,1]\) 之间,再算 \(\mu_t(\x_t,\x_\theta(\x_t,t))\). 这样做为什么合理呢?这是因为 clipping 本质上是对模型误差的人工修正——\(\x_\theta(\x_t,t)\) 是用来估计 \(\x_0\) 的,本就应该在 \([-1,1]\) 之间,只是出于模型误差而跳脱了这个范围,所以强行把它 clip 回来并不违背理论;另外,clipping 只影响逆向过程,并不需要重新训练模型。
色调偏移问题
早期的实现版本在 MNIST 上 work 得很好,但是在 CelebA-HQ 上训练时出现了色调偏移(color shifting)问题。具体而言,我发现各个 epoch 之间的图片色调会发生明显偏移,比如前一个 epoch 图片都偏红,后一个 epoch 图片都偏蓝,有时候甚至亮/暗得根本看不清人脸,如下图所示:
本以为是模型还没收敛,但是 300 多个 epochs 之后仍然是这样,这就不得不重视起来。一番排查后,发现是我偷懒没有实现 EMA 导致的,特别是原作者把 decay rate 设置为 0.9999,意味着参数更新其实是很慢的。EMA 的本质是对历史权重做了加权平均,可以看作若干历史模型的集成。从这个角度来说,那些色调发生不同偏移的模型互相“抵消”,从而缓解了色调偏移问题。(注意只是缓解,并没有消除!)
后来我读到其实宋飏在论文[3]里面就提到了这一现象,这也是他引入 EMA 的原因。说到底,色调偏移就是模型还没有收敛到真实分布的一个表现,只不过视觉上给人的冲击比较强烈罢了。
[update 2022.11.27] 虽然 EMA 的 decay rate 设置为 0.9999,但 tensorflow 的官方实现其实是这样的: \[ \text{decay}=\min\left(\text{decay}_\max,\frac{1+\text{num}\_\text{updates}}{10+\text{num}\_\text{updates}}\right) \] 随着 num_updates 增加,对应的 decay 序列是 \(0.1818,0.2500,0.3077,0.3571,0.4000,\ldots\),一直到 90000 步左右 decay 才会固定在 0.9999. 这样做能减小初始化的随机权重对整体权重的影响,模型见效更快。
References
- Ho, Jonathan, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33 (2020): 6840-6851. ↩︎
- Sohl-Dickstein, Jascha, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International Conference on Machine Learning, pp. 2256-2265. PMLR, 2015. ↩︎
- Song, Yang, and Stefano Ermon. Improved techniques for training score-based generative models. Advances in neural information processing systems 33 (2020): 12438-12448. ↩︎
- Luo, Calvin. Understanding diffusion models: A unified perspective. arXiv preprint arXiv:2208.11970 (2022). ↩︎
- Lilian Weng. What are Diffusion Models?. https://lilianweng.github.io/posts/2021-07-11-diffusion-models ↩︎
- Angus Turner. Diffusion Models as a kind of VAE. https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html ↩︎
- Denoising Diffusion-based Generative Modeling: Foundations and Applications. https://cvpr2022-tutorial-diffusion-models.github.io ↩︎
- 苏剑林. (Jul. 06, 2022). 《生成扩散模型漫谈(二):DDPM = 自回归式VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/9152 ↩︎
- 苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://kexue.fm/archives/9164 ↩︎
- 由浅入深了解Diffusion Model - ewrfcas的文章 - 知乎 https://zhuanlan.zhihu.com/p/525106459 ↩︎
- 扩散模型之DDPM - 小小将的文章 - 知乎 https://zhuanlan.zhihu.com/p/563661713 ↩︎
- Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读. https://www.bilibili.com/video/BV1b541197HX ↩︎
- Diffusion Model:比“GAN”还要牛逼的图像生成模型!https://www.bilibili.com/video/BV1pD4y1179T ↩︎
- 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - Nicolas的文章 - 知乎 https://zhuanlan.zhihu.com/p/68748778 ↩︎
- https://huggingface.co/blog/annotated-diffusion ↩︎
- https://github.com/lucidrains/denoising-diffusion-pytorch ↩︎
- https://github.com/hojonathanho/diffusion ↩︎
- https://github.com/openai/improved-diffusion ↩︎
- https://github.com/lucidrains/imagen-pytorch ↩︎
- https://github.com/tqch/ddpm-torch ↩︎
- https://github.com/abarankab/DDPM ↩︎
- https://github.com/w86763777/pytorch-ddpm ↩︎