扩散模型逆向方差的选取

\[ \newcommand{\x}{\mathbf x} \newcommand{\calN}{\mathcal N} \newcommand{\E}{\mathbb E} \]

DDPM: 人为选取方差

首先做一个简要回顾。DDPM 定义前向马尔可夫过程为:

\[ \begin{align} &q(\x_{1:T}\vert\x_0)=q(\x_0)\prod_{t=1}^T q(\x_t\vert\x_{t-1})\\ &q(\x_t\vert\x_{t-1})=\mathcal N\left(\x_t;\sqrt{1-\beta_t}\x_0,\beta_t\mathbf I\right)\\ &q(\x_t\vert\x_0)=\mathcal N\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf I\right)\\ \end{align} \] 计算可得: \[ q(\x_{t-1}\vert\x_t,\x_0)=\mathcal N\left(\x_{t-1};\mu_t(\x_t,\x_0),\tilde\beta_t\mathbf I\right) \] 仿照 \(q(\x_{t-1}\vert\x_t,\x_0)\) 的形式,将 \(p_\theta(\x_{t-1}\vert\x_t)\)​ 定义为: \[ p_\theta(\x_{t-1}\vert\x_t)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_\theta(\x_t,t)),\sigma_t^2\mathbf I\right) \] 模型通过最大化 ELBO 训练: \[ \text{ELBO}=\E_{\x_1\sim q(\x_1\vert\x_0)}[\log p_\theta(\x_0\vert\x_1)]- \text{KL}(q(\x_T\vert \x_0)\|p(\x_T))- \sum_{t=2}^T\E_{\x_t\sim q(\x_t\vert\x_0)}\left[\text{KL}(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right]\tag{1}\label{elbo} \] 注意 \(\text{KL}(q(\x_{t-1}\vert\x_t,\x_0)\Vert p_\theta(\x_{t-1}\vert\x_t))\) 里面的两个分布都是正态分布,前者方差为 \(\tilde\beta_t\),后者方差定义为 \(\sigma_t^2\);考虑到两个方差相同的正态分布的 KL 散度正比于二者均值的 L2 距离,因此如果直接取 \(\sigma_t^2=\tilde\beta_t\),那么损失函数就是一个非常简单的 L2 Loss. 但同时,DDPM 也称取 \(\sigma_t^2=\beta_t\) 也是可以的,这就是我们之前遗留的问题。

为了理解这个问题,需要注意到 \(\eqref{elbo}\) 式仅仅是针对一个样本 \(\x_0\) 的推导,实际上真正的优化目标还要套一层对 \(\x_0\) 的期望: \[ \begin{align} \mathbb E_{\x_0\sim q(\x_0)}[\text{ELBO}]&=\mathbb E_{\x_0\sim q(\x_0)}\left[\mathbb E_{q(\x_{1:T}\vert\x_0)}\left[\log\frac{p_\theta(\x_{0:T})}{q(\x_{1:T}\vert\x_0)}\right]\right]\\ &=\mathbb E_{\x_{0:T}\sim q(\x_{0:T})}\left[\log\frac{p_\theta(\x_{0:T})}{q(\x_{0:T})}+\log q(\x_0)\right]\\ &=-\text{KL}(q(\x_{0:T})\Vert p_\theta(\x_{0:T}))-H(q(\x_0))\\ &=-\text{KL}\left(q(\x_T)\prod_{t=1}^Tq(\x_{t-1}\vert\x_t)\Bigg\Vert p_\theta(\x_T)\prod_{t=1}^Tp_\theta(\x_{t-1}\vert\x_t)\right)-H(q(\x_0))\\ &=-\text{KL}(q(\x_T)\Vert p_\theta(\x_T))-\sum_{t=1}^T\text{KL}\left(q(\x_{t-1}\vert\x_t)\Vert p_\theta(\x_{t-1}\vert\x_t)\right)-H(q(\x_0)) \end{align}\tag{2}\label{elbo-e} \] 其中 \(q(\x_0)\) 表示数据分布。其中倒数第二步能对 \(q\) 展开是因为马尔可夫过程的逆向过程依旧是马尔可夫的: \[ q(\x_{t-1}\vert\x_{t:T})=\frac{q(\x_{t-1:T})}{q(\x_{t:T})}=\frac{q(\x_{t-1})\prod_{i=t}^Tq(\x_i\vert \x_{i-1})}{q(\x_t)\prod_{i=t+1}^Tq(\x_i\vert \x_{i-1})}=\frac{q(\x_{t-1})q(\x_t\vert\x_{t-1})}{q(\x_t)}=q(\x_{t-1}\vert\x_t) \] \(\eqref{elbo-e}\) 式可以看见,\(p_\theta(\x_{t-1}\vert\x_t)\) 要近似的目标并不是 \(q(\x_{t-1}\vert\x_t,\x_0)\),而是 \(q(\x_{t-1}\vert\x_t)\),它们二者之间的关系为: \[ q(\x_{t-1}\vert\x_t)=\int q(\x_{t-1}\vert\x_t,\x_0)q(\x_0\vert\x_t)\mathrm d\x_0=\mathbb E_{q(\x_0\vert\x_t)}[q(\x_{t-1}\vert\x_t,\x_0)] \] 因此,仅当 \(q(\mathbf x_0\vert\mathbf x_t)\) 是 Dirac delta 函数时才有 \(q(\x_{t-1}\vert\x_t)=q(\x_{t-1}\vert\x_t,\x_0)\),此时取 \(\sigma_t^2=\tilde\beta_t\) 才是最优的;否则,直接取 \(\sigma_t^2=\tilde\beta_t\) 不是最优的。然而,\(q(\x_0\vert\x_t)\) 本身是 intractable 的,所以为了可以计算,DDPM 的作者考虑了两种极端情形:

  • 情形一:\(\x_0\sim\mathcal N(\mathbf 0,\mathbf I)\),可知 \(q(\x_t)=\calN(\x_t;\mathbf 0,\mathbf{I})\),于是 \(q(\x_{t-1}\vert \x_t)=q(\x_t\vert \x_{t-1})=\calN(\x_t;\sqrt{1-\beta_t}\x_{t-1},\beta_t\mathbf{I})\),所以应该取 \(\sigma_t^2=\beta_t\).
  • 情形二:\(\x_0\sim \delta(\mathbf x_0)\),即数据集只有一个样本 \(\x_0=\mathbf 0\),此时 \(q(\mathbf x_0\vert\mathbf x_t)\) 是 Dirac delta 函数,所以应该取 \(\sigma_t^2=\tilde\beta_t\).

当然,这些情形的设定本身就不合理——没有人的数据是一堆高斯噪声或者一个样本,只是为了寻找方差可能的选择而做的一些试验性假设罢了。这也为后续的工作埋下了一个伏笔——最优的方差到底应该是什么形式的呢?

iDDPM: 可学习方差

为了解决上述问题,Improved DDPM[1] 将方差作为可学习参数进行优化而非人为选取。具体而言,考虑到 \(\tilde\beta_t\)\(\beta_t\) 是两个极端情况,作者把方差参数化为二者在 log domain 的插值: \[ \sigma^2_\theta(\x_t,t)=\exp(v\log\beta_t+(1-v)\log\tilde\beta_t) \] 损失函数为: \[ \mathcal L_\text{vlb}=-\E_{\x_0\sim q(\x_0)}[\text{ELBO}]=\sum_{t=1}^T\E_{\x_0\sim q(\x_0),\x_t\sim q(\x_t\vert\x_0)}\left[\text{KL}(q(\x_{t-1}\vert\x_t,\x_0)\|p_\theta(\x_{t-1}\vert\x_t))\right] \] 回忆在 DDPM 中,作者将上式简化为: \[ \mathcal L_\text{simple}\propto \E_{t,\x_0,\x_t}[\Vert\mu_t(\x_t,\x_0)-\mu_t(\x_t,\x_\theta(\x_t,t))\Vert^2] \] 虽然 Improved DDPM 引入可学习方差后,损失函数已经不是 \(\mathcal L_\text{simple}\) 了,但它依旧沿用了 \(\mathcal L_\text{simple}\) 来训练均值。同时,用 \(\mathcal L_\text{vlb}\) 来训练方差,并且 \(\mathcal L_\text{vlb}\) 对均值做梯度截断,只用来训练方差。综上,Improved DDPM 采用下面这个混合损失函数: \[ \mathcal L_\text{hybrid}=\mathcal L_\text{simple}+\lambda\mathcal L_\text{vlb} \] 实验也证明,如果只用 \(\mathcal L_\text{vlb}\) 训练,图像生成效果欠佳。这告诉我们理论与实践还是有差距的,最好的设置还是要靠实验探索。

附:两个正态分布的 KL 散度

\(x,y\) 是两个 \(d\) 维随机向量,\(\mu_x,\mu_y\in\mathbb R^d\)\(\Sigma_x,\Sigma_y\in\mathbb R^{d\times d}\),则: \[\mathrm{KL}\big(\calN(\mu_x,\Sigma_x)\ \Vert\ \calN(\mu_y,\Sigma_y)\big)=\frac{1}{2}\left[\log\frac{|\Sigma_y|}{|\Sigma_x|}-d+\mathrm{tr}\left(\Sigma_y^{-1}\Sigma_x\right)+(\mu_y-\mu_x)^\mathsf T\Sigma_y^{-1}(\mu_y-\mu_x) \right]\]

Analytic-DPM: 解析最优方差

Improved DDPM 引入可学习方差,缓解了 DDPM 方差设置不合理的问题,但也使得训练更加困难。然而同期工作 Analytic-DPM[2]却发现,\(\sigma_t^2\) 其实有解析形式的最优解,所以我们压根不需要训练它!

在开始之前,回忆 DDIM 论文将 DDPM 的前向过程扩展为非马尔可夫过程: \[ \begin{align} &q(\x_{1:T}\vert\x_0)=q(\x_T\vert\x_0)\prod_{t=2}^T q(\x_{t-1}\vert\x_t,\x_0)\\ &q(\x_t\vert\x_0)=\calN\left(\x_t;\sqrt{\bar\alpha_t}\x_0,(1-\bar\alpha_t)\mathbf I\right)\\ &q(\x_{t-1}\vert\x_t,\x_0)=\calN\left(\x_{t-1};\mu_t(\x_t,\x_0),\lambda_t^2\mathbf I \right)\\ &\mu_t(\x_t,\x_0)=\sqrt{\bar\alpha_{t-1}}\x_0+\sqrt{1-\bar\alpha_{t-1}-\lambda_t^2}\cdot\frac{\x_t-\sqrt{\bar\alpha_t}\x_0}{\sqrt{1-\bar\alpha_t}} \end{align} \]

特别地,取 \(\lambda_t^2=\tilde\beta_t\) 就得到了原始 DDPM. 因此下面我们都基于这个更一般的扩散过程叙述。

前文提到,问题实际的优化目标是: \[ \min_{\{\mu_t,\sigma_t^2\}_{t=1}^T}\mathcal L_\text{vlb}\iff\min_{\{\mu_t,\sigma_t^2\}_{t=1}^T}\text{KL}(q(\x_{0:T})\Vert p_\theta(\x_{0:T})) \] 由于: \[ \begin{align} \text{KL}(q(\x_{0:T})\Vert p_\theta(\x_{0:T}))&=\mathbb E_{q(\x_{0:T})}\left[\log\frac{q(\x_{0:T})}{p_\theta(\x_{0:T})}\right]=-\mathbb E_{q(\x_{0:T})}[\log p_\theta(\x_{0:T})]-H(q(\x_{0:T}))\\ &=-\sum_{t=1}^T\mathbb E_{q(\x_{0:T})}[\log p_\theta(\x_{t-1}\vert\x_t)]-\mathbb E_{q(\x_{0:T})}[\log p(\x_T)]-H(q(\x_{0:T}))\\ &=\sum_{t=1}^T\mathbb E_{q(\x_{0:T})}\left[\log\frac{q(\x_{t-1}\vert\x_t)}{p_\theta(\x_{t-1}\vert\x_t)}\right]-\sum_{t=1}^T\mathbb E_{q(\x_{0:T})}[\log q(\x_{t-1}\vert\x_t)]-\mathbb E_{q(\x_{0:T})}[\log p(\x_T)]-H(q(\x_{0:T}))\\ &=\sum_{t=1}^T\text{KL}(q(\x_{t-1}\vert\x_t)\Vert p_\theta(\x_{t-1}\vert\x_t))+\sum_{t=1}^T H(q(\x_{t-1}\vert\x_t))-H(q(\x_{0:T}))-\mathbb E_{q(\x_{0:T})}[\log p(\x_T)] \end{align} \] 后三项都与 \(\theta\) 无关,因此: \[ \min_{\{\mu_t,\sigma_t^2\}_{t=1}^T}\text{KL}(q(\x_{0:T})\Vert p_\theta(\x_{0:T}))\iff \min_{\mu_t,\sigma_t^2}\text{KL}(q(\x_{t-1}\vert\x_t)\Vert p_\theta(\x_{t-1}\vert\x_t)) \] 根据 moment matching 理论——在 KL 散度下用高斯分布去拟合任何分布,等价于让二者的均值和方差相等——我们只需要求解 \(q(\x_{t-1}\vert\x_t)\)​ 的均值和方差即可。

首先计算均值: \[ \begin{align} \mathbb E_{q(\x_{t-1}\vert\x_t)}[\x_{t-1}]&=\int \x_{t-1}q(\x_{t-1}\vert\x_t)\mathrm d\x_{t-1}\\ &=\int\x_{t-1}\left[\int q(\x_{t-1}\vert\x_t,\x_0)q(\x_0\vert\x_t)\mathrm d\x_0\right]\mathrm d\x_{t-1}\\ &=\int q(\x_0\vert\x_t)\left[\int\x_{t-1}q(\x_{t-1}\vert\x_t,\x_0)\mathrm d\x_{t-1}\right]\mathrm d\x_0\\ &=\int q(\x_0\vert\x_t)\mu_t(\x_t,\x_0)\mathrm d\x_0\\ &=\int q(\x_0\vert\x_t)(A\x_0+B\x_t)\mathrm d\x_0\\ &=A\int\x_0 q(\x_0\vert\x_t)\mathrm d\x_0+B\int\x_tq(\x_0\vert\x_t)\mathrm d\x_0\\ &=A\mathbb E[\x_0\vert\x_t]+B\x_t\\ &=\mu_t(\x_t,\mathbb E[\x_0\vert\x_t]) \end{align} \] 其中 \(A,B\) 分别是 \(\mu_t(\x_t,\x_0)\)\(\x_0\)\(\x_t\) 前面的系数,即: \[ A=\sqrt{\bar\alpha_{t-1}}-\sqrt{\bar\alpha_t}\cdot\frac{\sqrt{1-\bar\alpha_{t-1}-\lambda_t^2}}{\sqrt{1-\bar\alpha_t}},\quad B=\frac{\sqrt{1-\bar\alpha_{t-1}-\lambda_t^2}}{\sqrt{1-\bar\alpha_t}} \] 注意到 DDPM 中 \(\x_\theta(\x_t,t)\) 正是用于近似 \(\mathbb E[\x_0\vert\x_t]\)​ 的,所以 DDPM 采用均值的确就是最优的均值估计

接下来计算二阶矩: \[ \begin{align} \mathbb E_{q(\mathbf x_{t-1}\vert\x_t)}\left[\x_{t-1}\x_{t-1}^T\right] &=\int\x_{t-1}\x_{t-1}^Tq(\x_{t-1}\vert\x_t)\mathrm d\x_{t-1}\\ &=\int\x_{t-1}\x_{t-1}^T\left[\int q(\x_{t-1}\vert\x_t,\x_0)q(\x_0\vert\x_t)\mathrm d\x_0\right]\mathrm d\x_{t-1}\\ &=\int q(\x_0\vert\x_t)\left[\int\x_{t-1}\x_{t-1}^Tq(\x_{t-1}\vert\x_t,\x_0)\mathrm d\x_{t-1}\right]\mathrm d\x_0\\ &=\int q(\x_0\vert\x_t)\Big[\mu_t(\x_t,\x_0)\mu_t(\x_t,\x_0)^T+\lambda_t^2\mathbf I\Big]\mathrm d\x_0\\ &=\int q(\x_0\vert\x_t)\Big[\mu_t(\x_t,\x_0)\mu_t(\x_t,\x_0)^T\Big]\mathrm d\x_0+\lambda_t^2\mathbf I\\ &=\mathbb E_{q(\x_0\vert\x_t)}\Big[\mu_t(\x_t,\x_0)\mu_t(\x_t,\x_0)^T\Big]+\lambda_t^2\mathbf I \end{align} \] 所以协方差矩阵为: \[ \text{Cov}_{q(\x_{t-1}\vert\x_t)}(\x_{t-1})=\lambda_t^2\mathbf I+\mathbb E_{q(\x_0\vert\x_t)}\Big[\mu_t(\x_t,\x_0)\mu_t(\x_t,\x_0)^T\Big]-\mu_t(\x_t,\mathbb E[\x_0\vert\x_t])\mu_t(\x_t,\mathbb E[\x_0\vert\x_t])^T \] 考虑到 \(p_\theta(\x_{t-1}\vert\x_t)\) 的协方差矩阵并不依赖于 \(\x_t\),所以在上式的基础上对 \(\x_t\) 取期望: \[ \mathbb E_{q(\x_t)}[\text{Cov}_{q(\x_{t-1}\vert\x_t)}(\x_{t-1})]=\lambda_t^2\mathbf I+\mathbb E_{q(\x_t)}\left[\mathbb E_{q(\x_0\vert\x_t)}\left[\mu_t(\x_t,\x_0)\mu_t(\x_t,\x_0)^T\right]-\mu_t(\x_t,\mathbb E[\x_0\vert\x_t])\mu_t(\x_t,\mathbb E[\x_0\vert\x_t])^T\right] \] 又考虑到 \(p_\theta(\x_{t-1}\vert\x_t)\) 的协方差矩阵是一个各向同性对角阵(单位阵的倍数),我们只需要计算上述矩阵的对角线元素的平均,即求迹再除以 \(d\)\[ \begin{align} \frac{1}{d}\text{tr}\left(\mathbb E_{q(\x_t)}[\text{Cov}_{q(\x_{t-1}\vert\x_t)}(\x_{t-1})]\right)&=\lambda_t^2+\frac{1}{d}\mathbb E_{q(\x_t)}\left[\mathbb E_{q(\x_0\vert\x_t)}\left[\Vert\mu_t(\x_t,\x_0)\Vert_2^2\right]-\Vert\mu_t(\x_t,\mathbb E[\x_0\vert\x_t])\Vert_2^2\right]\\ &=\lambda_t^2+\frac{1}{d}\mathbb E_{q(\x_t)}\left[\mathbb E_{q(\x_0\vert\x_t)}\left[\Vert A\x_0+B\x_t\Vert_2^2\right]-\Vert A\E[\x_0\vert\x_t]+B\x_t\Vert_2^2\right]\\ &=\lambda_t^2+\frac{1}{d}A^2\left(\mathbb E_{q(\x_0)}[\Vert\x_0\Vert_2^2]-\mathbb E_{q(\x_t)}[\Vert\mathbb E[\x_0\vert\x_t]\Vert_2^2]\right)\\ &=\lambda_t^2+\frac{1}{d}A^2\mathbb E_{q(\x_t)}\left[\mathbb E_{q(\x_0\vert\x_t)}\left[\Vert\x_0\Vert_2^2\right]-\Vert\mathbb E[\x_0\vert\x_t]\Vert_2^2\right]\\ &=\lambda_t^2+\frac{1}{d}A^2\mathbb E_{q(\x_t)}\left[\text{tr}\left(\text{Cov}_{q(\x_0\vert\x_t)}(\x_0)\right)\right]\\ \end{align} \] 进一步地,可以推导出 \(\E_{q(\x_t)}[\text{Cov}_{q(\x_0\vert\x_t)}(\x_0)]\) 与 score function 有如下关系: \[ \mathbb E_{q(\x_t)}[\text{Cov}_{q(\x_0\vert\x_t)}(\x_0)]=\frac{1-\bar\alpha_t}{\bar\alpha_t}\left(\mathbf I-(1-\bar\alpha_t)\mathbb E_{q(\x_t)}\left[\nabla_{\x_t}\log q(\x_t)\nabla_{\x_t}\log q(\x_t)^T\right]\right) \]

代入可得最优方差为: \[ {\sigma_t^\ast}^2=\lambda_t^2+\left(\sqrt{\frac{\bar\beta_t}{\alpha_t}}-\sqrt{\bar\beta_{t-1}-\lambda_t^2}\right)^2\left(1-\bar\beta_t\ \mathbb E_{q(\x_t)}\frac{\Vert\nabla_{\x_t}\log q_t(\x_t)\Vert^2}{d}\right) \] 可以看见,最优方差是在 \(\lambda_t^2\)​ 的基础上,增加了一个“修正项”。修正项与 score function 有关,可以代入训练好的模型。不过,由于修正项包含期望,所以需要在采样开始之前做蒙特卡洛估计

References

  1. Nichol, Alexander Quinn, and Prafulla Dhariwal. Improved denoising diffusion probabilistic models. In International Conference on Machine Learning, pp. 8162-8171. PMLR, 2021. ↩︎
  2. Bao, Fan, Chongxuan Li, Jun Zhu, and Bo Zhang. Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models. In International Conference on Learning Representations. 2021. ↩︎
  3. 苏剑林. (Aug. 12, 2022). 《生成扩散模型漫谈(七):最优扩散方差估计(上) 》[Blog post]. Retrieved from https://kexue.fm/archives/9245 ↩︎

扩散模型逆向方差的选取
https://xyfjason.github.io/blog-main/2023/01/14/扩散模型逆向方差的选取/
作者
xyfJASON
发布于
2023年1月14日
许可协议