mxnet测试网络速度

举报
风吹稻花香 发表于 2021/06/05 00:13:36 2021/06/05
【摘要】   # -*- coding: utf-8 -*-"""File Name: calculate_flops.pyAuthor: liangdepengmail: liangdepeng@gmail.com"""import time import cv2import mxnet as mximport argparseimport numpy as npimpo...

 


  
  1. # -*- coding: utf-8 -*-
  2. """
  3. File Name: calculate_flops.py
  4. Author: liangdepeng
  5. mail: liangdepeng@gmail.com
  6. """
  7. import time
  8. import cv2
  9. import mxnet as mx
  10. import argparse
  11. import numpy as np
  12. import json
  13. import re
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='')
  16. parser.add_argument('-ds', '--data_shapes', default=["data,1,3,112,112"], type=str, nargs='+',
  17. help='data_shapes, format: arg_name,s1,s2,...,sn, example: data,1,3,224,224')
  18. parser.add_argument('-ls', '--label_shapes', default=["label,1,512"], type=str, nargs='+',
  19. help='label_shapes, format: arg_name,s1,s2,...,sn, example: label,1,1,224,224')
  20. parser.add_argument('-s', '--symbol_path', type=str, default=r'softmax_label-symbol.json', help='')
  21. #
  22. return parser.parse_args()
  23. def single_input(path):
  24. img = cv2.imread(path)
  25. # mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换
  26. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  27. img = cv2.resize(img, (112, 112))
  28. img = img.transpose(2, 0, 1)
  29. # 添加一个第四维度并构建NDArray
  30. img = img[np.newaxis, :]
  31. array = mx.nd.array(img)
  32. return array
  33. if __name__ == '__main__':
  34. from collections import namedtuple
  35. args = parse_args()
  36. sym = mx.sym.load(args.symbol_path)
  37. data_shapes = list()
  38. data_names = list()
  39. if args.data_shapes is not None and len(args.data_shapes) > 0:
  40. for shape in args.data_shapes:
  41. items = shape.replace('\'', '').replace('"', '').split(',')
  42. data_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
  43. data_names.append(items[0])
  44. label_shapes = None
  45. label_names = list()
  46. if args.label_shapes is not None and len(args.label_shapes) > 0:
  47. label_shapes = list()
  48. for shape in args.label_shapes:
  49. items = shape.replace('\'', '').replace('"', '').split(',')
  50. label_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
  51. label_names.append(items[0])
  52. devs = [mx.cpu()]
  53. if len(label_names) == 0:
  54. label_names = None
  55. model = mx.mod.Module(context=devs, symbol=sym, data_names=data_names, label_names=None)
  56. model.bind(data_shapes=data_shapes, for_training=False)
  57. model._params_dirty = True
  58. model.params_initialized = True
  59. model.save_checkpoint("0412", 1)
  60. time1 = time.time()
  61. # print("模型加载和重建时间:{0}".format(time1 - time0))
  62. Batch = namedtuple("batch", ['data'])
  63. img1_path = r'1054086.jpg'
  64. array1 = single_input(img1_path)
  65. start1 = time.time()
  66. for i in range(40):
  67. start = time.time()
  68. model.forward(Batch([array1]))
  69. vector1 = model.get_outputs()[0].asnumpy()
  70. print("time", time.time() - start, vector1.shape)
  71. print("total time", time.time() - start1)

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/115648268

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。