MaxViT: Multi-Axis Vision Transformer论文浅析与代码实战

举报
李长安 发表于 2023/02/16 16:20:13 2023/02/16
【摘要】 在该论文中,作者提出了一种新型的Transformer模块,称为多轴自注意力(multi-axis self-attention, Max-SA),它可以作为基本的架构组件,在单个块中执行局部和全局空间交互。与完全自注意力相比,Max-SA具有更大的灵活性和效率,即自然适应不同的输入长度,具有线性复杂度。此外,Max-SA仅具有线性复杂度,可以用作网络任何层的通用独立注意力模块,增加少量的计算量。

MaxViT: Multi-Axis Vision Transformer论文浅析

1、MaxViT主体结构与创新点

1.1 研究动机

  卷积神经网络经历了从AlexNet到ResNet再到Vision Transformer,其在计算机视觉任务中的表现越来越好,通过注意力机制,Vision Transformer取得了非常好的效果。然而,在没有充分的预训练情况下,Vision Transformer通常不会取得很好的效果,并且由于注意力算子需要二次复杂度,因此在层次网络的早期或高分辨率阶段通过完全注意力进行全局交互的计算量很大。如何有效地结合全局和局部交互,在计算预算下平衡模型大小和可推广性仍然是一个具有挑战性的问题。

  在该论文中,作者提出了一种新型的Transformer模块,称为多轴自注意力(multi-axis self-attention, Max-SA),它可以作为基本的架构组件,在单个块中执行局部和全局空间交互。与完全自注意力相比,Max-SA具有更大的灵活性和效率,即自然适应不同的输入长度,具有线性复杂度。此外,Max-SA仅具有线性复杂度,可以用作网络任何层的通用独立注意力模块,增加少量的计算量。

  其主要创新点包含如下三点:

  • MaxViT是一个通用的Transformer结构,在每一个块内都可以实现局部与全局之间的空间交互,同时可适应不同分辨率的输入大小。
  • Max-SA通过分解空间轴得到窗口注意力(Block attention)与网格注意力(Grid attention),将传统计算方法的二次复杂度降到线性复杂度。
  • MBConv作为自注意力计算的补充,利用其固有的归纳偏差来提升模型的泛化能力,避免陷入过拟合。

1.2 Max-SA主要结构




  作者通过引入Max-SA模块,将传统的自注意机制分解为窗格注意力(Block attention)与网格注意力(Grid attention)两种稀疏形式,在不损失非局部性的情况下,将传统注意力机制的计算复杂度从二次复杂度降低到线性。并且Max-SA具有灵活性和可伸缩性,我们可以通过简单地将Max-SA与MBConv在分层体系结构中叠加,从而构建一个称为MaxViT的视觉Backbone,MaxViT主要结构如上图2所示。

class MaxViT(nn.Layer):
    def __init__(self, args):
        super().__init__()
        self.conv_stem = nn.Sequential(nn.Conv2D(args['input_dim'], args['stem_dim'], 3,2,3//2),
                                       nn.BatchNorm2D(args['stem_dim']),
                                       nn.GELU(),
                                       nn.Conv2D(args['stem_dim'], args['stem_dim'], 3,1,3//2),
                                       nn.BatchNorm2D(args['stem_dim']),
                                       nn.GELU())
        in_dim = args['stem_dim']
        self.max_blocks = nn.LayerList([])
        for i,num_block in enumerate(args['stage_num_block']):
            layers = nn.LayerList([])
            out_dim = args['stage_dim']*(2**i)
            num_head = args['num_heads']*(2**i)
            for i in range(num_block):
                pooling_size = args['pooling_size']if i == 0 else 1
                layers.append(Max_Block(in_dim,out_dim,num_head,args['block_size'], 
                                        args['grid_size'],args['mbconv_ksize'],pooling_size,
                                        args['mbconv_expand_rate'],args['se_rate'],args['mlp_ratio'],
                                        args['qkv_bias'],args['qk_scale'], args['drop'], args['attn_drop'],
                                        args['drop_path'],args['act_layer'] ,args['norm_layer']))
                in_dim = out_dim
            self.max_blocks.append(layers)
        self.last_conv = nn.Sequential(nn.Conv2D(in_dim,in_dim,1,),
                                       nn.BatchNorm2D(in_dim),
                                       nn.GELU())
        self.proj = nn.Linear(in_dim,args['num_classes'])
        self.softmax = nn.Softmax(1)
        
    def forward(self, x):
        x = self.conv_stem(x)
        for blocks in self.max_blocks:
            for block in blocks:
                x = block(x)
        x = self.last_conv(x)
        x = self.softmax(self.proj(x.mean([2, 3])))
        return x

1.3 Multi-axis Attention 详解




  与局部卷积相比,全局相互作用是自注意力机制的优势之一。然而,直接将注意力应用于整个空间在计算上是不可行的,因为注意力算子需要二次复杂度,为了解决全局自注意力机制导致的二次计算复杂度,作者通过分解空间轴得到局部(block attention)与全局(grid attention)两种稀疏形式,巧妙的解决了计算复杂度的问题。如上所示,Max-SA模块主要包含Block Attention与Grid Attention两个部分。

class Max_Block(nn.Layer):
    def __init__(self, in_dim, out_dim , num_heads=8.,block_size=(7,7), grid_size=(7,7),
                 mbconv_ksize = 3,pooling_size = 1,mbconv_expand_rate=4,se_reduce_rate=0.25,
                 mlp_ratio=4,qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,drop_path=0., 
                 act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.mbconv = MBConv(in_dim,out_dim,mbconv_ksize,pooling_size,mbconv_expand_rate,se_reduce_rate,drop)
        self.block_attn = Window_Block(out_dim, block_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop, 
                                       attn_drop,drop_path, act_layer ,norm_layer)
        self.grid_attn = Grid_Block(out_dim, grid_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop, 
                                    attn_drop,drop_path, act_layer ,norm_layer)
        
    def forward(self, x):
        x = self.mbconv(x)
        x = self.block_attn(x)
        x = self.grid_attn(x)
        return x
  • Block Attention

  将输入特征图划分为不重叠的窗口, 最后在每一个窗口中执行自注意力计算。虽然避免了全局自注意力机制的复杂计算,但是局部注意模型已经被证明不适用于大规模的数据集。所以作者提出一种稀疏的全局自注意力机制,被称作grid attention(网格注意力机制)。

class Window_Block(nn.Layer):
    def __init__(self, dim, block_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., 
                 attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.block_size = block_size
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.attn = Rel_Attention(dim, block_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self,x):
        assert x.shape[2]%self.block_size[0] == 0 & x.shape[3]%self.block_size[1] == 0, 'image size should be divisible by block_size'
        
        out = block(self.norm1(x),self.block_size)
        out = self.attn(out)
        x = x + self.drop_path(unblock(self.attn(out)))
        out = self.mlp(self.norm2(x))
        x = x + self.drop_path(out)
        return x
  • Grid Attention

  不同于传统使用固定窗口大小来划分特征图的操作,grid attention 使用固定的大小的均匀网格将输人张量网格化, 可以有效平衡局部和全局之间的计算 (且仅具有线性复杂度)。

class Grid_Block(nn.Layer):
    def __init__(self, dim, grid_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., 
                 attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.grid_size = grid_size
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.attn = Rel_Attention(dim, grid_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self,x):
        assert x.shape[2]%self.grid_size[0] == 0 & x.shape[3]%self.grid_size[1] == 0, 'image size should be divisible by grid_size'
        grid_size = (x.shape[2]//self.grid_size[0], x.shape[3]//self.grid_size[1])
        
        out = block(self.norm1(x),grid_size)
        out = out.transpose([0,4,5,3,1,2])
        out = self.attn(out).transpose([0,4,5,3,1,2])
        x = x + self.drop_path(unblock(out))
        out = self.mlp(self.norm2(x))
        x = x + self.drop_path(out)
        return x

1.4 MBConv

  为了获得更丰富的特征表示,首先使用逐点卷积进行通道升维,在升维后的投影空间中进行Depth-wise卷积,紧随其后的SE用于增强重要通道的表征,最后再次使用逐点卷积恢复维度。可用如下公式表示:

  对于每个阶段的第一个MBConv块,下采样是通过应用stride=2的深度可分离卷积( Depthwise Conv3x3)来完成的,而残差连接分支也 应用pooling 和 channel 映射:

  MBConv包含如下特点:

  • 采用了Depthwise Convlution,因此相比于传统卷积,Depthwise Conv的参数能够大大减少;
  • 采用了“倒瓶颈”的结构,也就是说在卷积过程中,特征经历了升维和降维两个步骤,并利用卷积固有的归纳偏置,在一定程度上提升模型的泛化能力与可训练性。
  • 相比于ViT中的显式位置编码,在Multi-axis Attention则使用MBConv来代替,这是因为深度可分离卷积可被视为条件位置编码(CPE)。
class MBConv(nn.Layer):
    def __init__(self,in_dim,out_dim,kernel_size=3,stride_size=1,expand_rate = 4,se_rate = 0.25,dropout = 0.):
        super().__init__()
        hidden_dim = int(expand_rate * out_dim)
        self.bn = nn.BatchNorm2D(in_dim)
        self.expand_conv = nn.Sequential(nn.Conv2D(in_dim, hidden_dim, 1),
                                         nn.BatchNorm2D(hidden_dim),
                                         nn.GELU())
        self.dw_conv = nn.Sequential(nn.Conv2D(hidden_dim, hidden_dim, kernel_size, stride_size, kernel_size//2, groups=hidden_dim),
                                     nn.BatchNorm2D(hidden_dim),
                                     nn.GELU())
        self.se = SE(hidden_dim,max(1,int(out_dim*se_rate)))
        self.out_conv = nn.Sequential(nn.Conv2D(hidden_dim, out_dim, 1),
                                      nn.BatchNorm2D(out_dim))
        if stride_size > 1:
            self.proj = nn.Sequential(nn.MaxPool2D(kernel_size, stride_size, kernel_size//2),
                                      nn.Conv2D(in_dim, out_dim, 1)) 
        else: 
            self.proj = nn.Identity()
    
    def forward(self, x):
        out = self.bn(x)
        out = self.expand_conv(out)
        out = self.dw_conv(out)
        out = self.se(out)
        out = self.out_conv(out)
        return out + self.proj(x)

1.5 Multi-Axis attention与Axial attention区别

  论文所提出的方法不同于 Axial attention。如图 3 所示, 在 Axial attention 中 首先使用列注意力(column-wise attention),然后使用行注意力( row-wise attention) 来计算全局 注意力, 。然而 Multi-Axis attention 则先采用局部注意力 (block attention), 再使用稀疏的全局注意力 (grid attention), 这样的 设计充分考虑了图像的 2D 结构。

2、整体网络结构复现

在论文中,作者基于Max-SA模块搭建了四种网络结构(MaxViT model family (T/S/B/L)),本项目对这四种结构均进行了复现,其网络结构列表如下:

3、网络模型结构输出

import paddle
from maxvit import MaxViT,tiny_args

print(MaxViT(tiny_args)(paddle.zeros([2,3,224,224])).shape)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:654: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


[2, 1000]
model = MaxViT(tiny_args)

paddle.summary(model,(1,3,224,224))
{'total_params': 31001840, 'trainable_params': 30893552}

4、总结

  为解决传统自注意力机制在图像大小方面缺乏的可扩展性,论文提出了一种高效的、可扩展的多轴注意力模型,该模型由Block局部注意和Grid全局注意力两部分组成。本文还提出了一个新的架构,通过有效地混合提出的注意力模型与MBConv卷积,并相应地提出了一个简单的分层视觉骨干,称为MaxViT,通过简单地在多个阶段重复基本构建块。MaxViT允许任意分辨率的输入,实现全局-局部空间交互,且只具有线性复杂度。

5、参考资料

讲解 MaxViT: Multi-Axis Vision Transformer

论文 MaxViT: Multi-Axis Vision Transformer

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。