扩散模型逆向方差的选取

DDPM: 人为选取方差

首先做一个简要回顾。DDPM 定义前向马尔可夫过程为:

q(x1:T|x0)=q(x0)t=1Tq(xt|xt1)q(xt|xt1)=N(xt;1βtx0,βtI)q(xt|x0)=N(xt;α¯tx0,(1α¯t)I) 计算可得: q(xt1|xt,x0)=N(xt1;μt(xt,x0),β~tI) 仿照 q(xt1|xt,x0) 的形式,将 pθ(xt1|xt)​ 定义为: pθ(xt1|xt)=N(xt1;μt(xt,xθ(xt,t)),σt2I) 模型通过最大化 ELBO 训练: (1)ELBO=Ex1q(x1|x0)[logpθ(x0|x1)]KL(q(xT|x0)p(xT))t=2TExtq(xt|x0)[KL(q(xt1|xt,x0)pθ(xt1|xt))] 注意 KL(q(xt1|xt,x0)pθ(xt1|xt)) 里面的两个分布都是正态分布,前者方差为 β~t,后者方差定义为 σt2;考虑到两个方差相同的正态分布的 KL 散度正比于二者均值的 L2 距离,因此如果直接取 σt2=β~t,那么损失函数就是一个非常简单的 L2 Loss. 但同时,DDPM 也称取 σt2=βt 也是可以的,这就是我们之前遗留的问题。

为了理解这个问题,需要注意到 (1) 式仅仅是针对一个样本 x0 的推导,实际上真正的优化目标还要套一层对 x0 的期望: (2)Ex0q(x0)[ELBO]=Ex0q(x0)[Eq(x1:T|x0)[logpθ(x0:T)q(x1:T|x0)]]=Ex0:Tq(x0:T)[logpθ(x0:T)q(x0:T)+logq(x0)]=KL(q(x0:T)pθ(x0:T))H(q(x0))=KL(q(xT)t=1Tq(xt1|xt)pθ(xT)t=1Tpθ(xt1|xt))H(q(x0))=KL(q(xT)pθ(xT))t=1TKL(q(xt1|xt)pθ(xt1|xt))H(q(x0)) 其中 q(x0) 表示数据分布。其中倒数第二步能对 q 展开是因为马尔可夫过程的逆向过程依旧是马尔可夫的: q(xt1|xt:T)=q(xt1:T)q(xt:T)=q(xt1)i=tTq(xi|xi1)q(xt)i=t+1Tq(xi|xi1)=q(xt1)q(xt|xt1)q(xt)=q(xt1|xt) (2) 式可以看见,pθ(xt1|xt) 要近似的目标并不是 q(xt1|xt,x0),而是 q(xt1|xt),它们二者之间的关系为: q(xt1|xt)=q(xt1|xt,x0)q(x0|xt)dx0=Eq(x0|xt)[q(xt1|xt,x0)] 因此,仅当 q(x0|xt) 是 Dirac delta 函数时才有 q(xt1|xt)=q(xt1|xt,x0),此时取 σt2=β~t 才是最优的;否则,直接取 σt2=β~t 不是最优的。然而,q(x0|xt) 本身是 intractable 的,所以为了可以计算,DDPM 的作者考虑了两种极端情形:

  • 情形一:x0N(0,I),可知 q(xt)=N(xt;0,I),于是 q(xt1|xt)=q(xt|xt1)=N(xt;1βtxt1,βtI),所以应该取 σt2=βt.
  • 情形二:x0δ(x0),即数据集只有一个样本 x0=0,此时 q(x0|xt) 是 Dirac delta 函数,所以应该取 σt2=β~t.

当然,这些情形的设定本身就不合理——没有人的数据是一堆高斯噪声或者一个样本,只是为了寻找方差可能的选择而做的一些试验性假设罢了。这也为后续的工作埋下了一个伏笔——最优的方差到底应该是什么形式的呢?

iDDPM: 可学习方差

为了解决上述问题,Improved DDPM[1] 将方差作为可学习参数进行优化而非人为选取。具体而言,考虑到 β~tβt 是两个极端情况,作者把方差参数化为二者在 log domain 的插值: σθ2(xt,t)=exp(vlogβt+(1v)logβ~t) 损失函数为: Lvlb=Ex0q(x0)[ELBO]=t=1TEx0q(x0),xtq(xt|x0)[KL(q(xt1|xt,x0)pθ(xt1|xt))] 回忆在 DDPM 中,作者将上式简化为: LsimpleEt,x0,xt[μt(xt,x0)μt(xt,xθ(xt,t))2] 虽然 Improved DDPM 引入可学习方差后,损失函数已经不是 Lsimple 了,但它依旧沿用了 Lsimple 来训练均值。同时,用 Lvlb 来训练方差,并且 Lvlb 对均值做梯度截断,只用来训练方差。综上,Improved DDPM 采用下面这个混合损失函数: Lhybrid=Lsimple+λLvlb 实验也证明,如果只用 Lvlb 训练,图像生成效果欠佳。这告诉我们理论与实践还是有差距的,最好的设置还是要靠实验探索。

附:两个正态分布的 KL 散度

是两个 维随机向量,,则:

Analytic-DPM: 解析最优方差

Improved DDPM 引入可学习方差,缓解了 DDPM 方差设置不合理的问题,但也使得训练更加困难。然而同期工作 Analytic-DPM[2]却发现, 其实有解析形式的最优解,所以我们压根不需要训练它!

在开始之前,回忆 DDIM 论文将 DDPM 的前向过程扩展为非马尔可夫过程:

特别地,取 就得到了原始 DDPM. 因此下面我们都基于这个更一般的扩散过程叙述。

前文提到,问题实际的优化目标是: 由于: 后三项都与 无关,因此: 根据 moment matching 理论——在 KL 散度下用高斯分布去拟合任何分布,等价于让二者的均值和方差相等——我们只需要求解 ​ 的均值和方差即可。

首先计算均值: 其中 分别是 前面的系数,即: 注意到 DDPM 中 正是用于近似 ​ 的,所以 DDPM 采用均值的确就是最优的均值估计

接下来计算二阶矩: 所以协方差矩阵为: 考虑到 的协方差矩阵并不依赖于 ,所以在上式的基础上对 取期望: 又考虑到 的协方差矩阵是一个各向同性对角阵(单位阵的倍数),我们只需要计算上述矩阵的对角线元素的平均,即求迹再除以 进一步地,可以推导出 与 score function 有如下关系:

代入可得最优方差为: 可以看见,最优方差是在 ​ 的基础上,增加了一个“修正项”。修正项与 score function 有关,可以代入训练好的模型。不过,由于修正项包含期望,所以需要在采样开始之前做蒙特卡洛估计

References

  1. Nichol, Alexander Quinn, and Prafulla Dhariwal. Improved denoising diffusion probabilistic models. In International Conference on Machine Learning, pp. 8162-8171. PMLR, 2021. ↩︎
  2. Bao, Fan, Chongxuan Li, Jun Zhu, and Bo Zhang. Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models. In International Conference on Learning Representations. 2021. ↩︎
  3. 苏剑林. (Aug. 12, 2022). 《生成扩散模型漫谈(七):最优扩散方差估计(上) 》[Blog post]. Retrieved from https://kexue.fm/archives/9245 ↩︎

扩散模型逆向方差的选取
https://xyfjason.github.io/blog-main/2023/01/14/扩散模型逆向方差的选取/
作者
xyfJASON
发布于
2023年1月14日
许可协议