ADM: Improved Diffusion Models with Classifier Guidance

释放双眼,带上耳机,听听看~!
了解OpenAI发布的最新研究,ADM模型在图像合成中的性能优势以及使用Classifier Guidance技术的改进。

Classifier Guidance

Improved DDPM虽然对DDPM进行了改进,但在一些大数据集上(如ImageNet 256×256)生成图片的实验效果(FID)仍是低于GAN。因此,OpenAI继续对DDPM进行改进,在2021年随后又发表了论文《Diffusion Models Beat Gans on Image Synthesis》,在模型结构上进一步优化,同时引入Classifier Guidance技术,通过图片分类器的梯度调节反向扩散过程,在尽量保持图片生成多样性的前提下,提升准确性,从而在多个数据集的实验效果(FID)上超过了GAN,实现了SOTA。论文将改进后的模型称为ADM(Ablated Diffusion Model)。

改进

网络结构

DDPM和Improved DDPM中的模型均使用U-Net,ADM在其网络结构的基础上,进一步增加以下数项改进:

  • 增加网络结构的宽度和深度;
  • 在注意力机制上,DDPM原先只在16×16这一层增加单头注意力层(缩放点积注意力),而ADM在32×32、16×16、8×8各层均增加了多头注意力层;
  • 在上下采样上,DDPM原先在下采样使用池化或卷积层、在上采样使用插值或卷积层,而ADM使用残差卷积块;

ADM的代码开源,代码地址是:github.com/openai/guid…,其是在Improved DDPM的代码基础上进行修改。网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中,例如,ADM使用残差块卷积进行下采样的代码如下:

            if level != len(channel_mult) - 1:
                # 除编码器最后一层外的其他层,需要进行下采样输出到下一层
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        # 若标记resblock_updown为True,则使用残差卷积块进行下采样
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            # 设置残差卷积块中需进行下采样
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

自适应组归一化

ADM: Improved Diffusion Models with Classifier Guidance

ADM还使用了自适应组归一化(Adaptive Group Normalization,AdaGN),组归一化如图1最右侧所示,即对一个图片样本的所有像素,按通道分组进行归一化,而自适应归一化可表示为以下公式:
AdaGN(h,y)=ysGroupNorm(h)+ybtext{AdaGN}(h,y)=y_stext{GroupNorm(h)}+y_b
其中,hh是残差卷积块中第一个卷积层的输出,ysy_syby_b分别是步数和图片分类的Embedding向量经过线性层后的投影。自适应归一化的代码如下所示:

            # 经过第一个卷积层的输出
            h = self.in_layers(x)
            ......
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            # 取y_s和y_b
            scale, shift = th.chunk(emb_out, 2, dim=1)
            # 按y_s * GroupNorm(h) + y_b进行自适应组归一化
            h = out_norm(h) * (1 + scale) + shift
            # 经过第二个卷积层输出
            h = out_rest(h)

论文通过实验发现,使用自适应组归一化能够进一步优化FID。

Classifier Guidance

除了在网络结构上精心设计和优化外,GAN还在有条件(已知图片类别)的图片生成中大量使用了图片类别信息。基于此,ADM一方面在自适应组归一化中引入图片类别的Embedding向量作为模型输入,另一方面设计了Classifier Guidance机制,通过引入一个分类器指导反向扩散过程:预先使用带噪声的图片xtx_t训练分类器pϕ(y∣xt)p_{phi}(y|x_t)实现对类别的预测;在逐步反向扩散生成图片时,DDPM在每一步基于扩散模型预测噪声ϵθ(xt)epsilon_theta(x_t)和方差Σθ(xt)Sigma_theta(x_t),并由公式μθ(xt)=1αt(xt−1−αt1−αˉtϵθ(xt))mu_theta(x_t)=frac{1}{sqrt{alpha_t}}left(x_t-frac{1-alpha_t}{sqrt{1-bar{alpha}_t}}epsilon_theta(x_t)right)计算得到μθ(xt)mu_theta(x_t),即得到了xt−1x_{t-1}高斯分布的均值和方差,在此基础上,ADM使用分类器pϕ(y∣xt)p_{phi}(y|x_t)输出对数的梯度对均值进行调整,并使用调整均值后的高斯分布进行采样得到xt−1x_{t-1},均值调整公式如下所示:

μ^θ(xt∣y)=μθ(xt∣y)+s⋅Σθ(xt∣y)∇xtlog⁡pϕ(y∣xt)hat{mu}_theta(x_t|y)=mu_theta(x_t|y)+scdotSigma_theta(x_t|y)nabla_{x_{t}}log{p_{phi}{(y|x_t)}}

其中,系数ss被称为Guidance Scale,论文通过实验发现随着ss的增加,生成图片的质量会提升,但多样性会减少。引入Classifier Guidance后的采样算法步骤如图2所示。

ADM: Improved Diffusion Models with Classifier Guidance

ADM中使用的分类器网络结构和扩散模型网络结构近似,均采用U-Net,但只使用编码器层和中间层,而没有解码器层,另外,由于分类器的目标是预测类别,因此类别没有作为输入。分类器网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中。
在分类器和扩散模型训练完成后,便可使用其进行图片采样。采样时使用分类器输出梯度对均值进行调整的代码在guided-diffusion/gaussian_diffusion.py的condition_mean方法中,如下所示:

    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        # 使用分类器输出分类器梯度
        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
        # 根据分类器梯度、均值和方差计算新均值
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

而其中使用分类器输出分类器梯度的cond_fn方法代码如下所示:

    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            # 分类器输出分类结果
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            # 计算梯度并返回
            return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

效果

ADM: Improved Diffusion Models with Classifier Guidance

论文在多个数据集上对ADM(仅做网络结构优化)、ADM-G(同时引入Classifier Guidance机制)和其他模型进行了对比实验,从FID指标上,ADM、特别是ADM-G超过了GAN,实现了SOTA。

Classifier-Free Guidance

在Classifier Guidance机制被提出后,紧接着Google于2021年发表了论文《Classifier-Free Diffusion Guidance》。这篇论文指出Classifier Guidance仍存在以下不足:一是Classifier Guidance需要额外训练分类器,二是Classifier Guidance会导致基于梯度的对抗攻击,欺骗FID、IS这类基于分类器的评估指标。因此,这篇论文提出了一种不需要训练分类器、但仍可以基于类别信息指导反向扩散过程的机制——Classifier-Free Guidance。
之前ADM等扩散模型可使用类别信息进行有条件的图片生成,由模型基于xtx_t和类别yy预测噪声ϵθ(xt,y)epsilon_theta(x_t,y),或是不使用类别信息进行无条件的图片生成,由模型仅基于xtx_t预测噪声ϵθ(xt)epsilon_theta(x_t),但这两种情况需要分别训练模型,而Classifier-Free Guidance的思想是在模型训练时,按一定比例丢弃类别信息,使得模型能够同时学习有条件的图片生成和无条件的图片生成,这样在采样生成图片时,由同一个模型预测ϵθ(xt,y)epsilon_theta(x_t,y)ϵθ(xt)epsilon_theta(x_t),并使用两者的差值等价替换分类器输出的梯度对ϵθ(xt,y)epsilon_theta(x_t,y)进行调整,调整公式如下:

ϵ^θ(xt,y)=ϵθ(xt)+s⋅(ϵθ(xt,y)−ϵθ(xt))hat{epsilon}_theta(x_t,y)=epsilon_theta(x_t)+scdot(epsilon_theta(x_t,y)-epsilon_theta(x_t))

再基于调整后的ϵ^θ(xt,y)hat{epsilon}_theta(x_t,y)计算均值,并从高斯分布中采样得到xt−1x_{t-1}

CLIP Guidance

Classifier Guidance使用类别信息指导反向扩散过程,那是否可以使用除类别外的其他信息指导反向扩散过程?2022年发表的论文《More Control for Free! Image Synthesis with Semantic Diffusion Guidance》就尝试使用了其他信息,其中包括在多模态领域应用比较广泛的CLIP模型。
CLIP模型包括两部分,图片编码器f(x)f(x)和文本编码器g(c)g(c),其中xx为图片,cc为文本。训练阶段,采用对比学习,使得正确图片、文本对(x,c)(x,c)的点积f(x)⋅g(c)f(x)cdot g(c)尽可能大,错误图片、文本对的点积尽可能小。因此,在推理阶段,可以进行文本和图片相关性的比较。关于CLIP模型的详细介绍,可以阅读原论文《Learning Transferable Visual Models From Natural Language Supervision》《AIGC系列-CLIP论文阅读笔记》
在Classifier Guidance中可以使用CLIP模型替换分类器,对于xtx_t,使用f(xt)⋅g(c)f(x_t)cdot g(c)的梯度调整μθ(xt∣c)mu_theta(x_t|c),公式如下所示:

μ^θ(xt∣c)=μθ(xt∣c)+s⋅Σθ(xt∣c)∇xt(f(xt)⋅g(c))hat{mu}_theta(x_t|c)=mu_theta(x_t|c)+scdotSigma_theta(x_t|c)nabla_{x_t}(f(x_t)cdot g(c))

和Classifier Guidance中的分类器类似,需使用带噪声的图片和文本对训练CLIP模型以获得正确的梯度。

GLIDE

在上述工作的基础上,OpenAI于2022年发表了论文《GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models》,其中发布了GLIDE(Guided Language to Image Diffusion for Generation and Editing)模型,用于基于文本的图片生成。图4是使用GLIDE模型基于文本生成的图片。

ADM: Improved Diffusion Models with Classifier Guidance

基于文本的图片生成

一般的扩散模型从随机采样的高斯噪声开始,不能生成特定的图片,而GLIDE在已有扩散模型的基础上,使用文本信息指导扩散过程,对于带噪声的图片xtx_t和文本cc,能够通过模型预测pθ(xt−1∣xt,c)p_theta(x_{t-1}|x_t,c),从而逐步降噪,实现了基于文本的图片生成。
具体实现上,GLIDE基于ADM模型,但模型参数和训练数据规模更大,模型参数达到35亿。GLIDE先将文本cc转化为长度为KK的token序列,再通过Transformer输出文本的Embedding向量,最后使用文本Embedding向量替换原ADM模型输入中的类别Embedding向量。另外,文本Embedding向量还会经过投影与注意力层中的KKVV拼接在一起,通过注意力机制指导扩散过程。
GLIDE的代码开源,代码地址是:github.com/openai/guid…。通过Transformer输出文本的Embedding向量作为模型输入的代码在glide_text2im/text2im_model.py中,如下所示:

    def forward(self, x, timesteps, tokens=None, mask=None):
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.xf_width:
            text_outputs = self.get_text_emb(tokens, mask)
            xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
            emb = emb + xf_proj.to(emb)
        else:
            xf_out = None
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, xf_out)
            hs.append(h)
        h = self.middle_block(h, emb, xf_out)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, xf_out)
        h = h.type(x.dtype)
        h = self.out(h)
        return h

注意力层拼接文本Embedding向量的代码在glide_text2im/unet.py中,如下所示:

class QKVAttention(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv, encoder_kv=None):
        """
        Apply QKV attention.

        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        if encoder_kv is not None:
            assert encoder_kv.shape[1] == self.n_heads * ch * 2
            ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
            # 拼接文本Embedding向量到k
            k = th.cat([ek, k], dim=-1)
            # 拼接文本Embedding向量到v
            v = th.cat([ev, v], dim=-1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

Classifier-Free Guidance

GLIDE也使用了Classifier-Free Guidance机制,只是将类别替换为文本,因此,对模型所预测噪声ϵθ(xt,y)epsilon_theta(x_t,y)进行调整的公式如下:
ϵ^θ(xt,c)=ϵθ(xt)+s⋅(ϵθ(xt,c)−ϵθ(xt))hat{epsilon}_theta(x_t,c)=epsilon_theta(x_t)+scdot(epsilon_theta(x_t,c)-epsilon_theta(x_t))
GLIDE在上一节已训练得到基于文本的模型、可预测ϵθ(xt,c)epsilon_theta(x_t,c)的基础上,对模型进行微调,将20%的文本Token序列替换成空序列,从而使得模型在具备预测ϵθ(xt,c)epsilon_theta(x_t,c)的基础上,能够进一步预测ϵθ(xt)epsilon_theta(x_t),从而在采样时,能够基于调整后的噪声ϵ^θ(xt,c)hat{epsilon}_theta(x_t,c)降噪生成图片。

CLIP Guidance

GLIDE也使用了CLIP Guidance机制,但从实验结果上,其效果不如Classifier-Free Guidance。

参考文献

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

中康数字科技利用飞桨PaddleNLP构建医学知识图谱,提升医学信息抽取准确性

2023-11-28 0:02:14

AI教程

AIGC应用与创新分享:利用大语言模型开放知识的通用推荐框架

2023-11-28 2:22:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索