轻量级Vision-Transformer:EdgeViTs复现

举报
李长安 发表于 2023/02/16 16:31:07 2023/02/16
【摘要】 一个新的轻量级ViTs家族,也是首次使基于Self-attention的视觉模型在准确性和设备效率之间的权衡中达到最佳轻量级CNN的性能。

轻量级Vision-Transformer:EdgeViTs复现

摘要

  在计算机视觉领域,基于Self-attention的模型(如(ViTs))已经成为CNN之外的一种极具竞争力的架构。尽管越来越强的变种具有越来越高的识别精度,但由于Self-attention的二次复杂度,现有的ViT在计算和模型大小方面都有较高的要求。 虽然之前的CNN的一些成功的设计选择(例如,卷积和分层结构)已经被引入到最近的ViT中,但它们仍然不足以满足移动设备有限的计算资源需求。这促使人们最近尝试开发基于最先进的MobileNet-v2的轻型MobileViT,但MobileViT与MobileNet-v2仍然存在性能差距。 在这项工作中,作者进一步推进这一研究方向,引入了EdgeViTs,一个新的轻量级ViTs家族,也是首次使基于Self-attention的视觉模型在准确性和设备效率之间的权衡中达到最佳轻量级CNN的性能。

1 EdgeViTs

1.1 总体架构

  为了设计适用于移动/边缘设备的轻量级ViT,作者采用了最近ViT变体中使用的分层金字塔结构(图2(a))。Pyramid Transformer模型通常在不同阶段降低了空间分辨率同时也扩展了通道维度。每个阶段由多个基于Transformer Block处理相同形状的张量,类似ResNet的层次设计结构。

  在这项工作中,作者深入到Transformer Block,并引入了一个比较划算的Bottlneck,Local-Global-Local(LGL)(图2(b))。LGL通过一个稀疏注意力模块进一步减少了Self-attention的开销(图2©),实现了更好的准确性-延迟平衡。

1.2 Local-Global-Local bottleneck(LGL)

  与以前在每个空间位置执行Self-attention的Transformer Block相比,LGL Bottleneck只对输入Token的子集计算Self-attention,但支持完整的空间交互,如在标准的Multi-Head Self-attention(MHSA)中。既会减少Token的作用域,同时也保留建模全局和局部上下文的底层信息流。

  为了实现这一点,作者将Self-attention分解为连续的模块,处理不同范围内的空间Token(图2(b))。

  这里引入了3种有效的操作:

  • Local aggregation:仅集成来自局部近似Token信号的局部聚合
  • Global sparse attention:建模一组代表性Token之间的长期关系,其中每个Token都被视为一个局部窗口的代表;
  • Local propagation:将委托学习到的全局上下文信息扩散到具有相同窗口的非代表Token。

  • Local aggregation

  对于每个Token,利用Depth-wise和Point-wise卷积在大小为k×k的局部窗口中聚合信息(图3(a))。

  • Global sparse attention

  对均匀分布在空间中的稀疏代表性Token集进行采样,每个r×r窗口有一个代表性Token。这里,r表示子样本率。然后,只对这些被选择的Token应用Self-attention(图3(b))。这与所有现有的ViTs不同,在那里,所有的空间Token都作为Self-attention计算中的query被涉及到。

  • Local propagation

  通过转置卷积将代表性 Token 中编码的全局上下文信息传播到它们的相邻的 Token 中(图 3©)。

2 代码复现

import paddle
import paddle.nn as nn
from paddle.nn import Conv2D  as Conv2d
from paddle.nn import BatchNorm2D  as BatchNorm2d
from paddle.nn import Linear
from paddle.nn import AvgPool2D as AvgPool2d
from paddle.nn import Conv2DTranspose as ConvTranspose2d
from paddle.nn import LayerNorm, GELU


class Residual(nn.Layer):
    def __init__(self, module):
        super().__init__()
        self.module = module
    
    def forward(self, x):
        return x + self.module(x)

class LocalAgg(nn.Layer):  
    def __init__(self, dim):
        super().__init__()
        self.conv1 = Conv2d(dim, dim, 1)  
        self.conv2 = Conv2d(dim, dim, 3, padding=1, groups=dim)  
        self.conv3 = Conv2d(dim, dim, 1)  
        self.norm1 = BatchNorm2d(dim)  
        self.norm2 = BatchNorm2d(dim)  
          

    def forward(self, x):  
        """  
        [B, C, H, W] = x.shape  
        """  
        x = self.conv1(self.norm1(x))  
        x = self.conv2(x)  
        x = self.conv3(self.norm2(x))  
        return x  

class GlobalSparseAttn(nn.Layer):  
    def __init__(self, dim, sample_rate = 4, scale = 1):
        super().__init__()  
        self.head_dim = int(48)//int(1)
        self.num_heads = int(1)
        self.scale = scale  
        self.qkv = Linear(dim, dim * 3)  
        self.sampler = AvgPool2d(1, stride=sample_rate)  
        self.LocalProp = ConvTranspose2d(dim, dim, kernel_size=sample_rate, stride=sample_rate, groups=dim  
        )  
        self.proj = Linear(dim, dim)  


    def forward(self, x):  
        """  
        [B, C, H, W] = x.shape  
        """  
        x = self.sampler(x)
        [B, C, H, W] = x.shape
        x = x.flatten(2)
        x = x.transpose([0,2,1])

        x = self.qkv(x)
        x = x.transpose([0, 2, 1])
        x = x.reshape([1, 144, 14, 14])
        q, k, v = x.reshape([B, self.num_heads, -1, H*W]).split([self.head_dim, self.head_dim, self.head_dim], axis=2)
       
        attn = (q.transpose([0, 1, 3, 2]) @ k)

        attn = nn.functional.softmax(attn)

        x = v  @  attn.transpose([0, 1, 3, 2])

        x = x.reshape([B, -1, H, W])

        x = self.LocalProp(x)  
       
        x = paddle.nn.functional.layer_norm(x, x.shape[1:])
        x = x.flatten(2)
        x = x.transpose([0,2,1])
        x = self.proj(x)  
        x = x.transpose([0,2,1])
        x = x.reshape([1, 48, 56, 56])
        return x  

class DownSampleLayer(nn.Layer):  
    def __init__(self, dim_in=3, dim_out=48, downsample_rate=4):  
        super().__init__()
        self.downsample = Conv2d(dim_in, dim_out, kernel_size=downsample_rate, stride=  
        downsample_rate)  

    def forward(self, x):  
        x = self.downsample(x)
        x = paddle.nn.functional.layer_norm(x, x.shape[1:])

        return x  

class PatchEmbed(nn.Layer):  
    def __init__(self, dim):
        super().__init__()
        self.embed = Conv2d(dim, dim, 3, padding=1, groups=dim)  
    def forward(self, x):  
        return x + self.embed(x)  

class FFN(nn.Layer):  
    def __init__(self, dim=3156):
        super().__init__()  
        self.fc1 = nn.Linear(dim, dim*4)  
        self.fc2 = nn.Linear(dim*4, dim)  
          

    def forward(self, x):
        x = x.flatten(2)
        x = x.transpose([0,2,1])
       
        x = self.fc1(x)  
        x = nn.functional.gelu(x) 
        x = self.fc2(x) 
       
        x = x.transpose([0,2,1])
        x = x.reshape([1, 48, 56, 56])
        return x  

class EdgeViT(nn.Layer):
    def __init__(self, dim_in=3, dim_out=48, downsample_rate=4, dim=48):
        super().__init__()

       
        self.downsample1 = DownSampleLayer(dim_in=3, dim_out=48, downsample_rate=4)
        self.patchembeding1 = PatchEmbed(dim=48)
        self.residual_add1 = Residual(LocalAgg(dim=48))
        self.residual_add1_1 = Residual(FFN(dim=48))

        self.patchembeding2 = PatchEmbed(dim=48)
        self.residual_add2 = Residual(GlobalSparseAttn(dim=48))

    def forward(self, x):

        x = self.downsample1(x)
        x = self.patchembeding1(x)
        x = self.residual_add1(x)
        x = self.residual_add1_1(x)
        x = self.patchembeding2(x)
        x = self.residual_add2(x)

        return x

cnn = EdgeViT()



paddle.summary(cnn,(1,3,224,224))
------------------------------------------------------------------------------
   Layer (type)        Input Shape          Output Shape         Param #    
==============================================================================
     Conv2D-1       [[1, 3, 224, 224]]    [1, 48, 56, 56]         2,352     
DownSampleLayer-1   [[1, 3, 224, 224]]    [1, 48, 56, 56]           0       
     Conv2D-2       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
   PatchEmbed-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
  BatchNorm2D-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192      
     Conv2D-3       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352     
     Conv2D-4       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
  BatchNorm2D-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192      
     Conv2D-5       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352     
    LocalAgg-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
     Linear-1        [[1, 3136, 48]]       [1, 3136, 192]         9,408     
     Linear-2        [[1, 3136, 192]]      [1, 3136, 48]          9,264     
      FFN-1         [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-2      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
     Conv2D-6       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
   PatchEmbed-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
   AvgPool2D-1      [[1, 48, 56, 56]]     [1, 48, 14, 14]           0       
     Linear-3         [[1, 196, 48]]       [1, 196, 144]          7,056     
Conv2DTranspose-1   [[1, 48, 14, 14]]     [1, 48, 56, 56]          816      
     Linear-4        [[1, 3136, 48]]       [1, 3136, 48]          2,352     
GlobalSparseAttn-1  [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-3      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
==============================================================================
Total params: 37,776
Trainable params: 37,392
Non-trainable params: 384
------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 27.85
Params size (MB): 0.14
Estimated Total Size (MB): 28.57
------------------------------------------------------------------------------






{'total_params': 37776, 'trainable_params': 37392}

总结

  在本工作中,作者通过引入一个基于Self-attention和卷积的最优集成的高成本的local-global-local(LGL)信息交换瓶颈来实现的。对于移动设备专用的评估,不依赖于不准确的proxies,如FLOPs的数量或参数,而是采用了一种直接关注设备延迟和能源效率的实用方法。

  本次复现主要参考了作者的伪代码,成功使用飞桨复现了其主要结构LGL。大家感兴趣的可以跟随论文也进行一次代码复现,根据论文的描述,参考其伪代码,能够更加深入的了解作者的思想,并且也能提升Coding能力。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。