扩散模型蒸馏方法

Introduction

尽管扩散模型在生成质量、似然估计和训练稳定性上表现出卓越的性能,但其最大的缺点就是采样耗时。为此,许多采样器被提出以加速采样过程,例如 DDIM, Analytic-DPM, PNDM, DPM-Solver 等等,它们着眼于更精确地求解扩散 ODE,例如采用高阶的求解器并充分利用扩散 ODE 的特殊结构。然而,受制于模型本身的误差,此类 training-free 的方法再精确也难以做到 10 步以内的高质量生成。

随着领域的发展,如今扩散模型使用的网络架构已经基本收敛(特别是文生图应用上基本都用 SD 系列),这使得在原有网络的基础上蒸馏一个新的网络成为了不错的方案。所谓蒸馏,即训练一个 student 模型,其一步去噪的效果相当于原 teacher 模型多步去噪的效果。相比优化采样器的方法,基于蒸馏的方法往往能够实现 4 步、2 步甚至 1 步采样,彻底解决扩散模型采样耗时的问题。

一个自然的问题是,为什么蒸馏可以 work?换句话说,为什么我们不直接在较少步数上训练模型,偏偏要在较多步数上训练之后再蒸馏,这难道不是多此一举吗?这是因为,扩散模型的标准训练流程是依靠不断采样、每次拟合一个 \(\mathbf x_0\) 的方式来拟合 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\),这种训练方式的损失函数的方差很大,直观上训练曲线会非常振荡,不利于模型收敛。而蒸馏时 teacher 模型给出的监督信号本就是对 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\) 的近似,所以蒸馏损失的方差显著降低,训练过程更平稳,模型更容易收敛。不过,蒸馏虽然减少了训练损失的方差,但 teacher 模型对 \(\mathbb E[\mathbf x_0\vert\mathbf x_t]\) 的估计是有偏差的,所以这里存在 bias-variance trade-off.

Progressive Distillation

Progressive distillation[1] 的基本思想是每次蒸馏将采样步数缩短到原来的一半,如此进行 \(\log_2 N\) 次迭代,就可以把 \(N\) 步采样模型蒸馏为 1 步采样模型。

具体而言,假设有一个 teacher 扩散模型 \(\mathbf x_\eta(\cdot)\),我们希望从中蒸馏一个 student 模型 \(\mathbf x_\theta(\cdot)\),使其一步去噪相当于 teacher 模型两步去噪的效果。因此蒸馏时,首先从数据集中采样样本 \(\mathbf x_0\),添加噪声得到 \(\mathbf x_t\);然后用 teacher 模型实施两步 DDIM 去噪,得: \[ \begin{gather} \mathbf x_{t'}=\alpha_{t'}\mathbf x_\eta(\mathbf x_t)+\frac{\sigma_{t'}}{\sigma_t}(\mathbf x_t-\alpha_t\mathbf x_\eta(\mathbf x_t))\\ \mathbf x_{t''}=\alpha_{t''}\mathbf x_\eta(\mathbf x_{t'})+\frac{\sigma_{t''}}{\sigma_{t'}}(\mathbf x_{t'}-\alpha_{t'}\mathbf x_\eta(\mathbf x_{t'})) \end{gather} \] 那么对于 student 模型,我们希望它一步去噪就能得到 \(\mathbf x_{t''}\),即: \[ \mathbf x_{t''}=\alpha_{t''}\mathbf x_\theta(\mathbf x_t)+\frac{\sigma_{t''}}{\sigma_t}(\mathbf x_t-\alpha_t\mathbf x_\theta(\mathbf x_t)) \] 整理得: \[ \mathbf x_\theta(\mathbf x_t)=\frac{\mathbf x_{t''}-(\sigma_{t''}/\sigma_t)\mathbf x_t}{\alpha_{t''}-(\sigma_{t''}/\sigma_t)\alpha_t} \] 因此右边就是 student 模型的拟合目标,故损失函数为: \[ \mathcal L_\theta=w_t\left\Vert\mathbf x_\theta(\mathbf x_t)-\frac{\mathbf x_{t''}-(\sigma_{t''}/\sigma_t)\mathbf x_t}{\alpha_{t''}-(\sigma_{t''}/\sigma_t)\alpha_t}\right\Vert_2^2 \] 待 student 模型收敛后就完成了一轮蒸馏,采样步数减少了一半。迭代执行多轮蒸馏即可指数式地减少采样步数。

算法流程的对比图如下所示(符号略有不同,绿色高亮是与标准训练流程相比不同的地方):

Progressive distillation 蒸馏后 4 到 8 步采样的结果就可以与 DDIM 100 步的结果持平,从而大幅加速了采样过程。

Guided Diffusion Distillation

Reflow

Consistency Distillation

Consistency Distillation[4] 是训练 consistency model 的一种方式,详见 Consistency Models 一文。

Latent Consistency Distillation

LCM 和 LCM-LoRA 将 consistency distillation 应用到了 Stable Diffusion 上,基本 4 步能够生成不错的图像。

Adversarial Diffusion Distillation

InstaFlow

Distribution Matching Distillation

Trajectory Consistency Distillation

Reference

  1. Salimans, Tim, and Jonathan Ho. Progressive distillation for fast sampling of diffusion models. arXiv preprint arXiv:2202.00512 (2022). ↩︎
  2. Meng, Chenlin, Robin Rombach, Ruiqi Gao, Diederik Kingma, Stefano Ermon, Jonathan Ho, and Tim Salimans. On distillation of guided diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 14297-14306. 2023. ↩︎
  3. Liu, Xingchao, Chengyue Gong, and Qiang Liu. Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003 (2022). ↩︎
  4. Song, Yang, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. arXiv preprint arXiv:2303.01469 (2023). ↩︎
  5. Luo, Simian, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao. Latent consistency models: Synthesizing high-resolution images with few-step inference. arXiv preprint arXiv:2310.04378 (2023). ↩︎
  6. Sauer, Axel, Dominik Lorenz, Andreas Blattmann, and Robin Rombach. Adversarial diffusion distillation. arXiv preprint arXiv:2311.17042 (2023). ↩︎
  7. Liu, Xingchao, Xiwen Zhang, Jianzhu Ma, and Jian Peng. Instaflow: One step is enough for high-quality diffusion-based text-to-image generation. In The Twelfth International Conference on Learning Representations. 2023. ↩︎
  8. Yin, Tianwei, Michaël Gharbi, Richard Zhang, Eli Shechtman, Fredo Durand, William T. Freeman, and Taesung Park. One-step diffusion with distribution matching distillation. arXiv preprint arXiv:2311.18828 (2023). ↩︎
  9. Zheng, Jianbin, Minghui Hu, Zhongyi Fan, Chaoyue Wang, Changxing Ding, Dacheng Tao, and Tat-Jen Cham. Trajectory Consistency Distillation. arXiv preprint arXiv:2402.19159 (2024). ↩︎

扩散模型蒸馏方法
https://xyfjason.github.io/blog-main/2024/05/06/扩散模型蒸馏方法/
作者
xyfJASON
发布于
2024年5月6日
许可协议