type
status
date
slug
summary
tags
category
icon
password
ㅤ | info |
paper | |
github | |
个人blog位置 | |
stability.AI blog |
前置知识:
背景
Diffusion model(
DM
)的提出将图像生成任务推到了前所未有的高度。随着应用人数的增多,人们对文生图模型有了更高的期望:生成质量更高、生成分辨率更大、推理速度更快、生成的图片与文本的alignment更强等。更大的分辨率意味着更大的算力需求。现在的常用做法是先用
DM
先生成一个分辨率相对较低的图片,再结合图片超分的方法使其更加清晰。目前所用的超分算法一般采用GAN
的架构,就效果而言,超分生成的图片细节不如DM
直接在高分辨率生成的图片效果。提高
DM
的生成分辨率,意味着更多的训练成本。虽然Latent diffusion model(LDM
)[1]使得diffusion过程从像素空间转化为隐空间,大大减少训练和推理成本。但LDM
方法在不明显降低图片生成质量的条件下,只能支持到4-16x的压缩比率,训练成本依旧很大。如stable diffusion(SD
) v1.4(512分辨率,隐空间下采样factor为8) 的训练用了150000小时的A100 PCIe 40GB[2],
SD
v2.1(768分辨率,隐空间下采样factor为8)的训练用了200000小时的A100 PCIe 40GB[3]。能不能在采用更大采样率(意味着更低的训练和推理成本)的同时不降低生成图片质量呢? 这就是Wuerstchen (StableCascade)想要解决的问题。
方法大意
区别于二阶段(two-stage)
SD
,StableCascade
是一个三阶段(three-stage)的生成方法。如下所示。可以看到StableCascade
他有两个不同压缩比的text-condition LDM
。下文将详细介绍StableCascade
的推理过程和训练过程是如何利用这两个不同压缩比的text-condition LDM
,以及为何这样的架构能够加速推理的同时保证图片质量。ㅤ | SD | StableCascade |
stageA | VAGAN | VAGAN |
stageB | text-condition LDM (压缩比8x) | text-condition LDM (低压缩比4x) Unet 架构 |
stageC | 无 | text-condition LDM (高压缩比42x)。非Unet 架构,堆叠了16个ConvNext block, 内部没有downsample |
推理过程
推理过程的pipeline如下图所示,从stageC开始,stageB依赖stageC的输出,stageA依赖stageB的输出。为了方便表述,不妨记stageC的text-condition
LDM
为,stageB的text-condition LDM
为,stageA的VAGAN
的decode为STEP1: 通过stageC预测Semantic Compressor的输出。stageC的推理步长
式中为时间步为时的输出;为stageC的text-condition
LDM
;为时间步为时预测的噪声, ;为prompt的text embedding。Semantic Compressor可以理解为是一个特征提取器,如resnet
、vit
…。其作用是给定一张图片,提取其特征,维度变化为:。论文中,作者用EfficientV2(S)
作为Semantic Compressor。
STEP2: 通过StageB来预测图片的隐表征。stageB的推理步长
式中为时间步为时预测的隐表征;为stageB的text-condition
LDM
;为时间步为时预测的噪声,其中 ;为stageC的输出,为了增强模型的robust,会对添加一些noise使其non-perfect;为prompt的text embedding (与stageC中的是一致的)。这里可能会有一个疑问:如何融入?在StableCascade源码实现和Wuerstchen paper不同。paper中是将 reshape后和组合一起作为condition融入。而StableCascade源码的实现是先对进行transform,随后resize为模型内部特征的shape进行相加。源码位置:https://github.com/Stability-AI/StableCascade/blob/master/modules/stage_b.py#L227C22-L227C35
STEP3: 通过StageA的decode 来将隐空间变换到像素空间
式中为生成的图片;为
VAGAN
的decode;为stageB的输出。就此
StableCascade
的推理过程结束。回顾一下:Wuerstchen (StableCascade)想要解决的问题是希望用一个更大采样率(意味着更低的训练和推理成本)的同时不降低生成图片质量呢。 从上面对推理pipeline描述中我们发现StableCascade
相比LDM
还多了一个stage,为什么还能降低推理成本提升生成质量呢?
个人理解:StableCascade
仅在高压缩率的StageC(42x压缩率)采用较大的推理步长(),而在低压缩率的stageB用了较低的推理步长,从而实现加速。之所以能这样做,是由于stageC给stageB所提供的图片特征先验能让stageB仅用较少步长就能实现较好的效果。训练过程
StableCascade
需要训练3个模型- StageA的
VAGAN
。由一个encoder 和一个decoder 构成。参数量18M
- stageB的text-condition
LDM
,参数量1B,架构为Unet
- stageC的text-condition
LDM
,参数量为1B;架构:堆叠了16个ConvNext
block,没有进行下采样(压缩比已经很大了,作者发现再下采样会恶化生成质量)。
前面提到的Semantic Compressor采用的是开源预训练的
EfficientV2
,只额外训练一个1x1卷积调整维度。训练数据格式为图文对数据,来源于improved-aesthetic LAION-5B。数据格式为:
stageA VQGAN的训练目标可参考[4]
StageB和stageC的训练目标和
LDM
[1]中的一致。需要注意的是stageB的预测目标是VQGAN
encoder的输出,StageC的预测目标是Semantic Compressor+1x1 conv
的feature map 。结果
下图展示了
StableCascade
在不同batch下和SD2.1
和SDXL
的推理速度对比结果。下图展示了不同架构生成质量对比结果。
生成图片展示
小结
StableCascade
的核心在于新增了一个图片隐特征的先验,从而带来推理和训练的加速。在训练阶段这个先验来源于pre-training model。在推理阶段这个先验来源于stageC对这个先验信息的估计。reference
[4] (VQGAN) Taming transformers for high-resolution image synthesis
- 作者:莫叶何竹🍀
- 链接:http://www.myhz0606.com/article/stablecascade
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章