One/Few-step Flow Models

扩散/流模型最明显的短板就是迭代式生成的效率太低,为此人们做出了非常多的努力,大致可以分为以下三类:

  • 改进采样器:提出 DDIM, Euler, Heun, DPM-Solver 等一阶或高阶 ODE/SDE 求解器。这类方法的优点在于无需训练,但是很难在 10 步采样以内保持图像质量。
  • 模型蒸馏:以预训练扩散模型作为老师,蒸馏一个采样步数更少的学生模型,代表工作包括 Progressive Distillation, Consistency Distillation, Reflow 等。这类方法可以保证几步甚至一步生成的图像质量,但是训练流程麻烦,不是端到端的。
  • 构建新的模型族:基于扩散/流模型的理论设计新的模型族,代表工作包括 Consistency Models, Shortcut Models, MeanFlow 等。这类方法可以端到端从头训练,同时支持一步或多步生成,是很有潜力的研究方向。本文主要介绍此类方法。

Flow Matching 回顾

首先回顾一下 Flow Matching. 在训练的每一步中,Flow Matching 从源分布和目标分布中分别采样 \(\mathbf x_0,\mathbf x_1\),然后做线性插值 \(\mathbf x_t=(1-t)\mathbf x_0+t\mathbf x_1\),模型 \(\mathbf v_\theta\)\(\mathbf x_t\) 为输入,预测速度方向 \(\mathbf v_t=\mathbf x_1-\mathbf x_0\),即: \[ \mathcal L^\text{FM}(\theta)=\mathbb E_{(\mathbf x_0,\mathbf x_1),t}\left[\Vert\mathbf v_\theta(\mathbf x_t,t)-(\mathbf x_1-\mathbf x_0)\Vert^2\right] \] 考虑某个特定的 \(\mathbf x_t\),它在整个训练过程中可能由不同的 \((\mathbf x_0,\mathbf x_1)\) 插值得到,每次回归一个速度方向,于是最终将收敛到这些速度方向的条件期望: \[ \mathbf v_\theta^\ast(\mathbf x_t,t)=\mathbb E[\mathbf v_t\vert\mathbf x_t] \] 训练结束后,模型 \(\mathbf v^\ast_\theta(\mathbf x_t,t)\) 给出了空间中的一个速度场,从源分布中的任一点出发,沿该速度场运动就完成了一次采样。形式化地说,采样过程就是求解如下 ODE 的过程: \[ \mathrm d\mathbf x_t=\mathbf v_\theta^\ast(\mathbf x_t,t)\mathrm dt \] 值得注意的是,模型最终学习到的采样轨迹是弯曲的,这导致大步长采样存在较大的离散误差,如下图所示。极端情况下,如果我们做一步采样,那么采样出来的其实是数据集的平均点: \[ \text{One-step FM Sampling:}\quad\mathbf x_0+\mathbb E[\mathbf v_0\vert\mathbf x_0]\cdot 1=\mathbb E[\mathbf x_1\vert\mathbf x_0]=\mathbb E[\mathbf x_1] \]

Shortcut Models

Shortcut Models 的核心思想是将采样步长 \(d\) 作为条件加入网络,使得网络能够考虑到采样轨迹的曲率,进而给出跨出 \(d\) 步长后准确的位置。直观上,Flow Matching 可以解释为学习瞬时速度,而 Shortcut Models 意在学习 \(d\) 时间内的平均速度。若模型得到了充分的学习,那么当 \(d=1\) 时,模型就给出了跨越整条采样轨迹的速度方向,从而实现一步采样。

Shortcut

具体而言,记 \(\mathbf s(\mathbf x_t,t,d)\) 表示 \(t\) 时刻位于 \(\mathbf x_t\) 处且步长为 \(d\) 的 shortcut,定义为: \[ \mathbf x_{t+d}=\mathbf x_t+\mathbf s(\mathbf x_t,t,d)d \] 不难发现它具有两条性质:

  • 极限性质:当步长无限小时,平均速度趋向于瞬时速度,此时 shortcut 就是 Flow Matching 中的速度场: \[ \lim_{d\to 0}\mathbf s(\mathbf x_t,t,d)=\lim_{d\to0}\frac{\mathbf x_{t+d}-\mathbf x_t}{d}=\dot{\mathbf x}_t \]

  • 自一致性:跨一个 \(2d\) 步长等价于先跨一个 \(d\) 步长、再跨一个 \(d\) 步长: \[ \mathbf s(\mathbf x_t,t,2d)=\mathbf s(\mathbf x_t,t,d)/2+\mathbf s(\mathbf x_{t+d},t+d,d)/2 \]

我们希望训练一个模型 \(\mathbf s_\theta(\mathbf x_t,t,d)\) 去近似 shortcut: \[ \mathbf s_\theta(\mathbf x_t,t,d)\xrightarrow{\text{approximate}}\mathbf s(\mathbf x_t,t,d) \] 如此即可实现一步/少步采样。

模型训练

基于 shortcut 的两条性质,设计如下的损失函数: \[ \begin{align} &\mathcal L^\text{SM}(\theta)=\mathbb E_{(\mathbf x_0,\mathbf x_1),(t,d)}\Big[ \underbrace{\Vert\mathbf s_\theta(\mathbf x_t,t,0)-(\mathbf x_1-\mathbf x_0)\Vert^2}_{\text{Flow Matching}} + \underbrace{\Vert\mathbf s_\theta(\mathbf x_t,t,2d)-\mathbf s_\text{target}\Vert^2}_{\text{Self-Consistency}} \Big]\\ \text{where}\quad&\mathbf s_\text{target}=\mathbf s_\theta(\mathbf x_t,t,d)/2+s_\theta(~\underbrace{\mathbf x_t+\mathbf s_\theta(\mathbf x_t,t,d)d}_{\mathbf x'_{t+d}}~,t+d,d)/2 \end{align} \] 可以看见损失函数由两部分构成,分别对应极限性质和自一致性。直观上,Flow Matching 部分保证在小步长下模型给出的轨迹贴合真实的采样轨迹,而自一致性部分保证模型在大步长下的轨迹逼近小步长下的轨迹。二者联合训练,实现单阶段、端到端的从头训练。

Shortcut Models 的训练过程可以解释为一个自蒸馏的过程:如果我们分阶段训练不同步长,即先训练小步长,再训练大步长,那不就是 Progressive Distillation 吗!从这个角度上说,Shortcut Models 也可以视作 Progressive Distillation 的端到端训练版本。

虽然理论简单,但是实践中作者引入了不少工程设计:

  1. 时间离散化:尽管理论上 \(d\) 可以取连续值,但作者只在离散时间步上训练自一致性。具体而言,步长 \(d\) 采样自 \(\left\{1/128,1/64,\ldots,1/2,1\right\}\),而时间步 \(t\) 取为 \(d\) 的倍数。
  2. 批次分比例优化:对一个 batch 的数据,取其中 \(3/4\) 用于训练 Flow Matching,剩下 \(1/4\) 用于训练自一致性。由于自一致性的目标 \(s_\text{target}\) 包含额外的前向传播,因此训练代价比 Flow Matching 部分更高,设置比例后可以保证整个训练代价比普通的扩散模型训练仅多出 16% 左右。
  3. 提前确定 CFG:由于自一致性涉及到生成目标 \(\mathbf s_\text{target}\),因此原本不需要在训练时确定的超参数(如 CFG)需要提前确定下来。
  4. 使用 EMA 权重:同样的道理,生成目标 \(\mathbf s_\text{target}\) 时需要用维护的 EMA 模型生成。
  5. 调节 weight decay:作者发现 weight decay 对训练的稳定性至关重要,这是因为训练早期 \(\mathbf s_\text{target}\) 无法给出有效的目标,影响收敛,调节合适的 weight decay 可以缓解不稳定性。

训练及采样算法如下图所示:

MeanFlow

Shortcut Models 人为确定离散时间步的做法略显笨拙,而 Kaiming 组的 MeanFlow 则给出了一个更优雅的解决方案。直观上,可以认为 MeanFlow 是 Shortcut Models 的连续时间版本,它将后者在离散时间步上的一致性约束推广到了连续时间上的微分方程约束。

平均速度

MeanFlow 的思想依旧是用一个双时间步模型学习 ODE 轨迹的平均速度。形式化地说,设 \(t\in[0,1]\) 表示时间步,\(\mathbf x_t\)\(t\) 的函数,表示一条轨迹。记 \(\mathbf v(\mathbf x_t,t)\) 表示 \(t\) 时刻 \(\mathbf x_t\) 处的瞬时速度,则从 \(r\) 时刻到 \(t\) 时刻的平均速度定义为: \[ \mathbf u(\mathbf x_t,r,t){\;\mathrel{\vcenter{:}}=\;}\frac{1}{t-r}\int_r^t\mathbf v(\mathbf x_\tau,\tau)\mathrm d\tau \] 根据定义,平均速度场自然满足 Shortcut Models 中提到的两条性质:

  1. 极限性质:当 \(r\to t\) 时,\(\mathbf u(\mathbf x_t,r,t)\to\mathbf v(\mathbf x_t,t)\).
  2. 自一致性:设 \(r<s<t\),则 \((t-r)\mathbf u(\mathbf x_t,r,t)=(s-r)\mathbf u(\mathbf x_s,r,s)+(t-s)\mathbf u(\mathbf x_t,s,t)\).

我们的目标就是学习一个模型 \(\mathbf u_\theta(\mathbf x_t,r,t)\) 去近似平均速度: \[ \mathbf u_\theta(\mathbf x_t,r,t)\xrightarrow{\text{approximate}}\mathbf u(\mathbf x_t,r,t) \] 如果可以做到这一点,那么推理时就可以跨任意步长采样了,包括一步采样。

MeanFlow Identity

由于平均速度的定义包含积分,我们无法直接回归它。为此,作者巧妙地将积分等式转换为了微分等式。首先移项得: \[ (t-r)\mathbf u(\mathbf x_t,r,t)=\int_r^t\mathbf v(\mathbf x_\tau,\tau)\mathrm d\tau \]\(r\)\(t\) 无关,两边对 \(t\) 求导得: \[ \mathbf u(\mathbf x_t,r,t)+(t-r)\frac{\mathrm d}{\mathrm dt}\mathbf u(\mathbf x_t,r,t)=\mathbf v(\mathbf x_t,t) \] 整理有: \[ \underbrace{\mathbf u(\mathbf x_t,r,t)}_\text{average vel.}=\underbrace{\mathbf v(\mathbf x_t,t)}_\text{instant. vel.}-(t-r)\underbrace{\frac{\mathrm d}{\mathrm dt}\mathbf u(\mathbf x_t,r,t)}_\text{time derivative} \] 称此式为 MeanFlow Identity. 其中,对时间求导一项可进一步展开如下: \[ \frac{\mathrm d}{\mathrm dt}\mathbf u(\mathbf x_t,r,t)=\partial_\mathbf x\mathbf u\cdot\frac{\mathrm d\mathbf x_t}{\mathrm dt}+\partial_r\mathbf u\cdot\frac{\mathrm dr}{\mathrm dt}+\partial_t\mathbf u\cdot\frac{\mathrm dt}{\mathrm dt}=\partial_\mathbf x\mathbf u\cdot\mathbf v(\mathbf x_t,t)+\partial_t\mathbf u \] 若写作矩阵的形式,则为: \[ \frac{\mathrm d}{\mathrm dt}\mathbf u(\mathbf x_t,r,t)=\underbrace{\bigg[\begin{matrix}\partial_\mathbf x\mathbf u&\partial_r\mathbf u&\partial_t\mathbf u\end{matrix}\bigg]}_{\text{Jacobian of }\mathbf u}\begin{bmatrix}\mathbf v(\mathbf x_t,t)\\0\\1\end{bmatrix} \]\(\mathbf u\) 的 Jacobian 矩阵与一个向量的乘积,这可通过深度学习框架 (PyTorch/JAX) 中封装好的 jvp (Jacobian-vector product) 算子高效实现。

现在,我们只需训练一个模型 \(\mathbf u_\theta(\mathbf x_t,r,t)\) 使其满足 MeanFlow Identity 即可。构造如下损失函数: \[ \begin{align} &\mathcal L(\theta)=\mathbb E_{\mathbf x_t,r,t}\left[\Vert\mathbf u_\theta(\mathbf x_t,r,t)-\text{sg}(\mathbf u_\text{tgt})\Vert_2^2\right]\\ \text{where}\quad&\mathbf u_\text{tgt}=\mathbf v(\mathbf x_t,t)-(t-r)(\partial_\mathbf x\mathbf u_\theta\cdot\mathbf v(\mathbf x_t,t)+\partial_t\mathbf u_\theta) \end{align} \] 其中 \(\text{sg}(\cdot)\) 是停止梯度操作,防止计算 Jacobian 矩阵后还需对网络二次求导,提升训练效率。综上,MeanFlow 模型的训练算法流程如下:

实践中,我们不一定采用 L2 度量,可以根据实验效果更改。作者采取了 \(\Vert\Delta\Vert_2^{2\gamma}\) 的形式,并设置权重 \(w=1/(\Vert\Delta\Vert_2^2+c)^{1-\gamma}\).

条件生成与 Guidance

传统上,CFG 是用在采样过程中的,代价是让推理时延变成了原来的两倍。MeanFlow 作者指出,我们其实可以在训练时就引入 CFG,让模型学习 CFG 后的速度场,即可在保持推理时延的同时享受 CFG 的好处。

具体而言,施加 CFG 后的速度场为: \[ \mathbf v^{\text{cfg}}(\mathbf x_t,t\vert\mathbf c){\;\mathrel{\vcenter{:}}=\;}\omega\mathbf v(\mathbf x_t,t\vert\mathbf c)+(1-\omega)\mathbf v(\mathbf x_t,t) \] 依照上一节的推导,MeanFlow Identity 相应更改为: \[ \mathbf u^\text{cfg}(\mathbf x_t,r,t\vert\mathbf c)=\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)-(t-r)\frac{\mathrm d}{\mathrm dt}\mathbf u^\text{cfg}(\mathbf x_t,r,t\vert\mathbf c) \] 其中: \[ \frac{\mathrm d}{\mathrm dt}\mathbf u^\text{cfg}(\mathbf x_t,r,t\vert\mathbf c)=\partial_\mathbf x\mathbf u^\text{cfg}\cdot\frac{\mathrm d\mathbf x_t}{\mathrm dt}+\partial_r\mathbf u^\text{cfg}\cdot\frac{\mathrm dr}{\mathrm dt}+\partial_t\mathbf u^\text{cfg}\cdot\frac{\mathrm dt}{\mathrm dt}=\partial_\mathbf x\mathbf u^\text{cfg}\cdot\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)+\partial_t\mathbf u^\text{cfg} \] 于是损失函数更改为: \[ \begin{align} &\mathcal L(\theta)=\mathbb E_{(\mathbf x_t,\mathbf c),r,t}\left[\left\Vert\mathbf u^\text{cfg}_\theta(\mathbf x_t,r,t\vert\mathbf c)-\text{sg}(\mathbf u_\text{tgt})\right\Vert_2^2\right]\\ \text{where}\quad&\mathbf u_\text{tgt}=\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)-(t-r)\left(\partial_\mathbf x\mathbf u^\text{cfg}_\theta\cdot\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)+\partial_t\mathbf u^\text{cfg}_\theta\right) \end{align} \] 看似与无条件情形没有什么区别,但这里其实有一个问题:我们在训练过程中只能得到 \(\mathbf v(\mathbf x_t,t\vert\mathbf c)\),无法得到 \(\mathbf v(\mathbf x_t,t)\),因此无法计算 \(\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)\). 不过作者注意到: \[ \begin{align} \mathbf v^\text{cfg}(\mathbf x_t,t)&=\mathbb E_\mathbf c[\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)]\\ &=\mathbb E_\mathbf c[\omega\mathbf v(\mathbf x_t,t\vert\mathbf c)+(1-\omega)\mathbf v(\mathbf x_t,t)]\\ &=\omega\mathbb E_\mathbf c[\mathbf v(\mathbf x_t,t\vert\mathbf c)]+(1-\omega)\mathbf v(\mathbf x_t,t)\\ &=\mathbf v(\mathbf x_t,t) \end{align} \] 也就是说,在施加 CFG 前后,无条件速度场并不会改变。又 \(\mathbf v^\text{cfg}(\mathbf x_t,t)\) 也可以写作 \(\mathbf u^\text{cfg}(\mathbf x_t,t,t)\),因此我们直接用无条件模型去近似之: \[ \mathbf u_\theta^\text{cfg}(\mathbf x_t,t,t)\xrightarrow{\text{approximate}}\mathbf u^\text{cfg}(\mathbf x_t,t,t)=\mathbf v(\mathbf x_t,t) \] 综合起来,训练损失为: \[ \begin{align} &\mathcal L(\theta)=\mathbb E_{(\mathbf x_t,\mathbf c),r,t}\left[\left\Vert\mathbf u^\text{cfg}_\theta(\mathbf x_t,r,t\vert\mathbf c)-\text{sg}(\mathbf u_\text{tgt})\right\Vert_2^2\right]\\ \text{where}\quad&\mathbf u_\text{tgt}=\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)-(t-r)\left(\partial_\mathbf x\mathbf u^\text{cfg}_\theta\cdot\mathbf v^\text{cfg}(\mathbf x_t,t\vert\mathbf c)+\partial_t\mathbf u^\text{cfg}_\theta\right)\\ &\mathbf v^{\text{cfg}}(\mathbf x_t,t\vert\mathbf c)=\omega\mathbf v(\mathbf x_t,t\vert\mathbf c)+(1-\omega)\mathbf u_\theta^\text{cfg}(\mathbf x_t,t,t) \end{align} \] 同时在训练过程中以 10% 的概率丢弃条件以训练无条件模型。


One/Few-step Flow Models
https://xyfjason.github.io/blog-main/2025/11/13/One-Few-step-Flow-Models/
作者
xyfJASON
发布于
2025年11月13日
许可协议