Diffusion meets Self-Supervised Learning

Overview

Denoising MIM CL SD
Classic SSL DAE BEIT, MAE, SimMIM SimCLR, MoCo DINO, iBOT
+ Diffusion l-DAE MDT, MaskDiT Disperse SD-DiT, SRA

Denoising

Classic SSL: DAE

+ Diffusion: l-DAE

Masked Image Modeling

Classic SSL: MAE

+ Diffusion: MDT

整体上看,MDT 在标准 DiT 的基础上增加了掩码操作:网络仅接收未被掩码的带噪 latent 而非全部带噪 latent,但仍然要求预测全部的干净 latent. 这一训练目标结合了去噪与掩码预测,乍一看增加了训练难度,但实际上迫使网络更加着重学习图像语义,反而有利于模型的收敛。

MDT 的网络参考了 MAE 的非对称设计:较大的 encoder 仅接收未被掩码的 token,而较小的 decoder 接收补全后的 token. 与 MAE 不同的是,由于扩散模型在推理时不会施加掩码,因此 decoder 在推理时总是见到非掩码 token;倘若训练时仍然直接补上 mask token,这就会形成一个训练-推理的不匹配。为了解决这个问题,作者设计了一个 side-interpolater 来补充被掩码的 token. 具体而言,side-interpolater 的输入为未被掩码的 token 和补充的可学习 mask token,经过一个 Transformer block 后输出预测的被掩码的 token. Decoder 在训练时接收 side-interpolater 预测的被掩码 token 而非 mask token,因此在一定程度上缓解了不匹配问题。Side-interpolater 仅在训练中使用,推理时不需要。

实践中,作者按照 DiT 的标准设置 MDT 的网络,但 decoder 的层数始终保持为 2,这延续了 MAE 的非对称思想。训练时的掩码比例固定为 30%,远小于 MAE 的 75%. 这可能是因为扩散模型本身的加噪已经是一种破坏,再采用大比例的掩码会让问题变得过于困难,对生成任务来说并不合适。

基于 MDT 的成功,作者进一步提出了 MDTv2,做了以下改进:

  1. 网络架构:给 encoder 增加 long-shortcuts,即 UNet style 的跳跃连接;给 decoder 增加 dense input-shortcuts,即从输入到 decoder 每一个 block 的跳跃连接;将 decoder 层数从 2 增加到 6;
  2. 模型优化:将 Adam 优化器替换为 Adan 优化器;采用 min-SNR weighting;
  3. 掩码策略改进:将掩码比例从固定的 30% 更改为 30%~50% 随机。

更多细节请参阅论文。

+ Diffusion: MaskDiT

Contrastive Learning

Classic SSL: SimCLR

+ Diffusion: Disperse

Self Distillation

Classic SSL: DINO

DINO 与对比学习的一些方法有相似之处,但它更适合从知识蒸馏 (knowledge distillation) 的角度理解。在知识蒸馏中,设有 teacher 和 student 网络 \(g_{\theta_t}(x),g_{\theta_s}(x)\),它们经过 softmax 后分别形成类别分布: \[ P_t(x)^{(k)}=\frac{\exp(g_{\theta_t}(x)^{(k)}/\tau_t)}{\sum_{k=1}^K\exp(g_{\theta_t}(x)^{(k)}/\tau_t)} ,\quad P_s(x)^{(k)}=\frac{\exp(g_{\theta_s}(x)^{(k)}/\tau_s)}{\sum_{k=1}^K\exp(g_{\theta_s}(x)^{(k)}/\tau_s)} \] 其中 \(\tau_t,\tau_s\) 为温度系数。我们的训练目标是让 student 分布 \(P_s\) 逼近 teacher 分布 \(P_t\),即: \[ \min_{\theta_s}~\mathbb E_{p_\text{data}(x)}\left[-\sum_{k=1}^KP_t(x)^{(k)}\log P_s(x)^{(k)}\right] \] 为了将上述知识蒸馏框架适配到自监督学习中,DINO 采用了 multi-crop 的策略:给定图像 \(x\),构造 2 个分辨率为 224 的 global views \(x_1^g,x_2^g\),以及若干分辨率为 96 的 local views. Teacher 网络总是接收 global views,而 student 网络两种都接收,于是训练目标为: \[ \min_{\theta_s}~\mathbb E_{p_\text{data}(x)}\left[\sum_{x^g\in\{x_1^g,x_2^g\}}\sum_{x'\in V(x),\,x'\neq x}\left(-\sum_{k=1}^KP_t(x^g)^{(k)}\log P_s(x')^{(k)}\right)\right] \] 其中 \(V(x)\) 表示对图像 \(x\) 构造的所有 global views 和 local views 的集合。另外,在自监督学习中我们并没有一个训练好的 teacher 网络,因此 DINO 采用 student 网络的 EMA 版本作为 teacher: \[ \theta_t\gets\lambda\theta_t+(1-\lambda)\theta_s \] 其中 \(\lambda\) 为 EMA decay 系数,在训练过程中依 cosine schedule 从 0.996 增长到 1. 进一步地,为避免崩塌,DINO 在 teacher 网络末端接入了 centering 和 sharpening 操作。

+ Diffusion: SD-DiT

如图所示,SD-DiT 建立在 MaskDiT 的基础上,因此扩散网络包含两部分——只接收可见 token 的 encoder 和拼接上 mask token 的 decoder. SD-DiT 视 encoder 部分为 student,通过 EMA 的方式引入了一个 teacher encoder 做自蒸馏。对于扩散模型而言,加噪本身就是一种自然的数据增强方式。SD-DiT 实验发现,teacher 网络应始终接收最低程度 \(\sigma_\min\) 的加噪,以达到最好的蒸馏效果。

另外,与 DINO 只蒸馏 \(\texttt{[CLS]}\) token 不同,SD-DiT 会蒸馏所有 token. 具体而言,设 \(\mathbf e_\text{T},\mathbf e_\text{S}\) 分别为 teacher 和 student encoder 的输出,\(j_\theta\) 为一个三层的 projection head,则对第 \(i\) 个可见 token 有 teacher 和 student 分布: \[ {P_\text{T}}_i^{(k)}=\frac{\exp(j_\theta({\mathbf e_\text{T}}_i)/\tau_\text{T})^{(k)}}{\sum_{k=1}^K\exp(j_\theta({\mathbf e_\text{T}}_i)/\tau_\text{T})^{(k)}},\quad {P_\text{S}}_i^{(k)}=\frac{\exp(j_\theta({\mathbf e_\text{S}}_i)/\tau_\text{S})^{(k)}}{\sum_{k=1}^K\exp(j_\theta({\mathbf e_\text{S}}_i)/\tau_\text{S})^{(k)}} \] 那么自蒸馏损失为: \[ \mathcal L_\text{D}(i)=-\sum_{k=1}^K{P_\text{T}}_i^{(k)}\log{P_\text{S}}_i^{(k)} \] 总自蒸馏损失为所有可见 token 的自蒸馏损失与 \(\texttt{[CLS]}\) token 的自蒸馏损失之和: \[ \mathcal L_\text{D}=\frac{1}{|\bar{\mathcal M}|}\sum_{i\in \bar{\mathcal M}}\mathcal L_\text{D}(i)+\mathcal L_\text{D}(\texttt{[CLS]}) \] 为避免坍塌,SD-DiT 同样采用了 DINO 中的 centering 和 sharpening 技巧。

+ Diffusion: SRA


Diffusion meets Self-Supervised Learning
https://xyfjason.github.io/blog-main/2026/01/15/Diffusion-meets-Self-Supervised-Learning/
作者
xyfJASON
发布于
2026年1月15日
许可协议