mxnet测试网络速度

举报
风吹稻花香 发表于 2021/06/05 00:13:36 2021/06/05
【摘要】   # -*- coding: utf-8 -*-"""File Name: calculate_flops.pyAuthor: liangdepengmail: liangdepeng@gmail.com"""import time import cv2import mxnet as mximport argparseimport numpy as npimpo...

 


      # -*- coding: utf-8 -*-
      """
      File Name: calculate_flops.py
      Author: liangdepeng
      mail: liangdepeng@gmail.com
      """
      import time
      import cv2
      import mxnet as mx
      import argparse
      import numpy as np
      import json
      import re
      def parse_args():
       parser = argparse.ArgumentParser(description='')
       parser.add_argument('-ds', '--data_shapes', default=["data,1,3,112,112"], type=str, nargs='+',
       help='data_shapes, format: arg_name,s1,s2,...,sn, example: data,1,3,224,224')
       parser.add_argument('-ls', '--label_shapes', default=["label,1,512"], type=str, nargs='+',
       help='label_shapes, format: arg_name,s1,s2,...,sn, example: label,1,1,224,224')
       parser.add_argument('-s', '--symbol_path', type=str, default=r'softmax_label-symbol.json', help='')
      #
      return parser.parse_args()
      def single_input(path):
       img = cv2.imread(path)
      # mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换
       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       img = cv2.resize(img, (112, 112))
       img = img.transpose(2, 0, 1)
      # 添加一个第四维度并构建NDArray
       img = img[np.newaxis, :]
       array = mx.nd.array(img)
      return array
      if __name__ == '__main__':
      from collections import namedtuple
       args = parse_args()
       sym = mx.sym.load(args.symbol_path)
       data_shapes = list()
       data_names = list()
      if args.data_shapes is not None and len(args.data_shapes) > 0:
      for shape in args.data_shapes:
       items = shape.replace('\'', '').replace('"', '').split(',')
       data_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
       data_names.append(items[0])
       label_shapes = None
       label_names = list()
      if args.label_shapes is not None and len(args.label_shapes) > 0:
       label_shapes = list()
      for shape in args.label_shapes:
       items = shape.replace('\'', '').replace('"', '').split(',')
       label_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
       label_names.append(items[0])
       devs = [mx.cpu()]
      if len(label_names) == 0:
       label_names = None
       model = mx.mod.Module(context=devs, symbol=sym, data_names=data_names, label_names=None)
       model.bind(data_shapes=data_shapes,  for_training=False)
       model._params_dirty = True
       model.params_initialized = True
       model.save_checkpoint("0412", 1)
       time1 = time.time()
      # print("模型加载和重建时间:{0}".format(time1 - time0))
       Batch = namedtuple("batch", ['data'])
       img1_path = r'1054086.jpg'
       array1 = single_input(img1_path)
       start1 = time.time()
      for i in range(40):
       start = time.time()
       model.forward(Batch([array1]))
       vector1 = model.get_outputs()[0].asnumpy()
       print("time", time.time() - start, vector1.shape)
       print("total time", time.time() - start1)
  
 

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/115648268

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。