type
status
date
slug
summary
tags
category
icon
password
URL
paper
github

背景

在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。
Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即。从而简化整体的pipeline。
此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:
下面我们来看他是怎么做的吧!

方法大意

classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。
模型
训练目标
实现功能
训练数据
DM (DDPM, DDIM)
从服从高斯分布的噪声中生成图片
图片
classifier-guided DM
和分类器
从服从高斯分布的噪声中生成特定类别的图片
DM:图片 分类器:图片-标签对
classifier-free DM
从服从高斯分布的噪声中生成符合文本描述的图片
图片-文本对
  • 对于训练来估计在时间上添加的噪声,再根据采样公式推出,从而实现图片生成。训练数据只需要准备图片即可。
  • 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
  • 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。
回答3个问题深入理解classifier-free DM
  1. 模型如何融入类别信息(或语义信息)
  1. 如何训练
  1. 如何进行采样生成

模型如何融入类别信息(或语义信息)

采用交叉注意力机制融入

我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding。随后需要将text embedding和原本模型中的image representation进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention。具体来说就是将text embedding作为注意力机制中的keyvalue,原始的图片表征作为query。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。

基于channel-wise attention融入

该融入方法与time-embedding的融入方法相同,在时间中往往会预先和time-embedding进行融合,再融入到图片特征中,伪代码如下:
基于channel-wise的融入粒度没有CrossAttention细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention的方法。

如何训练

的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系的数据不宜过多,否则模型生成结果的保真度会降低。反之,若数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。
有两个实践中的trick需要注意
  • 在实践中,为了统一 两种情形,通常会给定一个的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
  • 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为的embedding即可。另外

如何进行采样生成

classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的用下式代替。
下面给出详细的推导过程:
首先根据贝叶斯公式有
当我们得到,参考classifier-guided的式(17)
可得
后面的采样过程与之前的方式一致。

结语

本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。

参考文献

附录

式12推导验证
 
Nougat 深度剖析diffusion model(三):classifier guided diffusion model
  • Twikoo