中山大学提出SimAM:无参Attention!助力分类/检测/分割涨点!

举报
风吹稻花香 发表于 2021/08/03 01:21:21 2021/08/03
【摘要】 http://proceedings.mlr.press/v139/yang21o.html code: https://github.com/ZjjConan/SimAM 在正式介绍本文所提注意力模块之前,我们先对现有代表性注意力模块(比如SE、CBAM、GC)进行简要总结;然后,我们再引出本文所提完全不同架构的注意力模块。 Overview of Exis...

http://proceedings.mlr.press/v139/yang21o.html

code: https://github.com/ZjjConan/SimAM

在正式介绍本文所提注意力模块之前,我们先对现有代表性注意力模块(比如SE、CBAM、GC)进行简要总结;然后,我们再引出本文所提完全不同架构的注意力模块。

Overview of Existing Attention Modules

图片

上图a与b列出了现有两种类型的注意力模块:

  • 通道注意力:1D注意力,它对不同通道区别对待,对所有位置同等对待;

  • 空域注意力:2D注意力,它对不同位置区别对待,对所有通道同等对待。

以下图为例,SE缺失了关于"grey_whale"的某些重要成分。我们认为3D注意力比1D和2D更佳,进而提出了上图c的3D注意力模块。

图片

现有注意力模块的另一个重要影响因素:权值生成方法。现有注意力往往采用额外的子网络生成注意力权值,比如SE的GAP+FC+ReLU+FC+Sigmoid。更多注意力模块的操作、参数量可参考下表。总而言之,现有注意力的结构设计需要大量的工程性实验。我们认为:注意力机制的实现应当通过神经科学中的某些统一原则引导设计

图片

Our Attention Module

已有研究BAM、CBAM分别将空域注意力与通道注意力进行并行或串行组合。然而,人脑的两种注意力往往是协同工作,因此,我们提出了统一权值的注意力模块。

为更好的实现注意力,我们需要评估每个神经元的重要性。在神经科学中,信息丰富的神经元通常表现出与周围神经元不同的放电模式。而且,激活神经元通常会抑制周围神经元,即空域抑制。换句话说,具有空域抑制效应的神经元应当赋予更高的重要性。最简单的寻找重要神经元的方法:度量神经元之间的线性可分性。因此,我们定义了如下能量函数:

其中,。最小化上述公式等价于训练同一通道内神经元t与其他神经元之间的线性可分性。为简单起见,我们采用二值标签,并添加正则项,最终的能量函数定义如下:

理论上,每个通道有个能量函数。幸运的是,上述公式具有如下解析解:

其中,。因此,最小能量可以通过如下公式得到:

上述公式意味着:能量越低,神经元t与周围神经元的区别越大,重要性越高。因此,神经元的重要性可以通过得到。

到目前为止,我们推导了能量函数并挖掘了神经元的重要性。按照注意力机制的定义,我们需要对特征进行增强处理:

下图给出了SimAM的pytorch风格实现code。

图片

Experiments

图片

我把最后一层改为320,模型6.21m,荣耀9 上64*64 25ms。

测试代码:


  
  1. import functools
  2. import torch
  3. from torch import nn
  4. from torch import Tensor
  5. # from .utils import load_state_dict_from_url
  6. from typing import Callable, Any, Optional, List
  7. import torch
  8. import torch.nn as nn
  9. class Simam_module(torch.nn.Module):
  10. def __init__(self, channels = None, e_lambda = 1e-4):
  11. super(Simam_module, self).__init__()
  12. self.activaton = nn.Sigmoid()
  13. self.e_lambda = e_lambda
  14. def __repr__(self):
  15. s = self.__class__.__name__ + '('
  16. s += ('lambda=%f)' % self.e_lambda)
  17. return s
  18. @staticmethod
  19. def get_module_name():
  20. return "simam"
  21. def forward(self, x):
  22. b, c, h, w = x.size()
  23. n = w * h - 1
  24. x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
  25. y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
  26. return x * self.activaton(y)
  27. def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
  28. """
  29. This function is taken from the original tf repo.
  30. It ensures that all layers have a channel number that is divisible by 8
  31. It can be seen here:
  32. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  33. :param v:
  34. :param divisor:
  35. :param min_value:
  36. :return:
  37. """
  38. if min_value is None:
  39. min_value = divisor
  40. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  41. # Make sure that round down does not go down by more than 10%.
  42. if new_v < 0.9 * v:
  43. new_v += divisor
  44. return new_v
  45. class ConvBNActivation(nn.Sequential):
  46. def __init__(self, in_planes: int, out_planes: int, kernel_size: int = 3, stride: int = 1, groups: int = 1,
  47. norm_layer: Optional[Callable[..., nn.Module]] = None,
  48. activation_layer: Optional[Callable[..., nn.Module]] = None,
  49. attention_module: Optional[Callable[..., nn.Module]] = None, ) -> None:
  50. padding = (kernel_size - 1) // 2
  51. if norm_layer is None:
  52. norm_layer = nn.BatchNorm2d
  53. if activation_layer is None:
  54. activation_layer = nn.ReLU6
  55. if attention_module is not None:
  56. if type(attention_module) == functools.partial:
  57. module_name = attention_module.func.get_module_name()
  58. else:
  59. module_name = attention_module.get_module_name()
  60. if module_name == "simam":
  61. super(ConvBNReLU, self).__init__(
  62. nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
  63. Simam_module(e_lambda=0.1), norm_layer(out_planes), activation_layer(inplace=True))
  64. else:
  65. super(ConvBNReLU, self).__init__(
  66. nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
  67. norm_layer(out_planes), activation_layer(inplace=True))
  68. else:
  69. super(ConvBNReLU, self).__init__(
  70. nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
  71. norm_layer(out_planes), activation_layer(inplace=True))
  72. # necessary for backwards compatibility
  73. ConvBNReLU = ConvBNActivation
  74. class InvertedResidual(nn.Module):
  75. def __init__(self, inp: int, oup: int, stride: int, expand_ratio: int,
  76. norm_layer: Optional[Callable[..., nn.Module]] = None,
  77. attention_module: Optional[Callable[..., nn.Module]] = None) -> None:
  78. super(InvertedResidual, self).__init__()
  79. self.stride = stride
  80. assert stride in [1, 2]
  81. if norm_layer is None:
  82. norm_layer = nn.BatchNorm2d
  83. hidden_dim = int(round(inp * expand_ratio))
  84. self.use_res_connect = self.stride == 1 and inp == oup
  85. layers: List[nn.Module] = []
  86. if expand_ratio != 1:
  87. # pw
  88. layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
  89. layers.extend([# dw
  90. ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
  91. attention_module=attention_module), # pw-linear
  92. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), norm_layer(oup), ])
  93. if attention_module is not None:
  94. if type(attention_module) == functools.partial:
  95. module_name = attention_module.func.get_module_name()
  96. else:
  97. module_name = attention_module.get_module_name()
  98. if module_name != "simam":
  99. # print(attention_module)
  100. layers.append(attention_module(oup))
  101. self.conv = nn.Sequential(*layers)
  102. def forward(self, x: Tensor) -> Tensor:
  103. if self.use_res_connect:
  104. return x + self.conv(x)
  105. else:
  106. return self.conv(x)
  107. class MobileNetV2(nn.Module):
  108. def __init__(self, num_classes: int = 1000, width_mult: float = 1.0,
  109. inverted_residual_setting: Optional[List[List[int]]] = None, round_nearest: int = 8,
  110. attention_module: Optional[Callable[..., nn.Module]] = None) -> None:
  111. super(MobileNetV2, self).__init__()
  112. block = InvertedResidual
  113. norm_layer = nn.BatchNorm2d
  114. input_channel = 32
  115. last_channel = 320
  116. if inverted_residual_setting is None:
  117. inverted_residual_setting = [# t, c, n, s
  118. [1, 16, 1, 1], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 2, 2],
  119. [6, 320, 1, 1], ]
  120. # only check the first element, assuming user knows t,c,n,s are required
  121. if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
  122. raise ValueError("inverted_residual_setting should be non-empty "
  123. "or a 4-element list, got {}".format(inverted_residual_setting))
  124. # building first layer
  125. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  126. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  127. features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
  128. # building inverted residual blocks
  129. for t, c, n, s in inverted_residual_setting:
  130. output_channel = _make_divisible(c * width_mult, round_nearest)
  131. for i in range(n):
  132. stride = s if i == 0 else 1
  133. features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer,
  134. attention_module=attention_module))
  135. input_channel = output_channel
  136. # building last several layers
  137. features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
  138. # make it nn.Sequential
  139. self.features = nn.Sequential(*features)
  140. # building classifier
  141. self.classifier = nn.Sequential(# nn.Dropout(0.2),
  142. nn.Linear(self.last_channel, num_classes), )
  143. # weight initialization
  144. for m in self.modules():
  145. if isinstance(m, nn.Conv2d):
  146. nn.init.kaiming_normal_(m.weight, mode='fan_out')
  147. if m.bias is not None:
  148. nn.init.zeros_(m.bias)
  149. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  150. nn.init.ones_(m.weight)
  151. nn.init.zeros_(m.bias)
  152. elif isinstance(m, nn.Linear):
  153. nn.init.normal_(m.weight, 0, 0.01)
  154. if m.bias is not None:
  155. nn.init.zeros_(m.bias)
  156. def _forward_impl(self, x: Tensor) -> Tensor:
  157. # This exists since TorchScript doesn't support inheritance, so the superclass method
  158. # (this one) needs to have a name other than `forward` that can be accessed in a subclass
  159. x = self.features(x)
  160. # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
  161. x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
  162. x = self.classifier(x)
  163. return x
  164. def forward(self, x: Tensor) -> Tensor:
  165. return self._forward_impl(x)
  166. if __name__ == '__main__':
  167. kwargs = {}
  168. kwargs["num_classes"] = 6
  169. kwargs["attention_module"] = Simam_module(e_lambda=0.1)
  170. model = MobileNetV2(**kwargs)
  171. size = 64
  172. # model.cuda()
  173. model.eval()
  174. model_path = "dicenet.pth"
  175. torch.save(model.state_dict(), model_path)
  176. import os
  177. import time
  178. fsize = os.path.getsize(model_path)
  179. fsize = fsize / float(1024 * 1024)
  180. print(f"model size {round(fsize, 2)} m")
  181. input = torch.rand(2, 3, size, size)#.cuda()
  182. for i in range(15):
  183. t1 = time.time()
  184. loc = model(input)
  185. cnt = time.time() - t1
  186. print(cnt, loc.size())

文章来源: blog.csdn.net,作者:AI视觉网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/119331832

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。