总结几个好用的CNN模块(Pytorch)

举报
AI浩 发表于 2021/12/23 01:40:36 2021/12/23
【摘要】 总结几个比较好的CNN模块。 SEBlock 代码: class SEBlock(nn.Module):     def __init__(self, input_channels, internal_neurons):       &...

总结几个比较好的CNN模块。

  • SEBlock

代码:


  
  1. class SEBlock(nn.Module):
  2.     def __init__(self, input_channels, internal_neurons):
  3.         super(SEBlock, self).__init__()
  4.         self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1,
  5.                               bias=True, padding_mode='same')
  6.         self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1,
  7.                             bias=True, padding_mode='same')
  8.     def forward(self, inputs):
  9.         x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
  10.         x = self.down(x)
  11.         x = F.leaky_relu(x)
  12.         x = self.up(x)
  13.         x = F.sigmoid(x)
  14.         x = x.repeat(1, 1, inputs.size(2), inputs.size(3))
  15.         return inputs * x
  • ACBlock

代码


  
  1. class CropLayer(nn.Module):
  2.     #   E.g., (-1, 0) means this layer should crop the first and last rows of the feature map. And (0, -1) crops the first and last columns
  3.     def __init__(self, crop_set):
  4.         super(CropLayer, self).__init__()
  5.         self.rows_to_crop = - crop_set[0]
  6.         self.cols_to_crop = - crop_set[1]
  7.         assert self.rows_to_crop >= 0
  8.         assert self.cols_to_crop >= 0
  9.     def forward(self, input):
  10.         if self.rows_to_crop == 0 and self.cols_to_crop == 0:
  11.             return input
  12.         elif self.rows_to_crop > 0 and self.cols_to_crop == 0:
  13.             return input[:, :, self.rows_to_crop:-self.rows_to_crop, :]
  14.         elif self.rows_to_crop == 0 and self.cols_to_crop > 0:
  15.             return input[:, :, :, self.cols_to_crop:-self.cols_to_crop]
  16.         else:
  17.             return input[:, :, self.rows_to_crop:-self.rows_to_crop, self.cols_to_crop:-self.cols_to_crop]
  18. class ACBlock(nn.Module):
  19.     def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1,
  20.                  padding_mode='same', deploy=False,
  21.                  use_affine=True, reduce_gamma=False, use_last_bn=False, gamma_init=None):
  22.         super(ACBlock, self).__init__()
  23.         self.deploy = deploy
  24.         if deploy:
  25.             self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
  26.                                         kernel_size=(kernel_size, kernel_size), stride=stride,
  27.                                         padding=padding, dilation=dilation, groups=groups, bias=True,
  28.                                         padding_mode=padding_mode)
  29.         else:
  30.             self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
  31.                                          kernel_size=(kernel_size, kernel_size), stride=stride,
  32.                                          padding=padding, dilation=dilation, groups=groups, bias=False,
  33.                                          padding_mode=padding_mode)
  34.             self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
  35.             center_offset_from_origin_border = padding - kernel_size // 2
  36.             ver_pad_or_crop = (padding, center_offset_from_origin_border)
  37.             hor_pad_or_crop = (center_offset_from_origin_border, padding)
  38.             if center_offset_from_origin_border >= 0:
  39.                 self.ver_conv_crop_layer = nn.Identity()
  40.                 ver_conv_padding = ver_pad_or_crop
  41.                 self.hor_conv_crop_layer = nn.Identity()
  42.                 hor_conv_padding = hor_pad_or_crop
  43.             else:
  44.                 self.ver_conv_crop_layer = CropLayer(crop_set=ver_pad_or_crop)
  45.                 ver_conv_padding = (0, 0)
  46.                 self.hor_conv_crop_layer = CropLayer(crop_set=hor_pad_or_crop)
  47.                 hor_conv_padding = (0, 0)
  48.             self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
  49.                                       stride=stride,
  50.                                       padding=ver_conv_padding, dilation=dilation, groups=groups, bias=False,
  51.                                       padding_mode=padding_mode)
  52.             self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
  53.                                       stride=stride,
  54.                                       padding=hor_conv_padding, dilation=dilation, groups=groups, bias=False,
  55.                                       padding_mode=padding_mode)
  56.             self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
  57.             self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
  58.             if reduce_gamma:
  59.                 assert not use_last_bn
  60.                 self.init_gamma(1.0 / 3)
  61.             if use_last_bn:
  62.                 assert not reduce_gamma
  63.                 self.last_bn = nn.BatchNorm2d(num_features=out_channels, affine=True)
  64.             if gamma_init is not None:
  65.                 assert not reduce_gamma
  66.                 self.init_gamma(gamma_init)
  67.     def init_gamma(self, gamma_value):
  68.         init.constant_(self.square_bn.weight, gamma_value)
  69.         init.constant_(self.ver_bn.weight, gamma_value)
  70.         init.constant_(self.hor_bn.weight, gamma_value)
  71.         print('init gamma of square, ver and hor as ', gamma_value)
  72.     def single_init(self):
  73.         init.constant_(self.square_bn.weight, 1.0)
  74.         init.constant_(self.ver_bn.weight, 0.0)
  75.         init.constant_(self.hor_bn.weight, 0.0)
  76.         print('init gamma of square as 1, ver and hor as 0')
  77.     def forward(self, input):
  78.         if self.deploy:
  79.             return self.fused_conv(input)
  80.         else:
  81.             square_outputs = self.square_conv(input)
  82.             square_outputs = self.square_bn(square_outputs)
  83.             vertical_outputs = self.ver_conv_crop_layer(input)
  84.             vertical_outputs = self.ver_conv(vertical_outputs)
  85.             vertical_outputs = self.ver_bn(vertical_outputs)
  86.             horizontal_outputs = self.hor_conv_crop_layer(input)
  87.             horizontal_outputs = self.hor_conv(horizontal_outputs)
  88.             horizontal_outputs = self.hor_bn(horizontal_outputs)
  89.             result = square_outputs + vertical_outputs + horizontal_outputs
  90.             if hasattr(self, 'last_bn'):
  91.                 return self.last_bn(result)
  92.             return result
  •  eca_layer


  
  1. class eca_layer(nn.Module):
  2. """Constructs a ECA module.
  3. Args:
  4. channel: Number of channels of the input feature map
  5. k_size: Adaptive selection of kernel size
  6. """
  7. def __init__(self, channel, k_size=3):
  8. super(eca_layer, self).__init__()
  9. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  10. self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
  11. self.sigmoid = nn.Sigmoid()
  12. def forward(self, x):
  13. # x: input features with shape [b, c, h, w]
  14. b, c, h, w = x.size()
  15. # feature descriptor on the global spatial information
  16. y = self.avg_pool(x)
  17. # Two different branches of ECA module
  18. y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
  19. # Multi-scale information fusion
  20. y = self.sigmoid(y)
  21. return x * y.expand_as(x)
  • ChannelAttention


  
  1. class ChannelAttention(nn.Module):
  2. def __init__(self, in_planes, ratio=16):
  3. super(ChannelAttention, self).__init__()
  4. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  5. self.max_pool = nn.AdaptiveMaxPool2d(1)
  6. self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
  7. self.relu1 = nn.LeakyReLU(negative_slope=0.01, inplace=False)
  8. self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
  9. self.sigmoid = nn.Sigmoid()
  10. def forward(self, x):
  11. avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
  12. max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
  13. out = avg_out + max_out
  14. return self.sigmoid(out)
  •  ConvBN


  
  1. class ConvBN(nn.Sequential):
  2. def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
  3. if not isinstance(kernel_size, int):
  4. padding = [(i - 1) // 2 for i in kernel_size]
  5. else:
  6. padding = (kernel_size - 1) // 2
  7. super(ConvBN, self).__init__(OrderedDict([
  8. ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
  9. padding=padding, groups=groups, bias=False)),
  10. ('bn', nn.BatchNorm2d(out_planes)),
  11. #('Mish', Mish())
  12. ('Mish', nn.LeakyReLU(negative_slope=0.3, inplace=False))
  13. ]))
  • ResBlock

    
        
    1. class ResBlock(nn.Module):
    2. """
    3. Sequential residual blocks each of which consists of \
    4. two convolution layers.
    5. Args:
    6. ch (int): number of input and output channels.
    7. nblocks (int): number of residual blocks.
    8. shortcut (bool): if True, residual tensor addition is enabled.
    9. """
    10. def __init__(self, ch, nblocks=1, shortcut=True):
    11. super().__init__()
    12. self.shortcut = shortcut
    13. self.module_list = nn.ModuleList()
    14. for i in range(nblocks):
    15. resblock_one = nn.ModuleList()
    16. resblock_one.append(ConvBN(ch, ch, 1))
    17. resblock_one.append(Mish())
    18. resblock_one.append(ConvBN(ch, ch, 3))
    19. resblock_one.append(Mish())
    20. self.module_list.append(resblock_one)
    21. def forward(self, x):
    22. for module in self.module_list:
    23. h = x
    24. for res in module:
    25. h = res(h)
    26. x = x + h if self.shortcut else h
    27. return x

     

 

 

 

 

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/113103089

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。