type
status
date
slug
summary
tags
category
icon
password
💡
这篇文章将从ddpm的原理推导和代码实现两个方向带你深入理解扩散模型。

1 从直觉上理解DDPM

在详细推到公式之前,我们先从直觉上理解一下什么是扩散
对于常规的生成模型,如GAN,VAE,它直接从噪声数据生成图像,我们不妨记噪声数据为,其生成的图片为
对于常规的生成模型
学习一个解码函数(即我们需要学习的模型)p,实现
常规方法只需要一次预测即能实现噪声到目标的映射,虽然速度快,但是效果不稳定。
常规生成模型的训练过程(以VAE为例)
对于diffusion model
它将噪声到目标的过程进行了多步拆解。不妨假设一共有个时间步,第个时间步 是噪声数据,第0个时间步的输出是目标图片。其过程可以表述为:
对于DDPM它采用的是一种自回归式的重建方法,每次的输入是当前的时刻及当前时刻的噪声图片。也就是说它把噪声到目标图片的生成分成了T步,这样每一次的预测相当于是对残差的预测。优势是重建效果稳定,但速度较慢。
训练整体pipeline包含两个过程

2 diffusion pipeline

2.1前置知识:

高斯分布的一些性质
(1)如果,且是实数,那么
(2)如果是统计独立的正态随机变量,则它们的和也满足高斯分布(高斯分布可加性).
均值为方差为的高斯分布的概率密度函数为

2.2 加噪过程

1 前向过程:将图片数据映射为噪声
每一个时刻都要添加高斯噪声,后一个时刻都是由前一个时刻加噪声得到。(其实每一个时刻加的噪声就是训练所用的标签)。即
下面我们详细来看
的增加而增大(论文中[2]从0.0001 -> 0.02) (这是因为一开始加一点噪声就很明显,后面需要增大噪声的量才明显).DDPM将加噪声过程建模为一个马尔可夫过程其中
为在t时刻的图片,当时为原图;为在t时刻所加的噪声,服从标准正态分布是常数,是自己定义的变量;从上式可见,随着增大,越来越接近纯高斯分布.
同理:
将式(8)代入式(7)可得:
由于服从均值为0,方差为1的高斯分布(即标准正态分布),根据定义服从的是均值为0,方差为的高斯分布.即.同理可得.则(高斯分布可加性,可以通过定义推得,不赘述)
我们不妨记则式(10)最终可改写为
通过递推,容易得到
其中为原图.从式(13)可见,我们可以从得到任意时刻的的分布,而无需按照时间顺序递推!这极大提升了计算效率.
⚠️加噪过程是确定的,没有模型的介入. 其目的是制作训练时标签

2.3 去噪过程

给定如何求出呢?直接求解是很难的,作者给出的方案是:我们可以一步一步求解.即学习一个解码函数,这个能够知道的映射规则.如何定义这个是问题的关键.有了,只需从逐步迭代,即可得出.
去噪过程是加噪过程的逆向.如果说加噪过程是求给定初始分布求任意时刻的分布,即那么去噪过程所求的分布就是给定任意时刻的分布求其初始时刻的分布,通过马尔可夫假设,可以对上述问题进行化简
如何求呢?前面的加噪过程我们大力气推到出了我们可以通过贝叶斯公式把它利用起来
⚠️这里的(去噪)和上面的(加噪)只是对分布的一种符号记法。
有了式(17)还是一头雾水,都不知道啊!该怎么办呢?这就要借助模型的威力了.下面来看如何构建我们的模型.
延续加噪过程的推导我们是可以知道的.因此若我们知道初始分布,则
 
结合高斯分布的定义(5)来看式(21),不难发现也是服从高斯分布的.并且结合式(5)我们可以求出其方差和均值
⚠️式17做了一个近似能做这个近似原因是一阶马尔科夫假设,当前时间点只依赖前一个时刻的时间点.
可以求得:
通过上式,我们可得
该式是真实的条件分布.我们目标是让模型学到的条件分布尽可能的接近真实的条件分布从上式可以看到方差是个固定量,那么我们要做的就是让的均值尽可能的对齐,即
(这个结论也可以通过最小化上述两个分布的KL散度推得)
下面的问题变为:如何构造来使我们的优化尽可能的简单 
我们注意到都是关于的函数,不妨让他们的保持一致,则可将写成
是我们需要训练的模型.这样对齐均值的问题就转化成了: 给定来预测原始图片输入根据上文的加噪过程,我们可以很容易制造训练所需的数据对! (Dalle2的训练采用的是这个方式).事情到这里就结束了吗?
DDPM作者表示直接从的预测数据跨度太大了,且效果一般.我们可以将式(12)做一下变形
代入到式(24)中
经过这次化简,我们将其中可以将式(29)转变为
此时对齐均值的问题就转化成:给定预测加入的噪声, 也就是说我们的模型预测的是噪声

2.3.1 训练与采样过程

训练的目标就是这所有时刻两个噪声的差异的期望越小越好(用MSE或L1-loss).
下图为论文提供的训练和采样过程
notion image

2.3.2 采样过程

通过以上讨论,我们推导出高斯分布的均值和方差.,根据文献[1]从一个高斯分布中采样一个随机变量可用一个重参数化技巧进行近似
式(32)和论文给出的采样递推公式一致.
至此,已完成DDPM整体的pipeline.
因为和式(7)的并不是独立同分布,所以不能根据(7)的变形来进行采样计算。

3 从代码理解训练&预测过程

3.1 训练过程

已知项: 我们假定有一批N张图片
第一步: 随机采样K组成batch,如
第二步: 随机采样一些时间步
第三步: 随机采样噪声
第四步: 计算在所采样的时间步的输出(即加噪声).(根据公式12)
第五步: 预测噪声.输入到噪声预测模型,来预测此时的噪声.论文用到的模型结构是Unet,与传统Unet的输入有所不同的是增加了一个时间步的输入.
这里面有一个需要注意的点:模型是如何对时间步进行编码并使用的
  • 首先会对时间步进行一个编码,将其变为一个向量,以正弦编码为例
  • 将时间步的embedding嵌入到Unet的block中,使模型能够学习到时间步的信息
第六步:计算损失,反向传播.计算预测的噪声与实际的噪声的损失,损失函数可以是L1或mse
通过不断迭代上述6步即可完成模型的训练

3.2采样过程

第一步:随机从高斯分布采样一张噪声图片,并给定采样时间步
第二步: 根据预测的当前时间步的噪声,通过公式计算当前时间步的均值和方差
第三步: 根据公式(32)计算得到前一个时刻图片
通过迭代以上三步,直至完成采样.

思考和讨论

DDPM区别与传统的VAE与GAN采用了一种新的范式实现了更高质量的图像生成.但实践发现,需要较大的采样步数才能得到较好的生成结果.由于其采样过程是一个马尔可夫的推理过程,导致会有较大的耗时.后续工作如DDIM针对该特性做了优化,数十倍降低采样所用时间。

参考文献

diffusion model(二):DDIM技术小结 (denoising diffusion implicit model)Segment Anything(SAM)
  • Twikoo