type
status
date
slug
summary
tags
category
icon
password
ㅤ | URL |
paper | |
github |
背景
在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。
Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即。从而简化整体的pipeline。
此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:
下面我们来看他是怎么做的吧!
方法大意
classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。
模型 | 训练目标 | 实现功能 | 训练数据 |
DM (DDPM, DDIM) | 从服从高斯分布的噪声中生成图片 | 图片 | |
classifier-guided DM | 和分类器 | 从服从高斯分布的噪声中生成特定类别的图片 | DM:图片 分类器:图片-标签对 |
classifier-free DM | 从服从高斯分布的噪声中生成符合文本描述的图片 | 图片-文本对 |
- 对于训练来估计在时间上添加的噪声,再根据采样公式推出,从而实现图片生成。训练数据只需要准备图片即可。
- 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
- 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。
回答3个问题深入理解classifier-free DM
- 模型如何融入类别信息(或语义信息)
- 如何训练与
- 如何进行采样生成
模型如何融入类别信息(或语义信息)
采用交叉注意力机制融入
我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个
embedding layer
。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding
。随后需要将text embedding
和原本模型中的image representation
进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention
。具体来说就是将text embedding
作为注意力机制中的key
和value
,原始的图片表征作为query
。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。class SpatialCrossAttention(nn.Module): def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None: super(SpatialCrossAttention, self).__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0) self.to_q = nn.Linear(context_dim, hidden_dim, bias=False) self.to_k = nn.Linear(context_dim, hidden_dim, bias=False) self.to_v = nn.Linear(context_dim, hidden_dim, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x, context=None): x_q = self.proj_in(x) b, c, h, w = x_q.shape x_q = rearrange(x_q, "b c h w -> b (h w) c") if context is None: context = x_q if context.ndim == 2: context = rearrange(context, "b c -> b () c") q = self.to_q(x_q) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads) out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w) out = self.to_out(out) return out
基于channel-wise attention融入
该融入方法与
time-embedding
的融入方法相同,在时间中往往会预先和time-embedding
进行融合,再融入到图片特征中,伪代码如下:# mixture time-embedding and label embedding t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] emb_out = self.emb_layers(emb).type(h.dtype) # h is image feature scale, shift = th.chunk(emb_out, 2, dim=1) # One half of the embedding is used for scaling and the other half for offset h = h * (1 + scale) + shift
基于channel-wise的融入粒度没有
CrossAttention
细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention
的方法。如何训练与
的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系的数据不宜过多,否则模型生成结果的保真度会降低。反之,若数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。
有两个实践中的trick需要注意:
- 在实践中,为了统一和 两种情形,通常会给定一个的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
- 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为的embedding即可。另外
如何进行采样生成
classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的用下式代替。
下面给出详细的推导过程:
首先根据贝叶斯公式有
当我们得到,参考classifier-guided的式(17)
可得
后面的采样过程与之前的方式一致。
结语
本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。
参考文献
附录
式12推导验证
- 作者:莫叶何竹🍀
- 链接:http://myhz0606.com/article/classifier_free
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章
diffusion model(一):DDPM技术小结 (denoising diffusion probabilistic)

diffusion model(二):DDIM技术小结 (denoising diffusion implicit model)

diffusion model(三):classifier guided diffusion model

diffusion model(五):LDM: 在隐空间用diffusion model合成高质量图片

diffusion model(六):Dalle2 技术小结

diffusion model(七):diffusion model is a zero-shot classifier
