type
status
date
slug
summary
tags
category
icon
password
背景
对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,如何生成特定类别的图片呢?这就是classifier guide需要解决的问题。
方法大意
为了实现带类别标签y的DM的推导,进行了以下定义
虽然上式定义了以为条件的噪声过程,但我们还可以证明当不以为条件时的行为与完全相同,即
同样的思路:
根据上式同样可以推导出
由上述推导可见带条件的DM的前向过程与DDPM完全相同。并且根据贝叶斯公式,不带逆向过程也满足
与此同时我们可以证明分类分布只和当前时刻的输入有关,与无关
基于条件的去噪过程
将带类别信息的去噪过程定义为
由于是已知的,这个概率分布与无关,可以将视为常数。此时上式可以表述为
上式的右边第二项)很容易得到,我们可以根据的pair对训练一个分类模型
上式的右边第三项在DDPM中也能够通过一个neural network进行估计
故采样分布
下面来看有了上面这个式子如何进行采样
直接对上面的式子进行采样是很难解决的。论文参考文献1将上式近似为perturbed Gaussian distribution。
根据前文DM的推导可知 ,对其取对数
对于作者假设其curvature比低。这个假设是合理的,对于当diffusion steps足够大时,。在该情况下,对在处进行泰勒展开
(附录给出了验证性证明)
通过上述推导,我们得到了带类别条件的采样过程也可以用高斯分布来近似,只是均值需要加上。具体的算法如下
代码实现
p_mean_var_ddpm
是DDPM对高斯分布均值、方差的计算函数p_mean_var_ddpm_with_classifier
是引入类别控制后的对高斯分布均值、方差的计算函数有了均值方差就可以进行采样了
DDIM 中基于条件的去噪过程
上述条件抽样推导仅对随机扩散采样过程有效,不能应用于DDIM2等确定性采样方法(因为DDIM中设定了方差为0,故无法推导出式11)。为此,作者在研究中采用score-based的思路,参考了Song等人 [3]的方法,并利用了扩散模型和score matching之间的联系[4]。
首先根据贝叶斯公式
具体来说,如果我们有一个模型来预测添加到样本中的噪声,那么可以利用它来推导出一个score function:
代入式(13)得
定义在条件y下的估计噪声为:
只需将DDIM中的 替换为就得到了基于条件的去噪过程。同样也可以引入gradient score 来控制条件的强度,此时式16改写为
代码上也很直观
一些细节
classifier的训练
classifier的训练与扩散模型的训练可以是独立的。在训练classifier的时候可以噪声预测模型(Unet)的encode部分作为主干,在后面接了一个分类层。并且需要与相应的扩散模型相同的噪声分布对classifier进行训练。训练数据集如。是对时间步的采样,是在时间步的输出。训练完成后,采用上面的算法集成到采样过程中。
gradient score的作用
在上面的采样算法我们看到有一个gradient scale s来对梯度进行拉伸。
实验视角
一般来说当时,大约能保证生成的图片50%是想要的类别5,随着的增大,这个比例也能够增加。如下图,当增加到10,此时生成的图片都是期望的类别。因此也称之为guidance scale。
其实理解这个scale还有另一个视角
,当他相当于对分布进行了一个指数拉升,从而带来更大的梯度更新收益。
根据DM的采样过程,当没有classifier guided时,在时刻,的采样过程应当是
当加了classifier guided相当于将向预测类别为的方向更新了一小步。是控制更新的幅值。
参考文献
附录
式12推导验证
- 作者:莫叶何竹🍀
- 链接:http://www.myhz0606.com/article/guided
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章