mxnet可视化模型中间层feature map输出
        【摘要】    
注: 
model输入 112x112保存的图片可能是白色的(这个还没有修复),但是在pycharm中运行时可以通过scientific tool窗口看到
#构造辅助函数做预处理, 注意mxnet中为通道在前格式即BCHW, 输入时要对通道维度调整,#其预训练模型采用减均值除方差的标准化预处理(均值标准差使用imagenet数据集的[0.485, 0.456,...
    
    
    
    
注:
- model输入 112x112
- 保存的图片可能是白色的(这个还没有修复),但是在pycharm中运行时可以通过scientific tool窗口看到
  
   - 
    
     
    
    
     
      #构造辅助函数做预处理, 注意mxnet中为通道在前格式即BCHW, 输入时要对通道维度调整,
     
    
- 
    
     
    
    
     
      #其预训练模型采用减均值除方差的标准化预处理(均值标准差使用imagenet数据集的[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     
    
- 
    
     
    
    
     
      #mxnet使用专有数据类型nd.array
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      import cv2
     
    
- 
    
     
    
    
     
      from mxnet import nd
     
    
- 
    
     
    
    
     
      from mxnet import gluon
     
    
- 
    
     
    
    
     
      import mxnet.gluon.nn as nn
     
    
- 
    
     
    
    
     
      import numpy as np
     
    
- 
    
     
    
    
     
      import mxnet as mx
     
    
- 
    
     
    
    
     
      from collections import namedtuple
     
    
- 
    
     
    
    
     
      import matplotlib.pyplot as plt
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      def preprocess_img(img_path,data_shape,ctx):
     
    
- 
    
     
    
    
      # 读取图片
     
    
- 
    
     
    
    
     
       img = cv2.imread(img_path)
     
    
- 
    
     
    
    
      # 将图片缩放到与bind中shape的width和height一致
     
    
- 
    
     
    
    
     
       img = cv2.resize(img, (data_shape[2], data_shape[3]))
     
    
- 
    
     
    
    
      # 将图片由BGR转为RGB
     
    
- 
    
     
    
    
     
       img = img[:, :, ::-1]
     
    
- 
    
     
    
    
      # 将numpy array转为ndarray
     
    
- 
    
     
    
    
     
       nd_img = mx.nd.array(img,ctx=ctx).transpose((2, 0, 1))
     
    
- 
    
     
    
    
      # 将图片的格式转为NCHW
     
    
- 
    
     
    
    
     
       nd_img = mx.nd.expand_dims(nd_img, axis=0)
     
    
- 
    
     
    
    
      return nd_img
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      def get_mod(model_str,ctx,data_shape):
     
    
- 
    
     
    
    
     
       _vec = model_str.split(",")
     
    
- 
    
     
    
    
     
       prefix = _vec[0]
     
    
- 
    
     
    
    
     
       epoch = int(_vec[1])
     
    
- 
    
     
    
    
     
       sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
     
    
- 
    
     
    
    
     
       mod = mx.mod.Module(symbol=sym,context=ctx,label_names=None)
     
    
- 
    
     
    
    
      #注意修改data_shapes,ImageNet使用的shape是(224,224,3)
     
    
- 
    
     
    
    
     
       mod.bind(for_training=False,data_shapes=[("data",data_shape)],
     
    
- 
    
     
    
    
     
       label_shapes=mod._label_shapes)
     
    
- 
    
     
    
    
      #加载网络的参数
     
    
- 
    
     
    
    
     
       mod.set_params(arg_params,aux_params,allow_missing=True)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       internals = mod.symbol.get_internals()  # list all symbol
     
    
- 
    
     
    
    
     
       outputs = []
     
    
- 
    
     
    
    
      for i in internals.list_outputs():
     
    
- 
    
     
    
    
      if str(i).endswith('_output'):
     
    
- 
    
     
    
    
     
       print(i)
     
    
- 
    
     
    
    
     
       outputs.append(i)
     
    
- 
    
     
    
    
      # print(internals.list_outputs())
     
    
- 
    
     
    
    
      return sym, mod,arg_params,aux_params, outputs
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      def predict(model_str,ctx,data_shape,img_path,label_path):
     
    
- 
    
     
    
    
      #通过标签文件获取标签名称
     
    
- 
    
     
    
    
     
       Batch = namedtuple("Batch",["data"])
     
    
- 
    
     
    
    
     
       sym, mod,arg_params,aux_params, outputs = get_mod(model_str,ctx,data_shape)
     
    
- 
    
     
    
    
      #获取预测的图片
     
    
- 
    
     
    
    
     
       nd_img = preprocess_img(img_path,data_shape,ctx)
     
    
- 
    
     
    
    
      #计算网络的预测值
     
    
- 
    
     
    
    
     
       mod.forward(Batch([nd_img]))
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       prob = mod.get_outputs()[0].asnumpy()
     
    
- 
    
     
    
    
     
       print(prob.shape)
     
    
- 
    
     
    
    
      # return prob
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       args = sym.get_internals().list_outputs()  # 获得所有中间输出
     
    
- 
    
     
    
    
     
       internals = mod.symbol.get_internals()
     
    
- 
    
     
    
    
     
       fc1 = internals['fc1_output']
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      for i,output in enumerate(outputs):
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      # conv = internals['stage4_unit1_bn1_output']
     
    
- 
    
     
    
    
     
       conv = internals[output]
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      # group = mx.symbol.Group([fc1, sym, conv]) # 把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       group = mx.symbol.Group([conv])  # 把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
     
    
- 
    
     
    
    
      #########################################################################
     
    
- 
    
     
    
    
     
       mod = mx.mod.Module(symbol=group, context=mx.gpu())  # 创建Module
     
    
- 
    
     
    
    
     
       mod.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])  # 绑定,此代码为预测代码,所以training参数设为False
     
    
- 
    
     
    
    
     
       mod.set_params(arg_params, aux_params)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       mod.forward(Batch([nd_img]))  # 预测结果
     
    
- 
    
     
    
    
     
       prob = mod.get_outputs()[0].asnumpy()
     
    
- 
    
     
    
    
      # y = np.argsort(np.squeeze(prob))[::-1]
     
    
- 
    
     
    
    
     
       print(prob.shape)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      def tensor_to_image(tensor):
     
    
- 
    
     
    
    
      # tensor = tensor[0].asnumpy()
     
    
- 
    
     
    
    
     
       tensor -= tensor.mean()
     
    
- 
    
     
    
    
     
       tensor /= tensor.std()
     
    
- 
    
     
    
    
     
       tensor *= 64
     
    
- 
    
     
    
    
     
       tensor += 128
     
    
- 
    
     
    
    
     
       img = np.clip(tensor, 0, 255).astype('uint8')
     
    
- 
    
     
    
    
      return img
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       prob = prob[0]
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       src_img = cv2.resize(input_image, (112,112))
     
    
- 
    
     
    
    
     
       src_img = src_img[:,:,::-1]
     
    
- 
    
     
    
    
      # b, g, r = cv2.split(src_img)
     
    
- 
    
     
    
    
      # src_img = cv2.merge([r, g, b])
     
    
- 
    
     
    
    
     
       print(src_img.shape)
     
    
- 
    
     
    
    
     		
     
    
- 
    
     
    
    
     		# 可视化原图
     
    
- 
    
     
    
    
     
       plt.imshow(src_img)
     
    
- 
    
     
    
    
     
       plt.show()
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       init_img = np.zeros((112,112))
     
    
- 
    
     
    
    
      # init_img = np.zeros((7, 7))
     
    
- 
    
     
    
    
      for img in prob:
     
    
- 
    
     
    
    
     
       img = cv2.resize(img, (112,112))
     
    
- 
    
     
    
    
     
       init_img += 0.1*tensor_to_image(img)
     
    
- 
    
     
    
    
     		
     
    
- 
    
     
    
    
     		# 可视化feature map
     
    
- 
    
     
    
    
     
       plt.imshow(init_img, cmap='viridis')
     
    
- 
    
     
    
    
     
       plt.show()
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       src_img = src_img.transpose(2, 0, 1)
     
    
- 
    
     
    
    
      for img in src_img:
     
    
- 
    
     
    
    
     
       init_img += img
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       plt.savefig('faces/feature_maps/%s_%s.jpg' % (i, output))
     
    
- 
    
     
    
    
     		
     
    
- 
    
     
    
    
     		# 可视化原图叠加feature map
     
    
- 
    
     
    
    
     
       plt.imshow(init_img, cmap='viridis')
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
       plt.show()
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      if __name__ == '__main__':
     
    
- 
    
     
    
    
     
       model_str = '/data/user1/model,5'
     
    
- 
    
     
    
    
     
       ctx = mx.cpu()
     
    
- 
    
     
    
    
     
       data_shape = (1, 3, 112, 112)
     
    
- 
    
     
    
    
     
       img_path = 'faces/2.jpg'
     
    
- 
    
     
    
    
     
       input_image = cv2.imread(img_path)
     
    
- 
    
     
    
    
     
       predict(model_str,ctx,data_shape,img_path,'')
     
    
 文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/116739633
        【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
            cloudbbs@huaweicloud.com
        
        
        
        
        
        
        - 点赞
- 收藏
- 关注作者
 
             
           
评论(0)