ControlNet 1.0:
ControlNet 1.1:
Overview
ControlNet 是基于 Stable Diffusion 的代码库开发的,以下为整个仓库的文件组织结构(排除了非代码文件,如文档、字体、测试图片等):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 . ├── annotator │ ├── canny │ ├── ckpts │ ├── hed │ ├── midas │ ├── mlsd │ ├── openpose │ ├── uniformer │ └── util.py ├── cldm │ ├── cldm.py │ ├── ddim_hacked.py │ ├── hack.py │ ├── logger.py │ └── model.py ├── config.py ├── gradio_annotator.py ├── gradio_canny2image.py ├── gradio_depth2image.py ├── gradio_fake_scribble2image.py ├── gradio_hed2image.py ├── gradio_hough2image.py ├── gradio_normal2image.py ├── gradio_pose2image.py ├── gradio_scribble2image_interactive.py ├── gradio_scribble2image.py ├── gradio_seg2image.py ├── ldm │ ├── data │ ├── models │ ├── modules │ └── util.py ├── models │ ├── cldm_v15.yaml │ └── cldm_v21.yaml ├── share.py ├── tool_add_control.py ├── tool_add_control_sd21.py ├── tool_transfer_control.py ├── tutorial_dataset.py ├── tutorial_dataset_test.py ├── tutorial_train.py └── tutorial_train_sd21.py
annotator
Annotator 指各种条件提取器,如 canny 边缘检测器、midas 深度和法线估计模型、openpose 人体姿态识别器等。这些 annotator 既有传统方法,也有基于深度学习的方法,因此有些 annotator 需要模型权重文件,权重文件应放置在 annotator/ckpts 下。
作者将每一种 annotator 分别放在对应子目录下,并且在各自的 __init__.py
中封装了 xxxDetector
类,例如 CannyDetector
、MidasDetector
、OpenposeDetector
等。这些 xxxDetector
有 __call__()
方法,接受输入图像(numpy uint8 HWC),返回检测图像(numpy uint8 HWC),因此调用起来非常的方便。
作者还在 annotator/util.py 中提供了两个工具方法:resize_image()
将图像按小边成比例缩放到指定大小附近的 64 倍率处,HWC3()
将不同通道数的图像统一为 3 通道。
下面是一个简单的代码片段示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 from PIL import Imageimport numpy as npfrom annotator.canny import CannyDetectorfrom annotator.midas import MidasDetectorfrom annotator.uniformer import UniformerDetectorfrom annotator.util import resize_image, HWC3 img_path = './test_imgs/building.png' img = HWC3(np.array(Image.open (img_path))) apply_canny = CannyDetector() img_canny = apply_canny(img, low_threshold=128 , high_threshold=200 ) img_canny = HWC3(img_canny) Image.fromarray(img_canny).save('img_canny.png' ) apply_midas = MidasDetector() img_r = resize_image(img, 768 ) img_midas_depth, img_midas_normal = apply_midas(img_r) img_midas_depth = HWC3(img_midas_depth) img_midas_normal = HWC3(img_midas_normal) Image.fromarray(img_midas_depth).save('img_midas_depth.png' ) Image.fromarray(img_midas_normal).save('img_midas_normal.png' ) apply_uniformer = UniformerDetector() img_uniformer = apply_uniformer(img) img_uniformer = HWC3(img_uniformer) Image.fromarray(img_uniformer).save('img_uniformer.png' )
ldm
ldm 摘取自 Stable Diffusion 项目,是 SD 的核心,在 SD 代码分析 一文中已经进行了分析,这里不再赘述。
cldm
cldm 里是 ControlNet 相关的代码,包括 5 个文件。作者选择把自己添加的代码单独拿出来而不是直接在 ldm 上改,很大程度上方便了我们阅读和分析。那就让我们看看作者是怎么把 ControlNet 加进 Stable Diffusion 中的吧。
cldm/cldm.py
这里是 ControlNet 的核心,定义了三个类:
1 2 class ControlledUnetModel (UNetModel ): def forward (self, x, timesteps=None , context=None , control=None , only_mid_control=False , **kwargs ): ...
ControlledUNetModel
继承自 UNetModel
(即 SD 的去噪 UNet),重写了 forward()
方法,从而支持把 ControlNet 的输出接入到 SD 之中。实现上并没有什么 tricky 的地方,就是老老实实抄过来改写。
1 2 3 4 class ControlNet (nn.Module): def __init__ (self, image_size, in_channels, model_channels, hint_channels, ... ): ... def make_zero_conv (self, channels ): ... def forward (self, x, hint, timesteps, context, **kwargs ): ...
ControlNet
就是 ControlNet 本体了,其大部分内容就是把 UNetModel
的 Encoder 部分抄过来,但是需要新添加一个将条件输入压缩到隐空间维度的 8 层卷积,称作 input_hint_block
. 另外还加了一些 zero_convs
(注意其中有一个 middle_block_out
也是零卷积)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class ControlLDM (LatentDiffusion ): def __init__ (self, control_stage_config, control_key, only_mid_control, *args, **kwargs ): super ().__init__(*args, **kwargs) self.control_model = instantiate_from_config(control_stage_config) self.control_key = control_key self.only_mid_control = only_mid_control self.control_scales = [1.0 ] * 13 def get_input (self, batch, k, bs=None , *args, **kwargs ): ... def apply_model (self, x_noisy, t, cond, *args, **kwargs ): ... def get_unconditional_conditioning (self, N ): ... def log_images (self, batch, N=4 , n_row=2 , sample=False , ... ): ... def sample_log (self, cond, batch_size, ddim, ddim_steps, **kwargs ): ... def configure_optimizers (self ): ... def low_vram_shift (self, is_diffusing ): ...
ControlLDM
则继承自 LatentDiffusion
,提供 PyTorch Lightning 的接口。其新增了组件 control_model
,即 ControlNet
的实例。另外,它还复写了一些方法,其中比较重要的包括:
get_input()
:把条件输入放入字典的 c_concat
字段中。
apply_model()
:取出字典 c_concat
字段中的条件输入,经过 ControlNet 后同其他条件一并给到 SD.
configure_optimizers()
:将 ControlNet 的参数加入优化器,另外参数 sd_locked
支持同时优化 SD 的 Decoder 部分。
low_vram_shift()
:由于 SD 的 pipeline 分为若干阶段,在 VAE 阶段可以将 UNet 放回 cpu,在扩散阶段可以将 VAE 放回 cpu,如此来节省显存消耗。
cldm/ddim_hacked.py
该文件主要修复了 ldm/models/diffusion/ddim.py
在采样时的一个 bug,可以直接平替原文件使用。
cldm/logger.py
这里定义了 ImageLogger
类,它是 PL 的一个 callback,用于在训练过程中采样图片,方便监视训练过程。
1 2 3 4 5 6 class ImageLogger (Callback ): def __init__ (self, batch_frequency=2000 , max_images=4 , ... ): ... def log_img (self, pl_module, batch, batch_idx, split="train" ): ... def on_train_batch_end (self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): if not self.disabled: self.log_img(pl_module, batch, batch_idx, split="train" )
具体而言,这个 callback 复写了 on_train_batch_end()
方法,因此在每一个训练 batch 结束后会执行里面的代码,即采样一些图片。
cldm/model.py
这里定义了一些加载模型权重的帮助函数:
get_state_dict(d)
:取出字典 d
中 key 为 state_dict
的项;如果没有 state_dict
这个 key,返回字典本身;
load_state_dict(ckpt_path, location='cpu')
:从 ckpt_path
加载模型权重,支持 safetensors
格式或普通的 torch
格式;
create_model(config_path)
:利用 instantiate_from_config()
从配置文件实例化整个模型。
cldm/hack.py
该文件 hack 了很多东西,包括:
屏蔽掉初始化 FrozenCLIPEmbedder 的时候输出的一大段信息(这段信息是未加载的权重)。
将 ldm 定义的 attention 模块的 forward()
函数换成 sliced attention 的 forward()
函数,用更长的运行时间换取更小的显存占用。
为 FrozenCLIPEmbedder 设置一个 clip_skip
参数,支持跳过倒数若干层的 CLIP 特征(许多 SD 微调模型用的是 CLIP 倒数第二层特征而不是最后一层特征)。
当文本 token 长度大于设置长度时,将其拆分成若干段分别编码再拼起来,而不是直接截断。
config.py & share.py
config.py 只有一行代码:
share.py 代码如下:
1 2 3 4 5 6 7 8 import configfrom cldm.hack import disable_verbosity, enable_sliced_attention disable_verbosity()if config.save_memory: enable_sliced_attention()
在文件开头写上 from share import *
,即可屏蔽掉实例化 FrozenCLIPEmbedder 时的冗长输出。若同时将 config.py 中的 disable_verbosity
设置为 True
,则将使用 sliced attention 来节省显存。
gradio_xxx.py
这些文件是基于 gradio 包制作的 webui,每种条件一个 webui.
值得一提的是,作者没有给出非 webui 的推理脚本,不过根据 gradio 代码,稍微改改写一个非 webui 的推理脚本也并不困难。
这些文件是训练用的,具体使用可以看 ControlNet 训练文档 ,写得非常详细。
Model weights
有了上面的基础之后,我们可以深入看一下 ControlNet 发布的权重。
Overview
模型
基于 SD 权重
发布权重
ControlNet 1.0
SD1.5 (v1-5-pruned.ckpt
) + 微调了 decoder
包含 SD 和 ControlNet
ControlNet 1.1
SD1.5 (v1-5-pruned.ckpt
)
只有 ControlNet
Control-LoRAs
SDXL
只有 LoRA、input hint 等
UniControl
SD1.5 ema (v1-5-pruned-emaonly.ckpt
)
包含 SD 和 UniControl
UniControl 1.1
SD1.5 ema (v1-5-pruned-emaonly.ckpt
)
包含 SD 和 UniControl
ControlNet 1.0
在 1.0 版本中,ControlNet 权重是和 SD1.5 的权重放一起发布 的,所以一个文件就有 5.4GB.
以 control_sd15_canny.pth
为例,它就是 ControlLDM
类的 state_dict,其所有的 keys 包括:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 # Diffusion parameters alphas_cumprod alphas_cumprod_prev betas log_one_minus_alphas_cumprod logvar posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 posterior_variance sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod # FrozenCLIPEmbedder's state_dict cond_stage_model.transformer.text_model.embeddings.[xxx] cond_stage_model.transformer.text_model.encoder.[xxx] cond_stage_model.transformer.text_model.final_layer_norm.[xxx] # AutoencoderKL's state_dict first_stage_model.encoder.[xxx] first_stage_model.decoder.[xxx] first_stage_model.quant_conv.[xxx] first_stage_model.post_quant_conv.[xxx] # DiffusionWrapper's state_dict model.diffusion_model.time_embed.[xxx] model.diffusion_model.input_blocks.[xxx] model.diffusion_model.middle_block.[xxx] model.diffusion_model.output_blocks.[xxx] model.diffusion_model.out.[xxx] # ControlNet's state_dict control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.weight control_model.middle_block_out.0.bias control_model.input_hint_block.[xxx] control_model.zero_convs.[xxx]
第一部分是扩散模型的参数,并不是网络的权重
第二部分是文本编码器,即 FrozenCLIPEmbedder
的权重
第三部分是自编码器,即 AutoencoderKL
的权重
第四部分是扩散模型 UNet,即 DiffusionWrapper
的权重(前文提及,DiffusionWrapper
包裹了 UNetModel
类)
第五部分是 ControlNet
的权重,可以看到确实是在 SD UNet Encoder 的基础上增加了 input_hint_block
、zero_convs
和 middle_block_out
.(注意 middle_block_out
也是一层 zero convolution,这个名字有迷惑性)
⚠️ 论文里说 SD 是固定的,但是深入探究后发现各条件模型的 SD UNet decoder 部分(即 model.diffusion_model.output_blocks
和 model.diffusion_model.out
)竟然不一样,与 SD1.5 也不一样(最大差异达到 0.01 数量级)。推测这与代码里的配置项 sd_locked
有关,配置 sd_locked=False
时 SD UNet decoder 是可训练的。
ControlNet 1.1
在 1.1 版本中,ControlNet 权重是单独发布 的,所以一个文件只有 1.4GB.
以 control_v11p_sd15_canny.pth
为例,它只是 ControlNet
类的 state_dict,其所有的 keys 包括:
1 2 3 4 5 6 7 control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.bias control_model.middle_block_out.0.weight control_model.input_hint_block.[xxx] control_model.zero_convs.[xxx]
Control-LoRAs by StabilityAI
StabilityAI 的 Control-LoRAs 在为 ControlNet 加入了 LoRA,一个 rank128 的文件有 396MB,一个 rank256 的文件有 774 MB.
以 control-lora-canny-rank128.safetensors
为例,它包含的 keys 有:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 lora_controlnet (an empty tensor: Tensor([])) label_emb.0.0.bias label_emb.0.0.up label_emb.0.0.down label_emb.0.2.bias label_emb.0.2.up label_emb.0.2.down time_embed.0.up time_embed.0.down time_embed.0.bias time_embed.2.up time_embed.2.down time_embed.2.bias input_blocks.[xxx].up input_blocks.[xxx].down input_blocks.[xxx].bias middle_block.[xxx].up middle_block.[xxx].down middle_block.[xxx].bias input_hint_block.[xxx].weight input_hint_block.[xxx].bias middle_block_out.0.weight middle_block_out.0.bias zero_convs.[xxx].weight zero_convs.[xxx].bias
对比 ControlNet 的 keys,最明显的是多了两项:label_emb
和 lora_controlnet
,后者只是一个空 Tensor,用于方便判断加载的是否是 Control-LoRAs 的权重;前者的作用目前并不清楚(因为代码也没有开源)。其他 keys 就是在 ControlNet 的基础上为 Linear 层添加了 .up
和 .down
(即 LoRA 的权重),有些层(如归一化层)还调整了 .bias
。
UniControl by Salesforce
Salesforce 的 UniControl 旨在使得一个 ControlNet 适配多种条件输入,为此作者对 ControlNet 的结构稍加改造,引入了:
MoE Adapter:即改造 input hint blocks
Task Aware HyperNet:用于得到 task embedding
Modulated Zero Conv:能够融入 task embedding 的 zero convolution
其发布的权重有 5.8GB 大小,即和 ControlNet 1.0 一样是将 UniControl 和 SD1.5 的权重是放在一起发布的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 # Diffusion parameters alphas_cumprod alphas_cumprod_prev betas log_one_minus_alphas_cumprod logvar posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 posterior_variance sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod # FrozenCLIPEmbedder's state_dict cond_stage_model.transformer.text_model.embeddings.[xxx] cond_stage_model.transformer.text_model.encoder.[xxx] cond_stage_model.transformer.text_model.final_layer_norm.[xxx] # AutoencoderKL's state_dict first_stage_model.encoder.[xxx] first_stage_model.decoder.[xxx] first_stage_model.quant_conv.[xxx] first_stage_model.post_quant_conv.[xxx] # DiffusionWrapper's state_dict model.diffusion_model.time_embed.[xxx] model.diffusion_model.input_blocks.[xxx] model.diffusion_model.middle_block.[xxx] model.diffusion_model.output_blocks.[xxx] model.diffusion_model.out.[xxx] # ControlNet's state_dict control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.input_hint_block_list_moe.[xxx] control_model.input_hint_block_share.[xxx] control_model.input_hint_block_zeroconv_0.[xxx] control_model.input_hint_block_zeroconv_1.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.weight control_model.middle_block_out.0.bias control_model.zero_convs.[xxx] control_model.task_id_hypernet.[xxx] control_model.task_id_layernet.[xxx] control_model.task_id_layernet_zeroconv_0.[xxx] control_model.task_id_layernet_zeroconv_1.[xxx]
可以看见,前面几部分确实就是 Stable Diffusion 的参数,与 ControlNet 1.0 是一样的,不同之处在于 ControlNet,input_hint_block
增加了 MoE 相关权重,以及最后多了个 task_id_hypernet
.