mxnet可视化模型中间层feature map输出

举报
风吹稻花香 发表于 2021/06/04 23:10:12 2021/06/04
【摘要】   注: model输入 112x112保存的图片可能是白色的(这个还没有修复),但是在pycharm中运行时可以通过scientific tool窗口看到 #构造辅助函数做预处理, 注意mxnet中为通道在前格式即BCHW, 输入时要对通道维度调整,#其预训练模型采用减均值除方差的标准化预处理(均值标准差使用imagenet数据集的[0.485, 0.456,...

 

注:

  1. model输入 112x112
  2. 保存的图片可能是白色的(这个还没有修复),但是在pycharm中运行时可以通过scientific tool窗口看到

  
  1. #构造辅助函数做预处理, 注意mxnet中为通道在前格式即BCHW, 输入时要对通道维度调整,
  2. #其预训练模型采用减均值除方差的标准化预处理(均值标准差使用imagenet数据集的[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  3. #mxnet使用专有数据类型nd.array
  4. import cv2
  5. from mxnet import nd
  6. from mxnet import gluon
  7. import mxnet.gluon.nn as nn
  8. import numpy as np
  9. import mxnet as mx
  10. from collections import namedtuple
  11. import matplotlib.pyplot as plt
  12. def preprocess_img(img_path,data_shape,ctx):
  13. # 读取图片
  14. img = cv2.imread(img_path)
  15. # 将图片缩放到与bind中shape的width和height一致
  16. img = cv2.resize(img, (data_shape[2], data_shape[3]))
  17. # 将图片由BGR转为RGB
  18. img = img[:, :, ::-1]
  19. # 将numpy array转为ndarray
  20. nd_img = mx.nd.array(img,ctx=ctx).transpose((2, 0, 1))
  21. # 将图片的格式转为NCHW
  22. nd_img = mx.nd.expand_dims(nd_img, axis=0)
  23. return nd_img
  24. def get_mod(model_str,ctx,data_shape):
  25. _vec = model_str.split(",")
  26. prefix = _vec[0]
  27. epoch = int(_vec[1])
  28. sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
  29. mod = mx.mod.Module(symbol=sym,context=ctx,label_names=None)
  30. #注意修改data_shapes,ImageNet使用的shape是(224,224,3)
  31. mod.bind(for_training=False,data_shapes=[("data",data_shape)],
  32. label_shapes=mod._label_shapes)
  33. #加载网络的参数
  34. mod.set_params(arg_params,aux_params,allow_missing=True)
  35. internals = mod.symbol.get_internals() # list all symbol
  36. outputs = []
  37. for i in internals.list_outputs():
  38. if str(i).endswith('_output'):
  39. print(i)
  40. outputs.append(i)
  41. # print(internals.list_outputs())
  42. return sym, mod,arg_params,aux_params, outputs
  43. def predict(model_str,ctx,data_shape,img_path,label_path):
  44. #通过标签文件获取标签名称
  45. Batch = namedtuple("Batch",["data"])
  46. sym, mod,arg_params,aux_params, outputs = get_mod(model_str,ctx,data_shape)
  47. #获取预测的图片
  48. nd_img = preprocess_img(img_path,data_shape,ctx)
  49. #计算网络的预测值
  50. mod.forward(Batch([nd_img]))
  51. prob = mod.get_outputs()[0].asnumpy()
  52. print(prob.shape)
  53. # return prob
  54. args = sym.get_internals().list_outputs() # 获得所有中间输出
  55. internals = mod.symbol.get_internals()
  56. fc1 = internals['fc1_output']
  57. for i,output in enumerate(outputs):
  58. # conv = internals['stage4_unit1_bn1_output']
  59. conv = internals[output]
  60. # group = mx.symbol.Group([fc1, sym, conv]) # 把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
  61. group = mx.symbol.Group([conv]) # 把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
  62. #########################################################################
  63. mod = mx.mod.Module(symbol=group, context=mx.gpu()) # 创建Module
  64. mod.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))]) # 绑定,此代码为预测代码,所以training参数设为False
  65. mod.set_params(arg_params, aux_params)
  66. mod.forward(Batch([nd_img])) # 预测结果
  67. prob = mod.get_outputs()[0].asnumpy()
  68. # y = np.argsort(np.squeeze(prob))[::-1]
  69. print(prob.shape)
  70. def tensor_to_image(tensor):
  71. # tensor = tensor[0].asnumpy()
  72. tensor -= tensor.mean()
  73. tensor /= tensor.std()
  74. tensor *= 64
  75. tensor += 128
  76. img = np.clip(tensor, 0, 255).astype('uint8')
  77. return img
  78. prob = prob[0]
  79. src_img = cv2.resize(input_image, (112,112))
  80. src_img = src_img[:,:,::-1]
  81. # b, g, r = cv2.split(src_img)
  82. # src_img = cv2.merge([r, g, b])
  83. print(src_img.shape)
  84. # 可视化原图
  85. plt.imshow(src_img)
  86. plt.show()
  87. init_img = np.zeros((112,112))
  88. # init_img = np.zeros((7, 7))
  89. for img in prob:
  90. img = cv2.resize(img, (112,112))
  91. init_img += 0.1*tensor_to_image(img)
  92. # 可视化feature map
  93. plt.imshow(init_img, cmap='viridis')
  94. plt.show()
  95. src_img = src_img.transpose(2, 0, 1)
  96. for img in src_img:
  97. init_img += img
  98. plt.savefig('faces/feature_maps/%s_%s.jpg' % (i, output))
  99. # 可视化原图叠加feature map
  100. plt.imshow(init_img, cmap='viridis')
  101. plt.show()
  102. if __name__ == '__main__':
  103. model_str = '/data/user1/model,5'
  104. ctx = mx.cpu()
  105. data_shape = (1, 3, 112, 112)
  106. img_path = 'faces/2.jpg'
  107. input_image = cv2.imread(img_path)
  108. predict(model_str,ctx,data_shape,img_path,'')

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/116739633

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。