[Stable Diffusion]SD代码分析

Stable Diffusion v2.0 ~ v2.1: https://github.com/Stability-AI/StableDiffusion

Stable Diffusion v1.1 ~ v1.5: https://github.com/runwayml/stable-diffusion

Stable Diffusion v1.1 ~ v1.4: https://github.com/compvis/stable-diffusion

Latent Diffusion: https://github.com/CompVis/latent-diffusion

关于 Stable Diffusion 这些版本之间的迭代关系参见模型概览一文。

Overview

Stable Diffusion(下文简称为 SD)是在 Latent Diffusion 代码库的基础上开发的。尽管 SD 迭代了许多版本,但是代码库结构基本保持不变,因此本文只分析 SDv2 的代码库。以下为整个仓库的组织结构(排除了一些非代码文件,如文档、协议等):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
.
├── checkpoints
│   └── checkpoints.txt
├── configs
│   ├── karlo
│   └── stable-diffusion
├── ldm
│   ├── data
│   ├── models
│   ├── modules
│   └── util.py
├── scripts
│   ├── gradio
│   ├── img2img.py
│   ├── streamlit
│   ├── tests
│   └── txt2img.py
└── setup.py

可以看到整个仓库基本由三块组成——configs 存放模型推理时用的配置文件,ldm 是实现 SD 的核心代码,scripts 是一些推理脚本。其中 ldm 是 SD 的核心,因此本文着重分析 ldm 的代码。

ldm

ldm 有四个部分,其中 util.py 定义了一些工具函数和工具类;data 目录下仅有一个类,用于做深度检测,暂且搁置到一边;modules 目录下是模型网络架构的定义;models 目录下是基于 PyTorch Lightning(下文简称为 PL)的封装,便于训练和推理。

ldm/util.py

这里定义了一些工具函数和工具类,其中有两个函数值得特别提及:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))

def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

在 SD 项目中,模型的实例化都是通过调用 instantiate_from_config(config) 实现的。其中 config 是 OmegaConf 包从配置文件读入的配置字典,字典中 target 字段写要实例化的类,params 字段写类的参数,例如 configs/stable-diffusion/v2-inference-v.yaml 中有这么一段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False

那么代码就会从 ldm.modules.diffusionmodules.openaimodel 模块中 import UNetModel 这个类,并用上述参数实例化。例如:

1
2
3
4
5
6
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config

conf = OmegaConf.load('./configs/stable-diffusion/v2-inference-v.yaml')
unet = instantiate_from_config(conf.model.params.unet_config)
print(unet)

执行上述代码后就能看到我们成功实例化了一个 UNet.

这种实例化方式非常灵活,适合快速修改参数和尝试不同的模型架构,同时保持代码简洁;缺点是看代码的时候必须要找到对应配置文件才知道实例化的到底是什么,以及在 IDE 中不能支持快速跳转和类型提示。

ldm/modules

该目录下是各种网络模块的定义(即继承 torch.nn.Module 的类),将其展开:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
ldm/modules
├── attention.py
├── diffusionmodules
│   ├── __init__.py
│   ├── model.py
│   ├── openaimodel.py
│   ├── upscaling.py
│   └── util.py
├── distributions
│   ├── __init__.py
│   └── distributions.py
├── ema.py
├── encoders
│   ├── __init__.py
│   └── modules.py
├── image_degradation
│   ├── __init__.py
│   ├── bsrgan_light.py
│   ├── bsrgan.py
│   ├── utils
│   └── utils_image.py
├── karlo
│   ├── __init__.py
│   ├── diffusers_pipeline.py
│   └── kakao
└── midas
├── api.py
├── __init__.py
├── midas
└── utils.py

后三个目录与 SD 关系不大,这里只分析前面的部分。

ldm/modules/ema.py

该文件定义了 LitEMA 类,用于在训练过程中维护模型的 EMA 权重。

1
2
3
4
5
6
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True): ...
def forward(self, model): ...
def copy_to(self, model): ...
def store(self, parameters): ...
def restore(self, parameters): ...

使用方法为:

  • 初始化时传入模型,则 LitEMA 会将其参数存入自己的 buffer;
  • 每次调用 LitEMA 时传入新的模型,则 LitEMA 按照 EMA 方式更新 buffer;
  • copy_to(self, model) 方法把 buffer 复制给传入的模型;
  • store(self, parameters) 方法将传入的模型参数暂存起来;
  • restore(self, parameters) 方法将暂存的模型参数赋值给传入的参数。

于是,在训练 loop 结束、验证 loop 开始之前,我们可以使用 store()copy_to() 将模型参数换成 EMA 版本,在验证 loop 结束后使用 restore() 恢复成原版本继续训练。

ldm/modules/attention.py

该文件定义了一些与注意力机制相关的网络模块,主要包括以下几种:

1
2
3
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): ...
def forward(self, x, context=None, mask=None): ...

输入的特征 x 与条件 context 之间做交叉注意力,mask 为注意力掩码,输出特征的维度与 x 相同。如果输入 context=None,则 context 会被赋值为 x,等价于自注意力。

1
2
3
class MemoryEfficientCrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): ...
def forward(self, x, context=None, mask=None): ...

利用 xformers 包做交叉注意力,功能与 CrossAttention 相同,但更省内存。

1
2
3
4
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): ...
def forward(self, x, context=None): ...
def _forward(self, x, context=None): ...

SelfAttention + CrossAttention + FF 串成一个 block.

1
2
3
class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True): ...
def forward(self, x, context=None): ...

输入输出为图像形式 (BCHW) 的一系列 Transformer block.

ldm/modules/encoders

注意这里的 encoders 指各种文本编码器,并不是 SD 的 VAE 或者 UNet 中的编码器。

SD 在 transformer 包和 open_clip 包的基础上进行了进一步的封装,得到四种文本编码器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class FrozenT5Embedder(AbstractEncoder):
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): ...
def encode(self, text): ...

class FrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None): ...
def encode(self, text): ...

class FrozenOpenCLIPEmbedder(AbstractEncoder):
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last"): ...
def encode(self, text): ...

class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77): ...
def encode(self, text): ...

具体实现不重要,只需要知道它们都有 encode() 方法,调用即可获取输入文本的 embeddings,例如:

1
2
3
4
5
from ldm.modules.encoders.modules import FrozenCLIPEmbedder

encoder = FrozenCLIPEmbedder().cuda()
embed = encoder.encode("Hello World!")
print(embed.shape)

执行上述代码就可以得到一个 [1, 77, 768] 大小的 embedding. 实例化的过程会输出一大段未被加载的权重,这是正常现象。

SD v1.5 使用的是 FrozenCLIPEmbedder,加载的权重是 openai/clip-vit-large-patch14

SD v2.1 使用的是 FrozenOpenCLIPEmbedder,加载的是 ViT-H-14 架构在 laion2b_s32b_b79k 上预训练的权重。

ldm/modules/distributions

该目录下定义了狄拉克分布和各分量独立高斯分布:

1
2
3
4
5
6
7
8
9
10
11
class DiracDistribution(AbstractDistribution):
def __init__(self, value): ...
def sample(self): ...
def mode(self): ...

class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False): ...
def sample(self): ...
def kl(self, other=None): ...
def nll(self, sample, dims=[1,2,3]): ...
def mode(self): ...

它们都有方法 sample()mode() 用于采样以及返回均值。高斯分布还有 kl()nll() 计算 KL 散度和负对数似然。

另外定义了计算两个高斯分布之间的 KL 散度的函数:

1
def normal_kl(mean1, logvar1, mean2, logvar2): ...

ldm/modules/diffusionmodules

该目录下定义了搭建 VAE 和 UNet 的各种网络模块。

其中 model.py 仿照 UNet 设计了图像编码器和解码器,被用作 SD 的 VAE 的编解码器:

1
2
3
4
5
6
7
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, ...): ...
def forward(self, x): ...

class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, ...): ...
def forward(self, z): ...

openaimodel.py 则参考 openai 的代码设计了 SD 的 UNet:

1
2
3
class UNetModel(nn.Module):
def __init__(self, image_size, in_channels, model_channels, out_channels, ...): ...
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): ...

forward() 接受噪声图 x、时间步 timestep、以交叉注意力方式融入的条件 context 和类别标签 y(即以 adm 的方式融入模型的条件,不一定真的表示类别)。

ldm/models

上文中 ldm/modules 目录下定义了 SD 需要用到的所有网络模块,但由于 SD 项目采用的是 PyTorch Lightning 训练,因此还需要对接 PL,这就是 ldm/models 的作用。

PL 要求将整个模型封装为一个继承 pl.LightningModule 的类,其中不仅要在 __init__() 中实例化各个网络模块组件,还要重写一些方法,典型的例子有:

  • configure_optimizers(self):返回一个元组,第一个元素为 optimizers 列表,第二个元素为 schedulers 列表。
  • training_step(self, batch, batch_idx):参数为一个 batch 的数据和当前 batch 的编号,返回一个 loss 张量或一个含有 loss 字段的字典。
  • validation_step(self, batch, batch_idx):参数为一个 batch 的数据和当前 batch 的编号,返回一个 loss 张量或一个含有 loss 字段的字典。
  • on_train_batch_end(self, outputs, batch, batch_idx):这是一个 hook,用于在一个 batch 训练结束后执行里面的内容。例如,在 SD 项目中该 hook 被用来更新 EMA 权重。

具体使用方法可以查阅 Pytorch Lightning 的文档

1
2
3
4
5
6
7
8
9
ldm/models
├── autoencoder.py
└── diffusion
├── __init__.py
├── ddim.py
├── ddpm.py
├── dpm_solver
├── plms.py
└── sampling_util.py

ldm/models/autoencoders.py

该文件定义了 SD 中的 VAE 模块,主要实现了 AutoencoderKL 类:

1
2
3
4
5
6
7
8
9
10
11
class AutoencoderKL(pl.LightningModule):
def __init__(self, ddconfig, lossconfig, embed_dim, ...):
...
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
...
def encode(self, x):
def decode(self, z):
def forward(self, input, sample_posterior=True): ...

这里我特意写明了 AutoEncoderKL 里实例化的四个网络模块组件:

  • encoder:上文 Encoder 类的实例,即编码器;
  • decoder:上文 Decoder 类的实例,即解码器;
  • quant_conv:一层 1x1 卷积 nn.Conv2d,将编码器隐空间维度 (=512) 映射到 embedding 维度 (=4);
  • post_quant_conv:一层 1x1 卷积 nn.Conv2d,将 embedding 维度 (=4) 映射回解码器隐空间维度 (=512)。

forward() 的流程为:encoder 编码均值方差 → quant_conv 降维 → 从服从该均值方差的高斯分布中采样 → post_quant_conv 升维 → decoder 解码输出图像。

ldm/models/diffusion

该目录下定义了扩散模型的各种采样器,包括:

1
2
3
4
5
6
7
8
9
10
11
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): ...
def sample(self, S, batch_size, shape, conditioning=None, callback=None, ...): ...

class PLMSSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): ...
def sample(self, S, batch_size, shape, conditioning=None, callback=None, ...): ...

class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs): ...
def sample(self, S, batch_size, shape, conditioning=None, callback=None, ...): ...

这些采样器都有 sample() 方法,给定采样步数、batch size、图像大小以及其他参数,就能迭代地从高斯噪声生成最终的图像。

ldm/models/diffusion/ddpm.py

ddpm.py 是整个仓库里最长的文件,有 1800 多行,其中定义了 DDPM 类、LatentDiffusion 类、DiffusionWrapper 类以及一系列继承自 LatentDiffusion 来做微调的类,我们着重看前三个。

首先看 DiffusionWrapper

1
2
3
4
5
6
7
8
9
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): ...

可以看见,forward() 接受的输入为噪声图 x、时间步 t、以 concat 的方式融入的条件 c_concat、以交叉注意力方式融入的条件 c_crossattn 和以 adm 方式融入的条件 c_adm. 融入哪些条件由参数 conditioning_key 决定。

回忆上文中 UNet 本身就支持两种条件——以交叉注意力方式融入和以 adm 方式融入,因此 DiffusionWrapper 就是 UNet 的进一步包装,使得条件融入方式更加的灵活。猜测这个类应该是后来添加的,否则直接在 UNet 的 forward() 函数中处理更简洁。

接下来看 DDPMLatentDiffusion

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class DDPM(pl.LightningModule):
def __init__(self, unet_config, timesteps=1000, beta_schedule="linear", ...): ...
...
self.model = DiffusionWrapper(unet_config, conditioning_key)
if self.use_ema:
self.model_ema = LitEma(self.model)
...

class LatentDiffusion(DDPM):
def __init__(self, first_stage_config, cond_stage_config, num_timesteps_cond=None, ...): ...
...
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
...
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
...

def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False

def instantiate_cond_stage(self, config):
...
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
...

def encode_first_stage(self, x): ...
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
def get_first_stage_encoding(self, encoder_posterior): ...
def apply_model(self, x_noisy, t, cond, return_ids=False): ...

这里依旧特意写明了 LightningModule 里实例化的模块组件。

DDPM 类包含一个网络模块组件 model,是 DiffusionWrapper 的实例;另外还包括模型的 EMA 版本。

LatentDiffusion 类是 SD 最主要的类。它继承自 DDPM 类,在其基础上增加了两个网络模块组件:

  • cond_stage_model:上文 FrozenCLIPEmbedder 类或 FrozenOpenCLIPEmbedder 的实例,即文本编码器;
  • first_stage_model:上文 AutoencoderKL 类的实例,即 VAE.

另外,我还特意写明了几个常用的方法:

  • encode_first_stage() 即得到 VAE Encoder 输出的高斯分布;
  • get_first_stage_encoding():从高斯分布中采样,并且缩放 0.18215 倍使得隐变量方差接近 1;
  • decode_first_stage():将隐变量缩放原大小,然后通过 VAE Decoder 得到解码的图像;
  • apply_model():调用模型,即初始化时实例化的 DiffusionWrapper.

[Stable Diffusion]SD代码分析
https://xyfjason.github.io/blog-main/2023/12/01/Stable-Diffusion-SD代码分析/
作者
xyfJASON
发布于
2023年12月1日
许可协议