TransGAN:两个纯粹的Transformer可以组成一个强大的GAN-论文精读

举报
中杯可乐多加冰 发表于 2022/11/03 11:30:46 2022/11/03
【摘要】 TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12...

TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12月。
该文章旨在仅使用Transformer网络设计GAN。Can we build a strong GAN completely free of convolutions?

论文地址:https://arxiv.org/abs/2102.07074
代码地址:https://github.com/VITA-Group/TransGAN

本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。

一、原文摘要

最近,人们对Transformer产生了爆炸性的兴趣,这表明Transformer有可能成为计算机视觉任务(如分类、检测和分割)的强大“通用”模型。虽然这些尝试主要研究区分模型,但我们探索了一些更为困难的视觉任务,例如生成性对抗网络(GAN)。
我们的目标是进行第一次试验性研究,仅使用纯Transformer架构,构建完全没有卷积的GAN,我们的vanilla-GAN架构被称为TransGAN,包括一个基于内存友好的Transformer的生成器,该生成器可逐渐提高特征分辨率,并相应地包含一个多尺度鉴别器,可同时捕获语义上下文和低级纹理。
在此基础上,我们引入了新的网格自关注模块,以进一步缓解内存瓶颈,从而将TransGAN扩展到高分辨率生成。我们还开发了一个独特的训练配方,包括一系列可以缓解TransGAN训练不稳定性问题的技术,如数据增强、修改的标准化和相对位置编码
与目前最先进的使用卷积主干的GANs相比,我们的架构实现了极具竞争力的性能。TransGAN能够生成具有高保真度和合理纹理细节的各种视觉示例。此外,通过可视化训练动态,我们深入研究了基于Transformer的生成模型,以了解它们的行为与卷积模型的区别。

二、介绍

而本文主要创新点如下:

  1. 新颖的结构设计:第一次使用纯粹的Transformer来构建无卷积的GAN。TransGAN定制了一个便于记忆的生成器和多尺度鉴别器,并进一步配备了一种新的网格自我注意机制。这些体系结构组件经过深思熟虑的设计,以平衡内存效率、全局特征统计和局部细节与空间差异。
  2. 新的训练方法:研究了一些技术来更好地训练TransGAN,包括利用数据增强、修改层规范化,以及对生成器和鉴别器采用相对位置编码。且进行了广泛的消融研究和讨论。
  3. 与当前最先进的GANs相比,TransGAN实现了极具竞争力的性能

三、为什么提出TransGAN?

  1. 一方面传统GAN存在模式崩溃问题,训练不稳定,在这几年里为了致力于稳定GAN训练,研究者们引入了各种正规化术语,更好的损失函数,以及各种变体训练方法,但是从2015的DC-GAN使用CNN架构来扩展GAN以来,每一个成功的GAN都依赖于基于CNN的生成器和鉴别器。
  2. 另一方面传统GAN都是基于卷积的,缺点是只有局部感受野,深层次会丢失细节

最初的transformer是为NLP设计的,在NLP中,多头自我注意层和前向反馈网络层层被堆叠起来,以捕捉单词之间的长期相关性,最近,Transformer在图像生成方面也有进展,通过替换CNN的某些组件,将Transformer模块结合到图像生成模型中,然而其CNN的整体架构仍然存在(包括用于发生器的CNN编码器/解码器,以及完全基于CNN的鉴别器)。

四、主要框架

在这里插入图片描述

4.1、生成器

如果以逐个像素作为输入,32*32的低分辨率图像也会导致1024长度的序列,与单词序列相比,数据指数级增长,如果再加入注意力,则参数爆炸式增长。于是作者的策略是分阶段迭代提高分辨率,即增加输入序列同时逐渐降低维数。
在这里插入图片描述

  1. 输入为随机噪声,首先通过多层感知器(MLP)形成一段长序列。
  2. 然后序列经过transformer的encoder块,输出长序列。
  3. 上采样模块包括重塑、上采样和重塑阶段。其首先将该长序列重塑为8×8×C(将1D的序列转换成了2D的图像特征 X i R H i × W i × C X_{i} \in \mathbb{R}^{H_{i} \times W_{i} \times C} ),然后使用双三次插值的方法进行上采样,使之在维度不减的情况下提高采样分辨率,变成16×16×C的图像特征。然后又一次重塑成1D的序列。
  4. 将重塑后的1D的序列再次经过步骤2、步骤3,生成32×32×C的图像特征,重塑成1D序列。又一次经过步骤2、步骤3生成64×64×C的图像特征,重塑成1D序列。然后再次经过步骤2,但不急着做步骤3。
  5. 此时的上采样模块进行了改进,与3不同的是双三次插值改成了pixel shuffle模块,将4的长序列输入重塑为64×64×C的图像特征( X i R H 64 × W 64 × C X_{i} \in \mathbb{R}^{H_{64} \times W_{64} \times C} ),使用pixel shuffle进行上采样,变为 128 × 128 × c 4 128×128×\frac{c}{4} 的图像特征( X i R H 128 × W 128 × c 4 X_{i} \in \mathbb{R}^{H_{128} \times W_{128} \times \frac{c}{4}} ),然后将其又重塑为1D的序列。
  6. 后面即重复一次transformer的encoder块,然后将生成的长序列重塑为 128 × 128 × c 4 128×128×\frac{c}{4} ,再次经过一次pixel shuffle,从 128 × 128 × c 4 128×128×\frac{c}{4} 特征变为 256 × 256 × c 16 256×256×\frac{c}{16} 特征,然后最后进行一次线性加权,得到256×256×3的图像。

4.2、鉴别器

鉴别器的任务是区分真假图像,也就是分类任务。作者设计了一个多尺度的鉴别器,在不同的阶段以不同大小的面片作为输入。(因为三种不同的序列能够同时提取语义结构和纹理细节。)
在这里插入图片描述

  1. 首先将图像分割成同样大小的P×P、2P×2P、4P×4P个块,作为不同的尺度。
  2. 图像第一个尺度的大小为 ( H P × W P ) × 3 \left(\frac{H}{P} \times \frac{W}{P}\right) \times 3 ,首先通过线性加权将其转换成 ( H P × W P ) × C 4 \left(\frac{H}{P} \times \frac{W}{P}\right) \times \frac{C}{4} ,同样第二个尺度 ( H 2 P × W 2 P ) × 3 \left(\frac{H}{2P} \times \frac{W}{2P}\right) \times 3 转换成 ( H 2 P × W 2 P ) × C 4 \left(\frac{H}{2P} \times \frac{W}{2P}\right) \times \frac{C}{4} ,第三个尺度 ( H 4 P × W 4 P ) × 3 \left(\frac{H}{4P} \times \frac{W}{4P}\right) \times 3 转换成 ( H 4 P × W 4 P ) × C 2 \left(\frac{H}{4P} \times \frac{W}{4P}\right) \times \frac{C}{2} 。第一个尺度转成的token是为了给第一个的transformer块作为输入,第二个和第三个分别连接到第二、三阶段的token后(捕捉更多的纹理信息)。
  3. 与生成器反过来类似,我们首先将token输入transformer块,然后将输出的一维向量重塑为二维特征图,并在每个阶段之间采用平均池层对特征图分辨率进行降采样。
  4. 在这些块的末尾,在1D序列的开始处附加[cls]标记,以输出真/假预测。

4.3、Self-Attention的一种变体:Grid Self-Attention

self-attn

Self-attention虽然使生成器能够捕获全局对应关系,但在建模高的分辨率时,会出现超长序列,会极大影响效率,于是作者提出了Grid Self-Attention:
在这里插入图片描述
Grid Self-Attention将全尺寸特征映射划分为几个非重叠网格,网格内进行Self-attention(分成多个块,块内做标准的self-attention,然后将每个块相连)。

Grid Self-Attention在TransGAN中,只被运用在64×64以上分辨率以减少消耗,64以下的仍然采用标准的self-attention。这样的做法从战略上平衡局部细节和全局效率。

五、改进性策略

5.1、数据增强

对比卷积来说,Transforme是更需要数据的,不同类型的强大数据增强可以为Transformer提供高效的训练。
作者从三个角度进行了数据增强:Translation, Cutout, Color,让TransGAN的性能有了惊人的提高
Translation是做些许偏移,Cutout在图像上加一些纯白或者纯黑的像素点,Color就是改变图像的对比度、饱和度。

5.2、相对位置编码

虽然经典的transformer已经有相对位置编码,但是其发挥出的作用不够明显。
作者将 Attention ( Q , K , V ) = softmax ( ( Q K T d k V ) \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\left(\frac{Q K^{T}}{\sqrt{d_{k}}} V\right)\right. 改为 Attention ( Q , K , V ) = softmax ( ( ( Q K T d k + E ) V ) \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\left(\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+E\right) V\right)\right. ,其中E取自矩阵M,并作为残差项添加(M是同时考虑H轴和W轴,用表示相对位置的参数化矩阵 M R ( 2 H 1 ) × ( 2 W 1 ) M \in \mathbb{R}^{(2 H-1) \times(2 W-1)}

相对位置编码学习了内容之间更强的“关系”,能够极大提升性能。

5.3、修正后的归一化

归一化层(Normalization )有助于稳定深层神经网络的深层学习训练,效果显著,原版标准归一化使用的是layer normalization,作者提出了一种 Y = X / 1 C i = 0 C 1 ( X i ) 2 + ϵ Y=X / \sqrt{\frac{1}{C} \sum_{i=0}^{C-1}\left(X^{i}\right)^{2}+\epsilon} ,其中 ϵ = 1 e 8 {\epsilon}=1e-8 ,X和Y表示缩放层前后的标记,C代表嵌入维度。(类似于AlexNet中曾经使用的局部响应规范化)

六、实验

6.1、数据集

CIFAR-10、STL10和CelebA数据集。

6.2、实验设置

遵循WGAN的设置,并使用WGAN-GP损失, 生成器的batch大小为128,鉴别器的batch大小为64,选择DiffAug作为培训过程中的基本增强策略。评价指标使用IS和FID。

6.3、实验结果

实验细节见论文
在这里插入图片描述
在这里插入图片描述

6.4、消融实验

实验细节见论文
在这里插入图片描述

6.5、实验消耗

实验细节见论文
在这里插入图片描述

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。