Shortcut Models

背景知识

扩散模型最明显的短板就是迭代式生成的效率太低,为此人们做出了非常多的努力,大致可以概括如下:

  • 改进采样器:提出 DDIM, Euler, Heun, DPM-Solver 等一阶或高阶 ODE 求解器。这类方法的优点在于无需训练,但是很难在 10 步采样以内保持图像质量。
  • 模型蒸馏:预训练扩散模型作为老师,蒸馏一个采样步数更少的学生模型,代表工作包括 Progressive Distillation, Consistency Distillation, Reflow 等。这类方法可以保证几步甚至一步生成的图像质量,但是训练流程麻烦,不是端到端的。
  • 构建新的模型族:基于扩散模型的理论设计新的模型族,代表工作包括 Consistency Training,以及本文介绍的 Shortcut Models. 这类方法可以端到端从头训练,同时支持一步或多步生成,是很有潜力的研究方向。

在正式开始介绍 Shorcut Models 之前,有必要简单回顾一下 Flow Matching. 对训练的每一步,Flow Matching 首先从源分布和目标分布中分别采样 ,然后做线性插值 ,模型 为输入,预测速度方向 ,即: 考虑某个特定的 ,它在整个训练过程中可能被不同的 对采样出来,每次回归一个速度方向,于是最终收敛到“平均速度方向”: 训练结束后,模型 给出了空间中的一个速度场,从源分布中的任一点出发,沿该速度场运动就完成了一次采样。值得注意的是,由于模型最终学习到的是平均速度,因此采样轨迹是弯曲的,这导致大步长采样存在较大的离散误差,如下图所示。极端情况下,如果我们做一步采样,那么采样出来的其实是数据集的平均点,即 .

方法介绍

核心思想

Shortcut Models 的核心思想是将采样步长 作为条件加入网络,使得网络能够考虑到采样轨迹的曲率,进而给出跨出 步长后准确的位置。直观上,Flow Matching 可以解释为学习瞬时速度,而 Shortcut Models 意在学习 时间内的平均速度。若模型得到了充分的学习,那么当 时,模型就给出了跨越整条采样轨迹的速度方向,从而实现一步采样。

具体而言,记 表示 时刻位于 处且步长为 的 shortcut,定义为: 不难发现它具有两条性质:

  • 极限性质:当步长无限小时,平均速度趋向于瞬时速度,此时 shortcut 就是 Flow Matching 中的速度场:

  • 自一致性:跨一个 步长等价于先跨一个 步长、再跨一个 步长:

接下来我们只需要训练一个神经网络 去近似 即可。

模型训练

基于 shortcut 的两条性质,我们可以设计如下的损失函数: 可以看见损失函数由两部分构成,分别对应极限性质和自一致性。直观上,Flow Matching 部分保证在小步长下模型给出的轨迹贴合真实的采样轨迹,而自一致性部分保证模型在大步长下的轨迹逼近小步长下的轨迹。二者联合训练,实现单阶段、端到端的从头训练。

Shortcut Models 的训练过程可以解释为一个自蒸馏的过程:如果我们分阶段训练不同步长,先训练小步长,再训练大步长,那不就是 Progressive Distillation 吗!从这个角度上说,Shortcut Models 也可以视作 Progressive Distillation 的端到端训练版本。

虽然理论简单,但是实践中作者引入了不少工程设计:

  1. 时间离散化:尽管理论上 可以取连续值,但作者只在离散时间步上训练自一致性。具体而言,步长 采样自 ,而时间步 取为 的倍数。
  2. 批次分比例优化:对一个 batch 的数据,取其中 用于训练 Flow Matching,剩下 用于训练自一致性。由于自一致性的目标 包含额外的前向传播,因此训练代价比 Flow Matching 部分更高,设置比例后可以保证整个训练代价比普通的扩散模型训练仅多出 16% 左右。
  3. 提前确定 CFG:由于自一致性涉及到生成目标 ,因此原本不需要在训练时确定的超参数(如 CFG)需要提前确定下来。
  4. 使用 EMA 权重:同样的道理,生成目标 时需要用维护的 EMA 模型生成。
  5. 调节 weight decay:作者发现 weight decay 对训练的稳定性至关重要,这是因为训练早期 无法给出有效的目标,影响收敛,调节合适的 weight decay 可以缓解不稳定性。

训练及采样算法如下图所示:


Shortcut Models
https://xyfjason.github.io/blog-main/2025/02/28/Shortcut-Models/
作者
xyfJASON
发布于
2025年2月28日
许可协议