type
status
date
slug
summary
tags
category
icon
password
info
paper
github
个人blog位置
前置知识:
latent diffusion model(stable diffusion)知乎 个人博客 paper
classifier-free guided: 知乎 个人博客 paper

背景

Diffusion model(DM)的提出将图像生成任务推到了前所未有的高度。随着应用人数的增多,人们对文生图模型有了更高的期望:生成质量更高、生成分辨率更大、推理速度更快、生成的图片与文本的alignment更强等。
更大的分辨率意味着更大的算力需求。现在的常用做法是先用DM先生成一个分辨率相对较低的图片,再结合图片超分的方法使其更加清晰。目前所用的超分算法一般采用GAN的架构,就效果而言,超分生成的图片细节不如DM直接在高分辨率生成的图片效果。
提高DM的生成分辨率,意味着更多的训练成本。虽然Latent diffusion model(LDM)[1]使得diffusion过程从像素空间转化为隐空间,大大减少训练和推理成本。但LDM方法在不明显降低图片生成质量的条件下,只能支持到4-16x的压缩比率,训练成本依旧很大。如stable diffusion(SD) v1.4(512分辨率,隐空间下采样factor为8) 的训练用了150000小时的A100 PCIe 40GB[
2],SD v2.1(768分辨率,隐空间下采样factor为8)的训练用了200000小时的A100 PCIe 40GB[3]。
能不能在采用更大采样率(意味着更低的训练和推理成本)的同时不降低生成图片质量呢? 这就是Wuerstchen (StableCascade想要解决的问题。

方法大意

区别于二阶段(two-stage) SDStableCascade是一个三阶段(three-stage)的生成方法。如下所示。可以看到StableCascade他有两个不同压缩比的text-condition LDM。下文将详细介绍StableCascade推理过程训练过程是如何利用这两个不同压缩比的text-condition LDM,以及为何这样的架构能够加速推理的同时保证图片质量。
SD
StableCascade
stageA
VAGAN
VAGAN
stageB
text-condition LDM(压缩比8x)
text-condition LDM(低压缩比4x) Unet架构
stageC
text-condition LDM (高压缩比42x)。非Unet架构,堆叠了16个ConvNext block, 内部没有downsample

推理过程

推理过程的pipeline如下图所示,从stageC开始,stageB依赖stageC的输出,stageA依赖stageB的输出。为了方便表述,不妨记stageC的text-condition LDM,stageB的text-condition LDM,stageA的VAGAN的decode为
STEP1: 通过stageC预测Semantic Compressor的输出。stageC的推理步长
式中为时间步为时的输出;为stageC的text-condition LDM为时间步为时预测的噪声, 为prompt的text embedding。
Semantic Compressor可以理解为是一个特征提取器,如resnetvit…。其作用是给定一张图片,提取其特征,维度变化为:。论文中,作者用EfficientV2(S)作为Semantic Compressor。
 
notion image
 
STEP2: 通过StageB来预测图片的隐表征。stageB的推理步长
式中为时间步为时预测的隐表征;为stageB的text-condition LDM为时间步为时预测的噪声,其中 为stageC的输出,为了增强模型的robust,会对添加一些noise使其non-perfect;为prompt的text embedding (与stageC中的是一致的)。
这里可能会有一个疑问:如何融入?在StableCascade源码实现和Wuerstchen paper不同。paper中是将 reshape后和组合一起作为condition融入。而StableCascade源码的实现是先对进行transform,随后resize为模型内部特征的shape进行相加。源码位置:https://github.com/Stability-AI/StableCascade/blob/master/modules/stage_b.py#L227C22-L227C35
STEP3: 通过StageA的decode 来将隐空间变换到像素空间
式中为生成的图片;VAGAN的decode;为stageB的输出。
就此StableCascade的推理过程结束。回顾一下:Wuerstchen (StableCascade想要解决的问题是希望用一个更大采样率(意味着更低的训练和推理成本)的同时不降低生成图片质量呢。 从上面对推理pipeline描述中我们发现StableCascade相比LDM还多了一个stage,为什么还能降低推理成本提升生成质量呢? 个人理解:StableCascade仅在高压缩率的StageC(42x压缩率)采用较大的推理步长(),而在低压缩率的stageB用了较低的推理步长,从而实现加速。之所以能这样做,是由于stageC给stageB所提供的图片特征先验能让stageB仅用较少步长就能实现较好的效果。

训练过程

StableCascade需要训练3个模型
  1. StageA的VAGAN 。由一个encoder 和一个decoder 构成。参数量18M
  1. stageB的text-condition LDM,参数量1B,架构为Unet
  1. stageC的text-condition LDM,参数量为1B;架构:堆叠了16个ConvNext block,没有进行下采样(压缩比已经很大了,作者发现再下采样会恶化生成质量)。
前面提到的Semantic Compressor采用的是开源预训练的EfficientV2,只额外训练一个1x1卷积调整维度。
训练数据格式为图文对数据,来源于improved-aesthetic LAION-5B。数据格式为: stageA VQGAN的训练目标可参考[4]
StageB和stageC的训练目标和LDM [1]中的一致。需要注意的是stageB的预测目标是VQGAN encoder的输出,StageC的预测目标是Semantic Compressor+1x1 conv的feature map
notion image

结果

下图展示了StableCascade 在不同batch下和SD2.1SDXL的推理速度对比结果。
notion image
下图展示了不同架构生成质量对比结果。
notion image
生成图片展示
notion image

小结

StableCascade的核心在于新增了一个图片隐特征的先验,从而带来推理和训练的加速。在训练阶段这个先验来源于pre-training model。在推理阶段这个先验来源于stageC对这个先验信息的估计。

reference

[4] (VQGAN) Taming transformers for high-resolution image synthesis
 
 
diffusion model(十三):DiT技术小结Matryoshka Representation Learning (俄罗斯套娃表征学习)技术小结
  • Twikoo