ControlNet 1.0:
ControlNet 1.1:
Overview
ControlNet 是基于 Stable Diffusion 的代码库开发的,以下为整个仓库的文件组织结构(排除了非代码文件,如文档、字体、测试图片等):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 . ├── annotator │ ├── canny │ ├── ckpts │ ├── hed │ ├── midas │ ├── mlsd │ ├── openpose │ ├── uniformer │ └── util.py ├── cldm │ ├── cldm.py │ ├── ddim_hacked.py │ ├── hack.py │ ├── logger.py │ └── model.py ├── config.py ├── gradio_annotator.py ├── gradio_canny2image.py ├── gradio_depth2image.py ├── gradio_fake_scribble2image.py ├── gradio_hed2image.py ├── gradio_hough2image.py ├── gradio_normal2image.py ├── gradio_pose2image.py ├── gradio_scribble2image_interactive.py ├── gradio_scribble2image.py ├── gradio_seg2image.py ├── ldm │ ├── data │ ├── models │ ├── modules │ └── util.py ├── models │ ├── cldm_v15.yaml │ └── cldm_v21.yaml ├── share.py ├── tool_add_control.py ├── tool_add_control_sd21.py ├── tool_transfer_control.py ├── tutorial_dataset.py ├── tutorial_dataset_test.py ├── tutorial_train.py └── tutorial_train_sd21.py
annotator
Annotator 指各种条件提取器,如 canny 边缘检测器、midas 深度和法线估计模型、openpose 人体姿态识别器等。这些 annotator 既有传统方法,也有基于深度学习的方法,因此有些 annotator 需要模型权重文件,权重文件应放置在 annotator/ckpts 下。
作者将每一种 annotator 分别放在对应子目录下,并且在各自的 __init__.py
中封装了 xxxDetector
类,例如 CannyDetector
、MidasDetector
、OpenposeDetector
等。这些 xxxDetector
有 __call__()
方法,接受输入图像(numpy uint8 HWC),返回检测图像(numpy uint8 HWC),因此调用起来非常的方便。
作者还在 annotator/util.py 中提供了两个工具方法:resize_image()
将图像按小边成比例缩放到指定大小附近的 64 倍率处,HWC3()
将不同通道数的图像统一为 3 通道。
下面是一个简单的代码片段示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 from PIL import Imageimport numpy as npfrom annotator.canny import CannyDetectorfrom annotator.midas import MidasDetectorfrom annotator.uniformer import UniformerDetectorfrom annotator.util import resize_image, HWC3 img_path = './test_imgs/building.png' img = HWC3(np.array(Image.open (img_path))) apply_canny = CannyDetector() img_canny = apply_canny(img, low_threshold=128 , high_threshold=200 ) img_canny = HWC3(img_canny) Image.fromarray(img_canny).save('img_canny.png' ) apply_midas = MidasDetector() img_r = resize_image(img, 768 ) img_midas_depth, img_midas_normal = apply_midas(img_r) img_midas_depth = HWC3(img_midas_depth) img_midas_normal = HWC3(img_midas_normal) Image.fromarray(img_midas_depth).save('img_midas_depth.png' ) Image.fromarray(img_midas_normal).save('img_midas_normal.png' ) apply_uniformer = UniformerDetector() img_uniformer = apply_uniformer(img) img_uniformer = HWC3(img_uniformer) Image.fromarray(img_uniformer).save('img_uniformer.png' )
ldm
ldm 摘取自 Stable Diffusion 项目,是 SD 的核心,在 SD 代码分析 一文中已经进行了分析,这里不再赘述。
cldm
cldm 里是 ControlNet 相关的代码,包括 5 个文件。作者选择把自己添加的代码单独拿出来而不是直接在 ldm 上改,很大程度上方便了我们阅读和分析。那就让我们看看作者是怎么把 ControlNet 加进 Stable Diffusion 中的吧。
cldm/cldm.py
这里是 ControlNet 的核心,定义了三个类:
1 2 class ControlledUnetModel (UNetModel ): def forward (self, x, timesteps=None , context=None , control=None , only_mid_control=False , **kwargs ): ...
ControlledUNetModel
继承自 UNetModel
(即 SD 的去噪 UNet),重写了 forward()
方法,从而支持把 ControlNet 的输出接入到 SD 之中。实现上并没有什么 tricky 的地方,就是老老实实抄过来改写。
1 2 3 4 class ControlNet (nn.Module): def __init__ (self, image_size, in_channels, model_channels, hint_channels, ... ): ... def make_zero_conv (self, channels ): ... def forward (self, x, hint, timesteps, context, **kwargs ): ...
ControlNet
就是 ControlNet 本体了,其大部分内容就是把 UNetModel
的 Encoder 部分抄过来,但是需要新添加一个将条件输入压缩到隐空间维度的 8 层卷积,称作 input_hint_block
. 另外还加了一些 zero_convs
(注意其中有一个 middle_block_out
也是零卷积)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class ControlLDM (LatentDiffusion ): def __init__ (self, control_stage_config, control_key, only_mid_control, *args, **kwargs ): super ().__init__(*args, **kwargs) self.control_model = instantiate_from_config(control_stage_config) self.control_key = control_key self.only_mid_control = only_mid_control self.control_scales = [1.0 ] * 13 def get_input (self, batch, k, bs=None , *args, **kwargs ): ... def apply_model (self, x_noisy, t, cond, *args, **kwargs ): ... def get_unconditional_conditioning (self, N ): ... def log_images (self, batch, N=4 , n_row=2 , sample=False , ... ): ... def sample_log (self, cond, batch_size, ddim, ddim_steps, **kwargs ): ... def configure_optimizers (self ): ... def low_vram_shift (self, is_diffusing ): ...
ControlLDM
则继承自 LatentDiffusion
,提供 PyTorch Lightning 的接口。其新增了组件 control_model
,即 ControlNet
的实例。另外,它还复写了一些方法,其中比较重要的包括:
get_input()
:把条件输入放入字典的 c_concat
字段中。
apply_model()
:取出字典 c_concat
字段中的条件输入,经过 ControlNet 后同其他条件一并给到 SD.
configure_optimizers()
:将 ControlNet 的参数加入优化器,另外参数 sd_locked
支持同时优化 SD 的 Decoder 部分。
low_vram_shift()
:由于 SD 的 pipeline 分为若干阶段,在 VAE 阶段可以将 UNet 放回 cpu,在扩散阶段可以将 VAE 放回 cpu,如此来节省显存消耗。
cldm/ddim_hacked.py
该文件主要修复了 ldm/models/diffusion/ddim.py
在采样时的一个 bug,可以直接平替原文件使用。
cldm/logger.py
这里定义了 ImageLogger
类,它是 PL 的一个 callback,用于在训练过程中采样图片,方便监视训练过程。
1 2 3 4 5 6 class ImageLogger (Callback ): def __init__ (self, batch_frequency=2000 , max_images=4 , ... ): ... def log_img (self, pl_module, batch, batch_idx, split="train" ): ... def on_train_batch_end (self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): if not self.disabled: self.log_img(pl_module, batch, batch_idx, split="train" )
具体而言,这个 callback 复写了 on_train_batch_end()
方法,因此在每一个训练 batch 结束后会执行里面的代码,即采样一些图片。
cldm/model.py
这里定义了一些加载模型权重的帮助函数:
get_state_dict(d)
:取出字典 d
中 key 为 state_dict
的项;如果没有 state_dict
这个 key,返回字典本身;
load_state_dict(ckpt_path, location='cpu')
:从 ckpt_path
加载模型权重,支持 safetensors
格式或普通的 torch
格式;
create_model(config_path)
:利用 instantiate_from_config()
从配置文件实例化整个模型。
cldm/hack.py
该文件 hack 了很多东西,包括:
屏蔽掉初始化 FrozenCLIPEmbedder 的时候输出的一大段信息(这段信息是未加载的权重)。
将 ldm 定义的 attention 模块的 forward()
函数换成 sliced attention 的 forward()
函数,用更长的运行时间换取更小的显存占用。
为 FrozenCLIPEmbedder 设置一个 clip_skip
参数,支持跳过倒数若干层的 CLIP 特征(许多 SD 微调模型用的是 CLIP 倒数第二层特征而不是最后一层特征)。
当文本 token 长度大于设置长度时,将其拆分成若干段分别编码再拼起来,而不是直接截断。
config.py & share.py
config.py 只有一行代码:
share.py 代码如下:
1 2 3 4 5 6 7 8 import configfrom cldm.hack import disable_verbosity, enable_sliced_attention disable_verbosity()if config.save_memory: enable_sliced_attention()
在文件开头写上 from share import *
,即可屏蔽掉实例化 FrozenCLIPEmbedder 时的冗长输出。若同时将 config.py 中的 disable_verbosity
设置为 True
,则将使用 sliced attention 来节省显存。
gradio_xxx.py
这些文件是基于 gradio 包制作的 webui,每种条件一个 webui.
值得一提的是,作者没有给出非 webui 的推理脚本,不过根据 gradio 代码,稍微改改写一个非 webui 的推理脚本也并不困难。
这些文件是训练用的,具体使用可以看 ControlNet 训练文档 ,写得非常详细。
Model weights
有了上面的基础之后,我们可以深入看一下 ControlNet 发布的权重。
Overview
ControlNet 1.0
SD1.5 (v1-5-pruned.ckpt
) + 微调了 decoder
包含 SD 和 ControlNet
ControlNet 1.1
SD1.5 (v1-5-pruned.ckpt
)
只有 ControlNet
Control-LoRAs
SDXL
只有 LoRA、input hint 等
UniControl
SD1.5 ema (v1-5-pruned-emaonly.ckpt
)
包含 SD 和 UniControl
UniControl 1.1
SD1.5 ema (v1-5-pruned-emaonly.ckpt
)
包含 SD 和 UniControl
ControlNet 1.0
在 1.0 版本中,ControlNet 权重是和 SD1.5 的权重放一起发布 的,所以一个文件就有 5.4GB.
以 control_sd15_canny.pth
为例,它就是 ControlLDM
类的 state_dict,其所有的 keys 包括:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 # Diffusion parameters alphas_cumprod alphas_cumprod_prev betas log_one_minus_alphas_cumprod logvar posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 posterior_variance sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod # FrozenCLIPEmbedder's state_dict cond_stage_model.transformer.text_model.embeddings.[xxx] cond_stage_model.transformer.text_model.encoder.[xxx] cond_stage_model.transformer.text_model.final_layer_norm.[xxx] # AutoencoderKL's state_dict first_stage_model.encoder.[xxx] first_stage_model.decoder.[xxx] first_stage_model.quant_conv.[xxx] first_stage_model.post_quant_conv.[xxx] # DiffusionWrapper's state_dict model.diffusion_model.time_embed.[xxx] model.diffusion_model.input_blocks.[xxx] model.diffusion_model.middle_block.[xxx] model.diffusion_model.output_blocks.[xxx] model.diffusion_model.out.[xxx] # ControlNet's state_dict control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.weight control_model.middle_block_out.0.bias control_model.input_hint_block.[xxx] control_model.zero_convs.[xxx]
第一部分是扩散模型的参数,并不是网络的权重
第二部分是文本编码器,即 FrozenCLIPEmbedder
的权重
第三部分是自编码器,即 AutoencoderKL
的权重
第四部分是扩散模型 UNet,即 DiffusionWrapper
的权重(前文提及,DiffusionWrapper
包裹了 UNetModel
类)
第五部分是 ControlNet
的权重,可以看到确实是在 SD UNet Encoder 的基础上增加了 input_hint_block
、zero_convs
和 middle_block_out
.(注意 middle_block_out
也是一层 zero convolution,这个名字有迷惑性)
⚠️ 论文里说 SD 是固定的,但是深入探究后发现各条件模型的 SD UNet decoder 部分(即 model.diffusion_model.output_blocks
和 model.diffusion_model.out
)竟然不一样,与 SD1.5 也不一样(最大差异达到 0.01 数量级)。推测这与代码里的配置项 sd_locked
有关,配置 sd_locked=False
时 SD UNet decoder 是可训练的。
ControlNet 1.1
在 1.1 版本中,ControlNet 权重是单独发布 的,所以一个文件只有 1.4GB.
以 control_v11p_sd15_canny.pth
为例,它只是 ControlNet
类的 state_dict,其所有的 keys 包括:
1 2 3 4 5 6 7 control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.bias control_model.middle_block_out.0.weight control_model.input_hint_block.[xxx] control_model.zero_convs.[xxx]
Control-LoRAs by StabilityAI
StabilityAI 的 Control-LoRAs 在为 ControlNet 加入了 LoRA,一个 rank128 的文件有 396MB,一个 rank256 的文件有 774 MB.
以 control-lora-canny-rank128.safetensors
为例,它包含的 keys 有:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 lora_controlnet (an empty tensor: Tensor([])) label_emb.0.0.bias label_emb.0.0.up label_emb.0.0.down label_emb.0.2.bias label_emb.0.2.up label_emb.0.2.down time_embed.0.up time_embed.0.down time_embed.0.bias time_embed.2.up time_embed.2.down time_embed.2.bias input_blocks.[xxx].up input_blocks.[xxx].down input_blocks.[xxx].bias middle_block.[xxx].up middle_block.[xxx].down middle_block.[xxx].bias input_hint_block.[xxx].weight input_hint_block.[xxx].bias middle_block_out.0.weight middle_block_out.0.bias zero_convs.[xxx].weight zero_convs.[xxx].bias
对比 ControlNet 的 keys,最明显的是多了两项:label_emb
和 lora_controlnet
,后者只是一个空 Tensor,用于方便判断加载的是否是 Control-LoRAs 的权重;前者的作用目前并不清楚(因为代码也没有开源)。其他 keys 就是在 ControlNet 的基础上为 Linear 层添加了 .up
和 .down
(即 LoRA 的权重),有些层(如归一化层)还调整了 .bias
。
UniControl by Salesforce
Salesforce 的 UniControl 旨在使得一个 ControlNet 适配多种条件输入,为此作者对 ControlNet 的结构稍加改造,引入了:
MoE Adapter:即改造 input hint blocks
Task Aware HyperNet:用于得到 task embedding
Modulated Zero Conv:能够融入 task embedding 的 zero convolution
其发布的权重有 5.8GB 大小,即和 ControlNet 1.0 一样是将 UniControl 和 SD1.5 的权重是放在一起发布的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 # Diffusion parameters alphas_cumprod alphas_cumprod_prev betas log_one_minus_alphas_cumprod logvar posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 posterior_variance sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod # FrozenCLIPEmbedder's state_dict cond_stage_model.transformer.text_model.embeddings.[xxx] cond_stage_model.transformer.text_model.encoder.[xxx] cond_stage_model.transformer.text_model.final_layer_norm.[xxx] # AutoencoderKL's state_dict first_stage_model.encoder.[xxx] first_stage_model.decoder.[xxx] first_stage_model.quant_conv.[xxx] first_stage_model.post_quant_conv.[xxx] # DiffusionWrapper's state_dict model.diffusion_model.time_embed.[xxx] model.diffusion_model.input_blocks.[xxx] model.diffusion_model.middle_block.[xxx] model.diffusion_model.output_blocks.[xxx] model.diffusion_model.out.[xxx] # ControlNet's state_dict control_model.time_embed.[xxx] control_model.input_blocks.[xxx] control_model.input_hint_block_list_moe.[xxx] control_model.input_hint_block_share.[xxx] control_model.input_hint_block_zeroconv_0.[xxx] control_model.input_hint_block_zeroconv_1.[xxx] control_model.middle_block.[xxx] control_model.middle_block_out.0.weight control_model.middle_block_out.0.bias control_model.zero_convs.[xxx] control_model.task_id_hypernet.[xxx] control_model.task_id_layernet.[xxx] control_model.task_id_layernet_zeroconv_0.[xxx] control_model.task_id_layernet_zeroconv_1.[xxx]
可以看见,前面几部分确实就是 Stable Diffusion 的参数,与 ControlNet 1.0 是一样的,不同之处在于 ControlNet,input_hint_block
增加了 MoE 相关权重,以及最后多了个 task_id_hypernet
.