type
status
date
slug
summary
tags
category
icon
password
Project web: https://emu-edit.metademolab.com/
Code: have not opensource
1 核心思想
作者将intruction-base image editing任务建模为生成任务,并用diffusion model进行求解。核心创新点有两个
- 详细定义了instruction-based image edit处理的任务,并设计了一个高效高质量的数据构建方法。
- 为提升模型对instruction的理解能力,引入learnable task embedding,能较好的解决上述问题。并且提出task inversion的训练方法,只需少量数据就能有效将模型扩展到新的task(类似textual inversion的思想)。
2 方法
2.1 方法建模
前面提到,作者将一系列的intruction-base image editing任务建模为生成任务,并用diffusion model来求解。具体来看intruction-base image editing任务做的是这么一件事:给定一张参考图片和一段表述文本,输出符合上述两个条件的图片。从上述描述可知:intruction-base image editing的训练数据应当至少是一个三元组,其中
: 参考图片(condition of image)
: 参考文本(condition of text)
: 目标图片
这样,基于diffusion model的优化目标可建模为:
和经典的classifier-free有所区别的是,此处多了一个参考图片的condition 。条件融入的方法上,
- 作者参考
Instructpix2pix
将image condition融入到输入层(在通道维度进行concat)。
- 参考classifier-free将text condition的融入在cross-attention。
通过实验,作者发现用上述方法训练的模型对task的理解不够准确如下图所示。为此,作者引入learnable task embedding来增强模型对task的理解。此时的优化目标建模为:
为了求解上述目标方程,构造的训练数据集的每一个元素应当是一个四元组为这条数据所属的task类别。并且此时的diffusion model的噪声预测模型多了一个task embedding 的条件。作者的融入方式是将其与time-step的embedding进行相加,共同融入到cross-attention中。这样设计还保留了可扩展性:当有一个新的task时,可以将优化目标转化为
此时训练的参数仅为新增的task embedding,其它参数都freeze。作者将其称之为task inversion(类似textual inversion)。
在用户层的推理阶段,用户无需输入task index,作者基于
Flan-T5-XL
训练了一个task index预测模型,来根据用户输入的instruction预测出相应的task index。从实现原理上,上述方法不难想到。论文取得的卓越的效果取决于训练的数据集。下面来看作者是如何用一种高效的方法构建高质量的数据集。
2.2 数据工程
前文提到,训练一个image-edit diffusion model训练数据至少是一个三元组(其中: 参考图片(condition of image) : 参考文本(condition of text): 目标图片)。手动构建数据集的成本非常大,开源数据规模又不够大,一些规模大的合成数据多样性和质量又不高,因此需要探寻如何用cheap的方法来构建一个高质量、大规模、高多样的image-edit数据集。为了结合task inversion,新构建的数据集应当是一个四元组为这条数据所属的task类别。
2.2.1 image-edit任务类别定义
作者将image-edit分为了三大类,分别是Region-based Editing、Free-From Editing、Vision tasks,每个大类中有若干小类。下图展示了每一个image-edit任务所做的事
2.2.2 指令集生成
任务定义:已知image caption和编辑任务,输出满足编辑任务新的caption
- 输入:image caption + edit任务
- 输出:edit instruction, edit instruction应当包含:1)edit指令;2)edit的目标(edited object);3)新的image caption;4)原始目标(original object)(7.2节提到有这个字段,但在7.1中的示例没有,实际上应当要有这个字段,否则后续的mask提取无法进行)
举个例子(对于
add
的image-edit任务)输入:{"image_caption": "Beautiful cat with mojito sitting in a cafe on the street", "task": "Add"}
输出:{"edit": "include a hat", "edited object": "hat", "output": "Beautiful cat wearing a hat with mojito sitting in a cafe on the street", "original object": "cat"}
作者用context learning的方法来实现上述任务的目标。作者构建的prompt方案如下:(作者所用的LLM是微调了的70B LLama2,我用chatgpt尝试了一下,也能实现类似的效果)
2.2.3 图片对的生成
通过上面的步骤我们拿到了4元组 ,中的,其中还有很多附加信息:
如:编辑的对象,新的image caption,如:
{"edit": "include a hat", "edited object": "hat", "output": "Beautiful cat wearing a hat with mojito sitting in a cafe on the street", "original object": "cat"}
此处需要进行的是根据上面的条件,得到对应的图片pair (x)。
任务目标:根据输入图片、instruction信息生成对应的图片pair (x)并且除了编辑的区域,x与c_I的差异应当尽可能的小。
为了解决上述的任务目标,作者提出一种
mask-based attention control
的方法(相当于DiffEdit和P2P的结合)。具体分为以下几个步骤:已知条件:
- : image caption 。
example:Beautiful cat with mojito sitting in a cafe on the street
- : image caption用DM生成的图片
- : 编辑后的image caption。
Beautiful cat wearing a hat with mojito sitting in a cafe on the street
- : image caption的原始目标(original object)。
cat
- 编辑目标(edited object):
hat
STEP1: 提取mask。将与送入到SAM+DINO模型中,得到3类mask
- 精确的mask,有sam+dino的生成
- 将1中的mask进行膨胀,在进行高斯模糊,作为新的mask
- 取第一步mask的bounding box作为新的mask
SETP2: 通过
mask-based attention control
进行图片生成。具体为:先用P2P的cross-attention control的方法将common token的对应的attention map进行注入,随后用diffedit的根据mask融合方法进行融合。STEP3: 图片Filter。通过上述步骤得到3个目标图片,留存最好的一个。filter的规则为
- 用CLIP filtering metrics,留存最相关的一个
- 留存edit image与input image在深度图上的L1 距离最小的图片。
- ...
每一类Edit方法的详细的数据构造细节见论文7.2.3
最后得到的各类训练数据比例如下:
3 结果
EmuEdit在多个测试集取得了SOTA。并且作者公开了一个新的基于EmuEdit的benchmark:https://huggingface.co/datasets/facebook/emu_edit_test_set
一些惊艳结果:
- 作者:莫叶何竹🍀
- 链接:http://www.myhz0606.com/article/emu_edit
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章