DPM-Solver

DPM-Solver

从扩散 ODE 入手

相比扩散 SDE,由于扩散 ODE 没有随机性,更适合采用大步长以加速采样,因此本文作者只考虑扩散 ODE: 代入“噪声预测模型” 那么从 解该 ODE 即是生成过程。具体而言,设已知 的值,那么从 积分得: 在之前的工作中,人们使用各种黑盒 ODE 数值求解器求解上述微分方程,相当于在用不同方式近似上面的积分(例如 Euler 法就是用矩形近似积分,improved Euler 就是用梯形近似积分),但在步数较少时它们效果并不好。因此,本文作者着重考虑了扩散 ODE 的特殊结构,提出了 DPM-Solver.

观察 1:半线性结构

作者首先注意到, 式可以拆解为两部分—— 是关于 的线性部分, 是关于 的非线性部分,其中非线性来自于神经网络。作者将这样的特殊结构称作扩散 ODE 的半线性 (semi-linear) 结构。

半线性结构带来的好处是——我们可以在最终的解中分离出一个可以解析计算的部分,这部分不必用数值法求解。具体而言,由积分因子法可得: 式中第一部分就是可以直接解析计算的部分。

推导过程:将 式中线性部分移项到左边: 两边乘以积分因子 选择特殊的积分因子使得左边可积: 两边从 积分得: 解得: 现在我们只需要确定下积分因子即可。积分因子的构造需要满足: 也即: 这是一个一阶线性齐次常微分方程,分离变量得: 两边从 积分得: 故: 注意 式中出现 的地方都是两个相除的形式: 代入 式得: 这就推出了 式。

观察 2:变量代换

对于扩散模型来说,系数 的含义不如 noise schedule 的含义明确,因此我们考虑将 替换为 . 回忆 来自扰动核的定义: 进一步地,记 表示半对数信噪比: 那么可以推出 之间有如下关系:

推导过程:对一个马尔可夫过程,考虑时刻 到时刻 的转移概率。由扰动核可解得: 写作采样形式: 写出差分格式: ,得: 注意 是单调递减的,因此根号下是正数。对比扩散 SDE: 可知: 这就推出了 的关系。

将该关系代入 式得: 果然看着简洁了一些。更进一步,由于信噪比是单调的,所以时间步 可以与 之间建立起一一映射,因此把积分变量从时间步换成对数信噪比得: 这里用 表示以半对数信噪比为参数的、对应 的模型。现在,我们只需要通过数值方法计算 一项即可,作者将其称作 的指数加权积分。

观察 3:解析计算系数

虽然我们可以直接用数值方法去近似指数加权积分一项,但是本着能求解析解就尽可能求解析解的原则,作者对该积分项做了进一步处理。考虑将 处做 阶泰勒展开: 代入 式得: 其中 . 作者指出,上式中的系数 是可以解析计算的。具体而言,实施一次分部积分: 发现系数存在递推关系,且最后一项 为: 据此可以计算出 实际应用中,我们只考虑 ,因此不必计算更多的 . 取 相应得到的求解器称作 DPM-Solver-1, DPM-Solver-2, DPM-Solver-3,它们分别是 1,2,3 阶的 ODE 求解器。

例如,当 时,代入 即得到 DPM-Solver-1 的更新式:

可以发现这个式子与 DDIM 的更新式是一模一样的,因此 DDIM 就是一阶的 DPM-Solver.

观察 4:数值估计导数项

在计算更高阶的 DPM-Solver 时,我们需要计算 式中关于神经网络的高阶导数项 . 这一项可以用数值方法估计。例如,对于一阶导,取 ,那么: 一般可以取 . 据此我们推导 DPM-Solver-2 的更新式。在 式中取 ​​,得: 其中第三个等号是因为: 用类似的方式可以推出 DPM-Solver-3,当然推导过程会更麻烦。最终 DPM-Solver-1, DPM-Solver-2, DPM-Solver-3 的算法流程如下:

总结一下,DPM-Solver 的思想就是尽可能求解析解,从而减小离散化误差:


DPM-Solver
https://xyfjason.github.io/blog-main/2024/05/14/DPM-Solver/
作者
xyfJASON
发布于
2024年5月14日
许可协议