Shortcut Models
背景知识
扩散模型最明显的短板就是迭代式生成的效率太低,为此人们做出了非常多的努力,大致可以概括如下:
- 改进采样器:提出 DDIM, Euler, Heun, DPM-Solver 等一阶或高阶 ODE 求解器。这类方法的优点在于无需训练,但是很难在 10 步采样以内保持图像质量。
- 模型蒸馏:预训练扩散模型作为老师,蒸馏一个采样步数更少的学生模型,代表工作包括 Progressive Distillation, Consistency Distillation, Reflow 等。这类方法可以保证几步甚至一步生成的图像质量,但是训练流程麻烦,不是端到端的。
- 构建新的模型族:基于扩散模型的理论设计新的模型族,代表工作包括 Consistency Training,以及本文介绍的 Shortcut Models. 这类方法可以端到端从头训练,同时支持一步或多步生成,是很有潜力的研究方向。
在正式开始介绍 Shorcut Models 之前,有必要简单回顾一下 Flow Matching. 对训练的每一步,Flow Matching 首先从源分布和目标分布中分别采样
方法介绍
核心思想
Shortcut Models 的核心思想是将采样步长
具体而言,记
极限性质:当步长无限小时,平均速度趋向于瞬时速度,此时 shortcut 就是 Flow Matching 中的速度场:
自一致性:跨一个
步长等价于先跨一个 步长、再跨一个 步长:
接下来我们只需要训练一个神经网络
模型训练
基于 shortcut 的两条性质,我们可以设计如下的损失函数:
Shortcut Models 的训练过程可以解释为一个自蒸馏的过程:如果我们分阶段训练不同步长,先训练小步长,再训练大步长,那不就是 Progressive Distillation 吗!从这个角度上说,Shortcut Models 也可以视作 Progressive Distillation 的端到端训练版本。
虽然理论简单,但是实践中作者引入了不少工程设计:
- 时间离散化:尽管理论上
可以取连续值,但作者只在离散时间步上训练自一致性。具体而言,步长 采样自 ,而时间步 取为 的倍数。 - 批次分比例优化:对一个 batch 的数据,取其中
用于训练 Flow Matching,剩下 用于训练自一致性。由于自一致性的目标 包含额外的前向传播,因此训练代价比 Flow Matching 部分更高,设置比例后可以保证整个训练代价比普通的扩散模型训练仅多出 16% 左右。 - 提前确定 CFG:由于自一致性涉及到生成目标
,因此原本不需要在训练时确定的超参数(如 CFG)需要提前确定下来。 - 使用 EMA 权重:同样的道理,生成目标
时需要用维护的 EMA 模型生成。 - 调节 weight decay:作者发现 weight decay 对训练的稳定性至关重要,这是因为训练早期
无法给出有效的目标,影响收敛,调节合适的 weight decay 可以缓解不稳定性。
训练及采样算法如下图所示: