使用nn.Sequential()对象和nn.ModuleList建立模型

举报
AI浩 发表于 2021/12/23 01:10:04 2021/12/23
【摘要】 1、使用nn.Sequential()建立模型的三种方式  import torch as tfrom torch import nn # Sequential的三种写法net1 = nn.Sequential()net1.add_module('conv', nn.Conv2d(3, 3, 3)) # Conv2D(输入...

1、使用nn.Sequential()建立模型的三种方式 


  
  1. import torch as t
  2. from torch import nn
  3. # Sequential的三种写法
  4. net1 = nn.Sequential()
  5. net1.add_module('conv', nn.Conv2d(3, 3, 3)) # Conv2D(输入通道数,输出通道数,卷积核大小)
  6. net1.add_module('batchnorm', nn.BatchNorm2d(3)) # BatchNorm2d(特征数)
  7. net1.add_module('activation_layer', nn.ReLU())
  8. net2 = nn.Sequential(nn.Conv2d(3, 3, 3),
  9. nn.BatchNorm2d(3),
  10. nn.ReLU()
  11. )
  12. from collections import OrderedDict
  13. #注意字典的key不能重复
  14. net3 = nn.Sequential(OrderedDict([
  15. ('conv1', nn.Conv2d(3, 3, 3)),
  16. ('bh1', nn.BatchNorm2d(3)),
  17. ('al', nn.ReLU())
  18. ]))
  19. print('net1', net1)
  20. print('net2', net2)
  21. print('net3', net3)
  22. # 可根据名字或序号取出子module
  23. print(net1.conv, net2[0], net3.conv1)

输出:


  
  1. net1 Sequential(
  2. (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  3. (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  4. (activation_layer): ReLU()
  5. )
  6. net2 Sequential(
  7. (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  8. (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  9. (2): ReLU()
  10. )
  11. net3 Sequential(
  12. (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  13. (bh1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  14. (al): ReLU()
  15. )
  16. Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))

2、使用nn.ModuleList建立模型。


  
  1. class MyModule(nn.Module):
  2. def __init__(self):
  3. super(MyModule, self).__init__()
  4. self.list = [nn.Linear(3, 4), nn.ReLU()]
  5. self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
  6. def forward(self):
  7. pass
  8. model = MyModule()
  9. print(model)

输出:


  
  1. MyModule(
  2. (module_list): ModuleList(
  3. (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  4. (1): ReLU()
  5. )
  6. )

3、二者结合构造更复杂的网络模型

例如cisnet中的Decoder模型


  
  1. class Decoder(nn.Module):
  2. num_quan_bits = 4
  3. def __init__(self, feedback_bits):
  4. super(Decoder, self).__init__()
  5. self.feedback_bits = feedback_bits
  6. self.dequantize = DequantizationLayer(self.num_quan_bits)
  7. self.multiConvs = nn.ModuleList()
  8. self.fc = nn.Linear(int(feedback_bits / self.num_quan_bits), 768)
  9. self.out_cov = conv3x3(2, 2)
  10. self.sig = nn.Sigmoid()
  11. for _ in range(3):
  12. self.multiConvs.append(nn.Sequential(
  13. conv3x3(2, 8),
  14. nn.ReLU(),
  15. conv3x3(8, 16),
  16. nn.ReLU(),
  17. conv3x3(16, 2),
  18. nn.ReLU()))
  19. def forward(self, x):
  20. out = self.dequantize(x)
  21. out = out.contiguous().view(-1, int(self.feedback_bits / self.num_quan_bits)) #需使用contiguous().view(),或者可修改为reshape
  22. out = self.sig(self.fc(out))
  23. out = out.contiguous().view(-1, 2, 24, 16) #需使用contiguous().view(),或者可修改为reshape
  24. for i in range(3):
  25. residual = out
  26. out = self.multiConvs[i](out)
  27. out = residual + out
  28. out = self.out_cov(out)
  29. out = self.sig(out)
  30. out = out.permute(0, 2, 3, 1)
  31. return out

 

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/114476467

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。