type
status
date
slug
summary
tags
category
icon
password
info
paper
github
个人博客主页
create date
2024-03-08
阅读前需要具备以下前置知识:
DDPM(扩散模型基本原理):知乎地址 个人博客地址 paper
LDM (隐空间扩散模型基本原理,stable diffusion 底层架构) 知乎地址 个人博客地址 paper
classifier-free guided(文生图基本原理) 知乎地址 个人博客地址 paper

Motivate

虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。本文旨在探讨以Transformer取代UNet在扩散模型中的可行性和潜在方案,并对所提出的Diffusion Transformer (DIT)架构的scalable能力进行了验证和评估。

Method

采用DiT架构替换UNet主要需要探索以下几个关键问题:
  1. Token化处理。Transformer的输入为一维序列,形式为(忽略batch维度),而LDM的latent表征为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。
  1. 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代Unet需要系统研究Transformer架构的条件嵌入
DiT这篇paper的核心在于对上述两个问题的系统研究。
notion image

Patchify(token化)

假定原始图片,经过auto-encoder后得到latent表征。首先DiT 用ViT中patch化的方式将隐表征 转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p是一个超参数。文中分别尝试了p=2,4,8。(DiT的输出会将每一个token线性解码成pxpx2C,再reshape为nose和协方差)
notion image

DiT block设计

这个部分系统探究了4中在DiT中引入控制信号的方案。
(一)In-context conditioning
直接将时间步信号、文本控制信号作为addition token和输入sequence进行拼接。其角色类似于类似于ViT里面的[CLS]token。这样做有一个好处,原本的ViT架构都可以不动,并且增加的的计算量可以忽略不计。
(二)Cross-Attention block
这个方法首先将时间步信号和文本信号进行拼接,得到拼接后的控制信号。随后类似文献[1]的做法,在ViT中添加cross attention层,将控制信号作为cross-attention的key,value进行融入。
(三)Adaptive Layer Norm (adaLN) block
作者参考文献[2]提出的adaptive normalization layer(adaLN),将transformer block的layer norm替换为adaLN。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter和shift parameter 用condition embedding来替代。下面给出了最简的示例代码便于理解。
论文原话:Rather than directly learn dimensionwise scale and shift parameters γ and β, we regress them from the sum of the embedding vectors of t and c.
(四)adaLN-Zero block
这个方法是(三)的延伸。简单来说就是condition embedding除了融入到layer norm中,还作为residual的强度融入到residual连接中。下面给出了最简的示例代码

Result

作者在imagenet数据上,以classifier-free的方式训练DiT(仅做class-control,即text condition embedding为类别embedding)。作者设置了4种不同model size的DiT,并开展实验。
notion image

DiT的scalable能力验证

作者分别尝试了的patch size,不同model size的DiT,从图中不难发现
  • patch size越小生成的效果越好(意味着初始时sequence的token数越多)。这里不太明白为什么作者不实验p=1的情形。因为latent表征本身就可以视作是CNN抽取的隐式token,只要flatten即可,很多hybrid的架构(CNN+ViT)都是这么玩的,或许是为了控制计算量?
  • model size越大生成效果越好。从实验结果中DiT-XLDiT-L的差距很小,可能是因为训练数据量还不够大体现不出大模型的优势
notion image
notion image

DiT Block有效性验证

作者在imagenet数据集上验证上面提出的四种DiT block的的生成效果。ada LN-Zero方案的生成效果最好。
notion image
 

小结

DiT 系统研究了diffusion transformer的token化和条件嵌入两个关键问题,验证了基于transformer架构的扩散模型的scalable能力。

参考文献

[1] Attention is all you need.
[2] Film: Visual reasoning with a general conditioning layer.
 
diffusion model(十四): prompt-to-prompt 深度剖析diffusion model(十二): StableCascade技术小结
  • Twikoo