fchardnet

举报
风吹稻花香 发表于 2021/06/05 22:31:34 2021/06/05
【摘要】   https://github.com/PingoLH/FCHarDNet/blob/master/ptsemseg/models/hardnet.py import time import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport col...

 

https://github.com/PingoLH/FCHarDNet/blob/master/ptsemseg/models/hardnet.py


  
  1. import time
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import numpy as np
  6. import collections
  7. class ConvLayer(nn.Sequential):
  8. def __init__(self, in_channels, out_channels, kernel=3, stride=1, dropout=0.1):
  9. super().__init__()
  10. self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size=kernel,
  11. stride=stride, padding=kernel // 2, bias=False))
  12. self.add_module('norm', nn.BatchNorm2d(out_channels))
  13. self.add_module('relu', nn.ReLU(inplace=True))
  14. # print(kernel, 'x', kernel, 'x', in_channels, 'x', out_channels)
  15. def forward(self, x):
  16. return super().forward(x)
  17. class BRLayer(nn.Sequential):
  18. def __init__(self, in_channels):
  19. super().__init__()
  20. self.add_module('norm', nn.BatchNorm2d(in_channels))
  21. self.add_module('relu', nn.ReLU(True))
  22. def forward(self, x):
  23. return super().forward(x)
  24. class HarDBlock_v2(nn.Module):
  25. def get_link(self, layer, base_ch, growth_rate, grmul):
  26. if layer == 0:
  27. return base_ch, 0, []
  28. out_channels = growth_rate
  29. link = []
  30. for i in range(10):
  31. dv = 2 ** i
  32. if layer % dv == 0:
  33. k = layer - dv
  34. link.insert(0, k)
  35. if i > 0:
  36. out_channels *= grmul
  37. out_channels = int(int(out_channels + 1) / 2) * 2
  38. in_channels = 0
  39. for i in link:
  40. ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
  41. in_channels += ch
  42. return out_channels, in_channels, link
  43. def get_out_ch(self):
  44. return self.out_channels
  45. def __init__(self, in_channels, growth_rate, grmul, n_layers, dwconv=False):
  46. super().__init__()
  47. self.links = []
  48. conv_layers_ = []
  49. bnrelu_layers_ = []
  50. self.layer_bias = []
  51. self.out_channels = 0
  52. self.out_partition = collections.defaultdict(list)
  53. for i in range(n_layers):
  54. outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul)
  55. self.links.append(link)
  56. for j in link:
  57. self.out_partition[j].append(outch)
  58. cur_ch = in_channels
  59. for i in range(n_layers):
  60. accum_out_ch = sum(self.out_partition[i])
  61. real_out_ch = self.out_partition[i][0]
  62. # print( self.links[i], self.out_partition[i], accum_out_ch)
  63. conv_layers_.append(nn.Conv2d(cur_ch, accum_out_ch, kernel_size=3, stride=1, padding=1, bias=True))
  64. bnrelu_layers_.append(BRLayer(real_out_ch))
  65. cur_ch = real_out_ch
  66. if (i % 2 == 0) or (i == n_layers - 1):
  67. self.out_channels += real_out_ch
  68. # print("Blk out =",self.out_channels)
  69. self.conv_layers = nn.ModuleList(conv_layers_)
  70. self.bnrelu_layers = nn.ModuleList(bnrelu_layers_)
  71. def transform(self, blk, trt=False):
  72. # Transform weight matrix from a pretrained HarDBlock v1
  73. in_ch = blk.layers[0][0].weight.shape[1]
  74. for i in range(len(self.conv_layers)):
  75. link = self.links[i].copy()
  76. link_ch = [blk.layers[k - 1][0].weight.shape[0] if k > 0 else
  77. blk.layers[0][0].weight.shape[1] for k in link]
  78. part = self.out_partition[i]
  79. w_src = blk.layers[i][0].weight
  80. b_src = blk.layers[i][0].bias
  81. self.conv_layers[i].weight[0:part[0], :, :, :] = w_src[:, 0:in_ch, :, :]
  82. self.layer_bias.append(b_src)
  83. if b_src is not None:
  84. if trt:
  85. self.conv_layers[i].bias[1:part[0]] = b_src[1:]
  86. self.conv_layers[i].bias[0] = b_src[0]
  87. self.conv_layers[i].bias[part[0]:] = 0
  88. self.layer_bias[i] = None
  89. else:
  90. # for pytorch, add bias with standalone tensor is more efficient than within conv.bias
  91. # this is because the amount of non-zero bias is small,
  92. # but if we use conv.bias, the number of bias will be much larger
  93. self.conv_layers[i].bias = None
  94. else:
  95. self.conv_layers[i].bias = None
  96. in_ch = part[0]
  97. link_ch.reverse()
  98. link.reverse()
  99. if len(link) > 1:
  100. for j in range(1, len(link)):
  101. ly = link[j]
  102. part_id = self.out_partition[ly].index(part[0])
  103. chos = sum(self.out_partition[ly][0:part_id])
  104. choe = chos + part[0]
  105. chis = sum(link_ch[0:j])
  106. chie = chis + link_ch[j]
  107. self.conv_layers[ly].weight[chos:choe, :, :, :] = w_src[:, chis:chie, :, :]
  108. # update BatchNorm or remove it if there is no BatchNorm in the v1 block
  109. self.bnrelu_layers[i] = None
  110. if isinstance(blk.layers[i][1], nn.BatchNorm2d):
  111. self.bnrelu_layers[i] = nn.Sequential(
  112. blk.layers[i][1],
  113. blk.layers[i][2])
  114. else:
  115. self.bnrelu_layers[i] = blk.layers[i][1]
  116. def forward(self, x):
  117. layers_ = []
  118. outs_ = []
  119. xin = x
  120. for i in range(len(self.conv_layers)):
  121. link = self.links[i]
  122. part = self.out_partition[i]
  123. xout = self.conv_layers[i](xin)
  124. layers_.append(xout)
  125. xin = xout[:, 0:part[0], :, :] if len(part) > 1 else xout
  126. if self.layer_bias[i] is not None:
  127. xin += self.layer_bias[i].view(1, -1, 1, 1)
  128. if len(link) > 1:
  129. for j in range(len(link) - 1):
  130. ly = link[j]
  131. part_id = self.out_partition[ly].index(part[0])
  132. chs = sum(self.out_partition[ly][0:part_id])
  133. che = chs + part[0]
  134. xin += layers_[ly][:, chs:che, :, :]
  135. xin = self.bnrelu_layers[i](xin)
  136. if i % 2 == 0 or i == len(self.conv_layers) - 1:
  137. outs_.append(xin)
  138. out = torch.cat(outs_, 1)
  139. return out
  140. class HarDBlock(nn.Module):
  141. def get_link(self, layer, base_ch, growth_rate, grmul):
  142. if layer == 0:
  143. return base_ch, 0, []
  144. out_channels = growth_rate
  145. link = []
  146. for i in range(10):
  147. dv = 2 ** i
  148. if layer % dv == 0:
  149. k = layer - dv
  150. link.append(k)
  151. if i > 0:
  152. out_channels *= grmul
  153. out_channels = int(int(out_channels + 1) / 2) * 2
  154. in_channels = 0
  155. for i in link:
  156. ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
  157. in_channels += ch
  158. return out_channels, in_channels, link
  159. def get_out_ch(self):
  160. return self.out_channels
  161. def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False):
  162. super().__init__()
  163. self.in_channels = in_channels
  164. self.growth_rate = growth_rate
  165. self.grmul = grmul
  166. self.n_layers = n_layers
  167. self.keepBase = keepBase
  168. self.links = []
  169. layers_ = []
  170. self.out_channels = 0 # if upsample else in_channels
  171. for i in range(n_layers):
  172. outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul)
  173. self.links.append(link)
  174. use_relu = residual_out
  175. layers_.append(ConvLayer(inch, outch))
  176. if (i % 2 == 0) or (i == n_layers - 1):
  177. self.out_channels += outch
  178. # print("Blk out =",self.out_channels)
  179. self.layers = nn.ModuleList(layers_)
  180. def forward(self, x):
  181. layers_ = [x]
  182. for layer in range(len(self.layers)):
  183. link = self.links[layer]
  184. tin = []
  185. for i in link:
  186. tin.append(layers_[i])
  187. if len(tin) > 1:
  188. x = torch.cat(tin, 1)
  189. else:
  190. x = tin[0]
  191. out = self.layers[layer](x)
  192. layers_.append(out)
  193. t = len(layers_)
  194. out_ = []
  195. for i in range(t):
  196. if (i == 0 and self.keepBase) or \
  197. (i == t - 1) or (i % 2 == 1):
  198. out_.append(layers_[i])
  199. out = torch.cat(out_, 1)
  200. return out
  201. class TransitionUp(nn.Module):
  202. def __init__(self, in_channels, out_channels):
  203. super().__init__()
  204. # print("upsample",in_channels, out_channels)
  205. def forward(self, x, skip, concat=True):
  206. out = F.interpolate(
  207. x,
  208. size=(skip.size(2), skip.size(3)),
  209. mode="bilinear",
  210. align_corners=True,
  211. )
  212. if concat:
  213. out = torch.cat([out, skip], 1)
  214. return out
  215. class Hardnet(nn.Module):
  216. def __init__(self, n_classes=19):
  217. super(Hardnet, self).__init__()
  218. first_ch = [16, 24, 32, 48]
  219. ch_list = [64, 96, 160, 224, 320]
  220. grmul = 1.7
  221. gr = [10, 16, 18, 24, 32]
  222. n_layers = [4, 4, 8, 8, 8]
  223. blks = len(n_layers)
  224. self.shortcut_layers = []
  225. self.base = nn.ModuleList([])
  226. self.base.append(
  227. ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3,
  228. stride=2))
  229. self.base.append(ConvLayer(first_ch[0], first_ch[1], kernel=3))
  230. self.base.append(ConvLayer(first_ch[1], first_ch[2], kernel=3, stride=2))
  231. self.base.append(ConvLayer(first_ch[2], first_ch[3], kernel=3))
  232. skip_connection_channel_counts = []
  233. ch = first_ch[3]
  234. for i in range(blks):
  235. blk = HarDBlock(ch, gr[i], grmul, n_layers[i])
  236. ch = blk.get_out_ch()
  237. skip_connection_channel_counts.append(ch)
  238. self.base.append(blk)
  239. if i < blks - 1:
  240. self.shortcut_layers.append(len(self.base) - 1)
  241. self.base.append(ConvLayer(ch, ch_list[i], kernel=1))
  242. ch = ch_list[i]
  243. if i < blks - 1:
  244. self.base.append(nn.AvgPool2d(kernel_size=2, stride=2))
  245. cur_channels_count = ch
  246. prev_block_channels = ch
  247. n_blocks = blks - 1
  248. self.n_blocks = n_blocks
  249. #######################
  250. # Upsampling path #
  251. #######################
  252. self.transUpBlocks = nn.ModuleList([])
  253. self.denseBlocksUp = nn.ModuleList([])
  254. self.conv1x1_up = nn.ModuleList([])
  255. for i in range(n_blocks - 1, -1, -1):
  256. self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels))
  257. cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]
  258. self.conv1x1_up.append(ConvLayer(cur_channels_count, cur_channels_count // 2, kernel=1))
  259. cur_channels_count = cur_channels_count // 2
  260. blk = HarDBlock(cur_channels_count, gr[i], grmul, n_layers[i])
  261. self.denseBlocksUp.append(blk)
  262. prev_block_channels = blk.get_out_ch()
  263. cur_channels_count = prev_block_channels
  264. self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
  265. out_channels=n_classes, kernel_size=1, stride=1,
  266. padding=0, bias=True)
  267. def v2_transform(self, trt=False):
  268. for i in range(len(self.base)):
  269. if isinstance(self.base[i], HarDBlock):
  270. blk = self.base[i]
  271. self.base[i] = HarDBlock_v2(blk.in_channels, blk.growth_rate, blk.grmul, blk.n_layers)
  272. self.base[i].transform(blk, trt)
  273. for i in range(self.n_blocks):
  274. blk = self.denseBlocksUp[i]
  275. self.denseBlocksUp[i] = HarDBlock_v2(blk.in_channels, blk.growth_rate, blk.grmul, blk.n_layers)
  276. self.denseBlocksUp[i].transform(blk, trt)
  277. def forward(self, x):
  278. skip_connections = []
  279. size_in = x.size()
  280. for i in range(len(self.base)):
  281. x = self.base[i](x)
  282. if i in self.shortcut_layers:
  283. skip_connections.append(x)
  284. print('x',i,x.size())
  285. out = x
  286. for i in range(self.n_blocks):
  287. skip = skip_connections.pop()
  288. out = self.transUpBlocks[i](out, skip, True)
  289. out = self.conv1x1_up[i](out)
  290. out = self.denseBlocksUp[i](out)
  291. print(i,out.size())
  292. out = self.finalConv(out)
  293. # out = F.interpolate(
  294. # out,
  295. # size=(size_in[2], size_in[3]),
  296. # mode="bilinear",
  297. # align_corners=True)
  298. return out
  299. if __name__ == '__main__':
  300. model=Hardnet().cuda()
  301. torch.save(model.state_dict(), f'v2.pth')
  302. #
  303. inputs = torch.randn(1, 3, 640, 640).cuda()
  304. #
  305. # torch_out = torch.onnx._export(pelee_net, inputs, output_onnx, export_params=True, verbose=False,
  306. # input_names=input_names, output_names=output_names,opset_version=11)
  307. # print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
  308. for i in range(5):
  309. start=time.time()
  310. output = model(inputs)
  311. print('output.size ', time.time()-start,output.size())
  312. # print('output.size ', time.time()-start,output[0].size(),output[1].size(),output[2].size())

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/106036437

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。