使用nn.Sequential()对象和nn.ModuleList建立模型

举报
AI浩 发表于 2021/12/23 01:10:04 2021/12/23
953 0 0
【摘要】 1、使用nn.Sequential()建立模型的三种方式  import torch as tfrom torch import nn # Sequential的三种写法net1 = nn.Sequential()net1.add_module('conv', nn.Conv2d(3, 3, 3)) # Conv2D(输入...

1、使用nn.Sequential()建立模型的三种方式 


      import torch as t
      from torch import nn
      # Sequential的三种写法
      net1 = nn.Sequential()
      net1.add_module('conv', nn.Conv2d(3, 3, 3))  # Conv2D(输入通道数,输出通道数,卷积核大小)
      net1.add_module('batchnorm', nn.BatchNorm2d(3))  # BatchNorm2d(特征数)
      net1.add_module('activation_layer', nn.ReLU())
      net2 = nn.Sequential(nn.Conv2d(3, 3, 3),
                           nn.BatchNorm2d(3),
                           nn.ReLU()
                           )
      from collections import OrderedDict
      #注意字典的key不能重复
      net3 = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(3, 3, 3)),
          ('bh1', nn.BatchNorm2d(3)),
          ('al', nn.ReLU())
      ]))
      print('net1', net1)
      print('net2', net2)
      print('net3', net3)
      # 可根据名字或序号取出子module
      print(net1.conv, net2[0], net3.conv1)
  
 

输出:


      net1 Sequential(
        (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
        (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation_layer): ReLU()
      )
      net2 Sequential(
        (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      net3 Sequential(
        (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
        (bh1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (al): ReLU()
      )
      Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  
 

2、使用nn.ModuleList建立模型。


      class MyModule(nn.Module):
         def __init__(self):
             super(MyModule, self).__init__()
              self.list = [nn.Linear(3, 4), nn.ReLU()]
              self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
         def forward(self):
             pass
      model = MyModule()
      print(model)
  
 

输出:


      MyModule(
        (module_list): ModuleList(
          (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
          (1): ReLU()
        )
      )
  
 

3、二者结合构造更复杂的网络模型

例如cisnet中的Decoder模型


      class Decoder(nn.Module):
          num_quan_bits = 4
         def __init__(self, feedback_bits):
             super(Decoder, self).__init__()
              self.feedback_bits = feedback_bits
              self.dequantize = DequantizationLayer(self.num_quan_bits)
              self.multiConvs = nn.ModuleList()
              self.fc = nn.Linear(int(feedback_bits / self.num_quan_bits), 768)
              self.out_cov = conv3x3(2, 2)
              self.sig = nn.Sigmoid()
             for _ in range(3):
                  self.multiConvs.append(nn.Sequential(
                      conv3x3(2, 8),
                      nn.ReLU(),
                      conv3x3(8, 16),
                      nn.ReLU(),
                      conv3x3(16, 2),
                      nn.ReLU()))
         def forward(self, x):
              out = self.dequantize(x)
              out = out.contiguous().view(-1, int(self.feedback_bits / self.num_quan_bits)) #需使用contiguous().view(),或者可修改为reshape
              out = self.sig(self.fc(out))
              out = out.contiguous().view(-1, 2, 24, 16) #需使用contiguous().view(),或者可修改为reshape
             for i in range(3):
                  residual = out
                  out = self.multiConvs[i](out)
                  out = residual + out
              out = self.out_cov(out)
              out = self.sig(out)
              out = out.permute(0, 2, 3, 1)
             return out
  
 

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/114476467

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

作者其他文章

评论(0

抱歉,系统识别当前为高风险访问,暂不支持该操作

    全部回复

    上滑加载中

    设置昵称

    在此一键设置昵称,即可参与社区互动!

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。