Consistency Models
Introduction
扩散模型虽然效果很好,但最大的缺点就是迭代式的生成效率太低。为了解决这个问题,宋飏等人在 Probability Flow ODE 的基础上建立了 consistency models,通过将 ODE 轨迹上的任何点(如 \(\mathbf x_t,\mathbf x_{t'},\mathbf x_T\))都映射到 ODE 的端点 \(\mathbf x_0\),从而实现从 \(\mathbf x_T\) 到 \(\mathbf x_0\) 的一步生成,如下图所示:
值得说明的是,扩散模型的一种常见参数化方式就是“预测原图”,这与 consistency models 有什么区别呢?扩散模型中的“原图预测模型”预测的真值是 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\),即在 \(q(\mathbf x_0\vert\mathbf x_t)\) 意义下 \(\mathbf x_0\) 的平均值,并不是 ODE 的端点 \(\mathbf x_0\),因此并不支持一步生成;相反,consistency models 预测的就是 \(\mathbf x_0\),因此支持一步生成。
Probability Flow ODE
Probability Flow ODE 的一般形式为: \[ \mathrm d\mathbf x_t=\left[\boldsymbol\mu(\mathbf x_t,t)-\frac{1}{2}\sigma^2(t)\nabla\log p_t(\mathbf x_t)\right]\mathrm dt \] 为简单起见,文章采用 EDM 的设置 \(\boldsymbol\mu(\mathbf x_t,t)\equiv \mathbf 0,\,\sigma(t)=\sqrt{2t}\),那么有: \[ \mathrm d\mathbf x_t=-t\nabla\log p_t(\mathbf x_t)\mathrm dt\tag{1}\label{ode} \] 设扩散模型训练得到的 score model 为 \(s_\phi(\mathbf x_t,t)\approx\nabla\log p_t(\mathbf x_t)\),代入上式得到 empirical PF ODE: \[ \mathrm d\mathbf x_t=-ts_\phi(\mathbf x_t,t)\mathrm dt\tag{2}\label{ode-e} \] 则从 \(T\) 到 \(\epsilon>0\) 迭代求解 \(\eqref{ode-e}\) 式的过程就是生成过程。其中 \(\epsilon\) 不取 0 是为了避免 score function 在 \(t=0\) 处的数值不稳定问题。
Consistency Models
设 Probability Flow ODE \(\eqref{ode}\) 式的轨迹为 \(\{\mathbf x_t\}_{t\in[\epsilon,T]}\),定义 consistency function 为 \(\mathbf f:\mathcal D\times[\epsilon,T]\to\mathcal D\),满足: \[ \mathbf f(\mathbf x_t,t)=\mathbf x_\epsilon,\quad\forall t\in[\epsilon,T] \] Consistency function 具有 self-consistency 的性质,即对同一个 ODE 轨迹上的任意两点 \(\mathbf x_t,\mathbf x_{t'}\),有: \[ \mathbf f(\mathbf x_t,t)=\mathbf f(\mathbf x_{t'},t'),\quad\forall t,t'\in[\epsilon,T] \] 另外还满足 boundary condition,即: \[ \mathbf f(\mathbf x_\epsilon,\epsilon)=\mathbf x_\epsilon \] 为此作者将 consistency models 参数化为了如下形式: \[ \mathbf f_\theta(\mathbf x,t)=c_\text{skip}(t)\mathbf x+c_\text{out}(t)F_\theta(\mathbf x_t,t)\approx \mathbf f(\mathbf x,t) \] 其中可微系数 \(c_\text{skip}(t),c_\text{out}(t)\) 满足 \(c_\text{skip}(\epsilon)=1,\,c_\text{out}(\epsilon)=0\),这样就能满足 boundary condition.
训练好 consistency model \(\mathbf f_\theta(\cdot,\cdot)\) 后,只需采样 \(\hat{\mathbf x}_T\sim\mathcal N(\mathbf 0,T^2\mathbf I)\),然后计算 \(\hat{\mathbf x}_\epsilon=\mathbf f_\theta(\hat{\mathbf x}_T,T)\) 即可。同时,我们也可以反复注入噪声来实现多步生成,算法如下所示:
由于 consistency model 定义了从高斯噪声到数据的一一映射,因此可以像 DDIM 一样完成隐空间插值和 zero-shot 的编辑。
Training via Distillation
Consistency models 有两种训练方式,蒸馏预训练的扩散模型 (consistency distillation) 或直接从头训练。
首先来看第一种方式。对于从训练集中采样的 \(\mathbf x\sim p_\text{data}\),按原扩散模型的 perturbation kernel 添加噪声得到 \(\mathbf x_{t_{n+1}}\),然后使用预训练的 score model 去噪计算 \(\hat{\mathbf x}_{t_n}^\phi\): \[ \hat{\mathbf x}_{t_n}^{\phi}=\mathbf x_{t_{n+1}}+(t_n-t_{n+1})\Phi(\mathbf x_{t_{n+1}},t_{n+1};\phi) \] 其中 \(\Phi(\cdot,\cdot,\phi)\) 表示 ODE solver 的一步更新。例如,若使用 Euler solver,那么: \[ \hat{\mathbf x}_{t_n}^{\phi}=\mathbf x_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}s_\phi(\mathbf x_{t_{n+1}},t_{n+1}) \] 这样,我们就得到了数据对 \((\hat{\mathbf x}_{t_n}^{\phi},\mathbf x_{t_{n+1}})\). 根据 self-consistency 性质,\(\mathbf f_\theta(\cdot,\cdot)\) 分别作用在它们之上后应该得到相等的结果,因此定义 consistency distillation loss 如下: \[ \mathcal L_\text{CD}^N(\theta,\theta^-;\phi)=\mathbb E\left[\lambda(t_n)d\left(\mathbf f_\theta(\mathbf x_{t_{n+1}},t_{n+1}),\mathbf f_{\theta^-}(\hat{\mathbf x}_{t_n}^\phi,t_n)\right)\right] \] 其中 \(d(\cdot,\cdot)\) 可以是 L2 距离、L1 距离或 LPIPS. 参数 \(\theta\) 用梯度下降优化,而 \(\theta^-\) 则是通过 EMA 更新,即: \[ \theta^-\gets\text{stopgrad}(\mu\theta^-+(1-\mu)\theta) \] 采用 EMA 更新而不是简单地设置 \(\theta^-=\theta\) 有助于稳定训练过程并提升最后的性能。【为什么?】
值得注意的是,由于我们的参数化使得模型始终满足 boundary condition \(\mathbf f_\theta(\mathbf x,\epsilon)=\mathbf x\),所以这样训练不会发生坍塌的平凡解 \(\mathbf f_\theta(\mathbf x,\epsilon)\equiv\mathbf 0\). 算法如下所示:
Training in Isolation
在 consistency distillation 中,我们用到了预训练的 \(s_\phi(\mathbf x,t)\) 来近似 score function \(\nabla\log p_t(\mathbf x)\). 但事实上,score function 有以下的无偏估计: \[ \nabla\log p_t(\mathbf x_t)=-\mathbb E\left[\frac{\mathbf x_t-\mathbf x}{t^2}\Bigg\vert\mathbf x_t\right],\quad \mathbf x\sim p_\text{data},\,\mathbf x_t\sim\mathcal N(\mathbf x;\mathbf0,t^2\mathbf I) \] 因此,我们可以用上述无偏估计代替 score function,这样就可以直接从头训练 consistency model.
具体而言,采样 \(\mathbf x\sim p_\text{data},\,\mathbf z\sim\mathcal N(\mathbf0,\mathbf I)\),计算 \(\mathbf x_{t_{n+1}}=\mathbf x+t_{n+1}\mathbf z\),将这次采样视作对上述期望的估计,有: \[ \nabla\log p_{t_{n+1}}(\mathbf x_{})\approx -\frac{\mathbf z}{t_{n+1}} \] 所以一步去噪后的样本为: \[ \begin{align} \hat{\mathbf x}_{t_n}&=\mathbf x_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}\cdot\left(-\frac{\mathbf z}{t_{n+1}}\right)\\ &=\mathbf x_{t_{n+1}}+(t_n-t_{n+1})\mathbf z\\ &=\mathbf x+t_{n+1}\mathbf z+(t_n-t_{n+1})\mathbf z\\ &=\mathbf x_t+t_n\mathbf z \end{align} \] 代入 consistency distillation loss,得到 consistency training loss 如下: \[ \mathcal L_\text{CT}^N(\theta,\theta^-)=\mathbb E\left[\lambda(t_n)d\left(\mathbf f_\theta(\mathbf x+t_{n+1}\mathbf z,t_{n+1}),\mathbf f_{\theta^-}(\mathbf x+t_n\mathbf z,t_n)\right)\right],\quad\mathbf z\sim \mathcal N(\mathbf0,\mathbf I) \] 算法如下所示: