type
status
date
slug
summary
tags
category
icon
password
URL
paper
github

背景

在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。
Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即ϵθ(xt,t)ϵ^θ(xt,y,t)\epsilon_{\theta}(x_t, t) \rightarrow \hat{\epsilon}_{\theta}(x_t, y, t)。从而简化整体的pipeline。
此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:
下面我们来看他是怎么做的吧!

方法大意

classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。
模型
训练目标
实现功能
训练数据
DM (DDPM, DDIM)
ϵθ(xt,t)\epsilon_{\theta}(x_t, t)
从服从高斯分布的噪声中生成图片
图片
classifier-guided DM
ϵθ(xt,t)\epsilon_{\theta}(x_t, t)和分类器p(yxt)p(y|x_t)
从服从高斯分布的噪声中生成特定类别的图片
DM:图片 分类器:图片-标签对
classifier-free DM
ϵθ(xt,y,t),ϵθ(xt,t)\epsilon_{\theta}(x_t, y,t), \epsilon_{\theta}(x_t, t)
从服从高斯分布的噪声中生成符合文本描述的图片
图片-文本对
  • 对于训练ϵθ(xt,t)\epsilon_{\theta}(x_t, t)来估计xtx_t在时间tt上添加的噪声,再根据采样公式推出xt1x_{t-1},从而实现图片生成。训练数据只需要准备图片即可。
  • 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
  • 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练ϵθ(xt,y,t)ϵθ(xt,y=,t)(ϵθ(xt,t))\epsilon_{\theta}(x_t, y,t),\epsilon_{\theta}(x_t, y=\empty,t)(\text{即}\epsilon_{\theta}(x_t,t))。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。
回答3个问题深入理解classifier-free DM
  1. 模型如何融入类别信息(或语义信息)
  1. 如何训练ϵθ(xt,y,t)\epsilon_{\theta}(x_t, y,t)ϵθ(xt,y=,t)\epsilon_{\theta}(x_t, y=\empty,t)
  1. 如何进行采样生成

模型如何融入类别信息(或语义信息)

采用交叉注意力机制融入

我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding。随后需要将text embedding和原本模型中的image representation进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention。具体来说就是将text embedding作为注意力机制中的keyvalue,原始的图片表征作为query。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。
class SpatialCrossAttention(nn.Module): def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None: super(SpatialCrossAttention, self).__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0) self.to_q = nn.Linear(context_dim, hidden_dim, bias=False) self.to_k = nn.Linear(context_dim, hidden_dim, bias=False) self.to_v = nn.Linear(context_dim, hidden_dim, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x, context=None): x_q = self.proj_in(x) b, c, h, w = x_q.shape x_q = rearrange(x_q, "b c h w -> b (h w) c") if context is None: context = x_q if context.ndim == 2: context = rearrange(context, "b c -> b () c") q = self.to_q(x_q) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads) out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w) out = self.to_out(out) return out

基于channel-wise attention融入

该融入方法与time-embedding的融入方法相同,在时间中往往会预先和time-embedding进行融合,再融入到图片特征中,伪代码如下:
# mixture time-embedding and label embedding t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] emb_out = self.emb_layers(emb).type(h.dtype) # h is image feature scale, shift = th.chunk(emb_out, 2, dim=1) # One half of the embedding is used for scaling and the other half for offset h = h * (1 + scale) + shift
基于channel-wise的融入粒度没有CrossAttention细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention的方法。

如何训练ϵθ(xt,y,t)\epsilon_{\theta}(x_t, y,t)ϵθ(xt,y=,t)\epsilon_{\theta}(x_t, y=\emptyset,t)

ϵθ(xt,y,t)\epsilon_{\theta}(x_t, y,t)的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练ϵθ(xt,y,t)\epsilon_{\theta}(x_t, y,t)将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系ϵθ(xt,y=,t)\epsilon_{\theta}(x_t, y=\empty,t)的数据不宜过多,否则模型生成结果的保真度会降低。反之,若ϵθ(xt,y=,t)\epsilon_{\theta}(x_t, y=\empty,t)数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。
有两个实践中的trick需要注意
  • 在实践中,为了统一y=y=\emptyyy \neq \empty 两种情形,通常会给定一个y=y=\empty的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
  • 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为y=y=\empty的embedding即可。另外

如何进行采样生成

classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的ϵ(xt,t)\epsilon(x_t, t)用下式代替。
ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{1} \end{align} 
下面给出详细的推导过程:
首先根据贝叶斯公式有
p(yxt)=p(xty)p(y)先验分布p(xt)p(yxt)p(xty)/p(xt)取对数logp(yxt)=logp(xty)logp(xt)xt求导xtlogp(yxt)=xtlogp(xty)xtlogp(xt)根据score functionxtlogpθ(xt)=11αtϵθ(xt)xtlogp(yxt)=11αt(ϵθ(xt,y,t)ϵθ(xt,y=,t))(2)\begin{aligned} p (y| x_t) & = \frac{p (x_t|y) \overbrace{p(y)}^{\text{先验分布}} } {p(x_t) } \\ \Rightarrow p (y| x_t) & \propto p (x_t|y) / {p (x_t) } \\ \stackrel{取对数} \Rightarrow \log{p (y| x_t)} & = \log{p (x_t|y)} - \log{{p (x_t) }} \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = \nabla_{x_t}\log{p (x_t|y)} - \nabla_{x_t}\log{{p (x_t) }} \\ \stackrel{\text{根据score function} \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t)} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}}(\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ) \end{aligned} \tag{2}
当我们得到xtlogp(yxt)\nabla_{x_t}\log{p (y| x_t)},参考classifier-guided的式(17)
ϵ^(xty)本文中的ϵ^θ(xt,y,t):=ϵθ(xt)本文中的ϵθ(xt,y=,t)s1αtxtlogpϕ(yxt)(3)\underbrace{\hat{\epsilon}(x_t|y)}_{\text{本文中的}\hat{\epsilon}_{\theta}(x_t, y, t)} := \underbrace{\epsilon_\theta(x_t)}_{\text{本文中的}\epsilon_{\theta}(x_t, y=\empty, t)} - s\sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{3}
可得
ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)]\begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{4} \end{align}
后面的采样过程与之前的方式一致。

结语

本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。

参考文献

附录

式12推导验证
12(xtμΣg)TΣ1(xtμΣg)+12gTΣg+C2=12(xtTμTgTΣT)Σ1(xtμΣg)+12gTΣg+C2=12(xtTμTgTΣT)Σ1(xtμΣg)+12gTΣg+C2 =12(xtTΣ1μTΣ1gTΣTΣ1gT)(xtμΣg)+12gTΣg+C2=12(xtTΣ1(xtμΣg)  μTΣ1(xtμΣg) gT(xtμΣg))+12gTΣg+C2=12(xtTΣ1(xtμ)μTΣ1(xtμ))(xtμ)TΣ1(xtμ)12(gT(xtμΣg)+(xtTΣ1Σg)xtTg+ μTΣ1ΣgμTg)+12gTΣg+C2=12(xtμ)TΣ1(xtμ)+(xtμ)g+C2 \begin{align*}&- \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\= &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\= &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ \\  = & - \frac{1}{2} (x_t^T \Sigma^{-1} - \mu^T \Sigma^{-1} - \underbrace{g^T \Sigma^T \Sigma^{-1}}_{g^T} )(x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} (x_t - \mu - \Sigma g)    - \mu^T \Sigma^{-1} (x_t - \mu - \Sigma g)  - g^T (x_t - \mu - \Sigma g)) + \frac{1}{2}g^T\Sigma g + C_2 \\= & - \frac{1}{2} \underbrace{(x_t^T \Sigma^{-1} (x_t - \mu ) - \mu^T \Sigma^{-1} (x_t - \mu))}_{(x_t - \mu)^T \Sigma^{-1} (x_t - \mu)} - \frac{1}{2} ( - g^T (x_t - \mu - \Sigma g) + \underbrace{(- x_t^T \Sigma^{-1}\Sigma g)}_{-x_t^Tg} +  \underbrace{\mu^T \Sigma^{-1}\Sigma g}_{\mu^Tg}) + \frac{1}{2}g^T\Sigma g + C_2 \\= & - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2  \\ \end{align*}
 
相关文章
diffusion model(一):DDPM技术小结 (denoising diffusion probabilistic)
Lazy loaded image
diffusion model(二):DDIM技术小结 (denoising diffusion implicit model)
Lazy loaded image
diffusion model(三):classifier guided diffusion model
Lazy loaded image
diffusion model(五):LDM: 在隐空间用diffusion model合成高质量图片
Lazy loaded image
diffusion model(六):Dalle2 技术小结
Lazy loaded image
diffusion model(七):diffusion model is a zero-shot classifier
Lazy loaded image
diffusion model(三):classifier guided diffusion modeldiffusion model(五):LDM: 在隐空间用diffusion model合成高质量图片
Loading...
莫叶何竹🍀
莫叶何竹🍀
非淡泊无以明志,非宁静无以致远
最新发布
Nougat 深度剖析
2025-3-18
表格结构还原——SLANet
2025-2-27
KV-Cache技术小结(MHA,GQA,MQA,MLA)
2025-2-24
diffusion model(十九) :SDE视角下的扩散模型
2024-12-31
🔥Lit: 进一步提升多模态模型Zero-Shot迁移学习的能力
2024-11-22
RNN并行化——《Were RNNs All We Needed?》论文解读
2024-11-21
hexo