mxnet统计运算量

举报
风吹稻花香 发表于 2021/06/04 23:15:30 2021/06/04
【摘要】 参考: https://github.com/Ldpe2G/DeepLearningForFun/blob/master/MXNet-Python/CalculateFlopsTool/calculateFlops.py   python calculateFlops.py -s symbols/caffenet-symbol.json -ds data,1,...

参考:

https://github.com/Ldpe2G/DeepLearningForFun/blob/master/MXNet-Python/CalculateFlopsTool/calculateFlops.py

 

python calculateFlops.py -s symbols/caffenet-symbol.json -ds data,1,3,224,224 -ls prob_label,1,1000 
 

('flops: ', '723.007176', ' MFLOPS')

('model size: ', '232.563873291', ' MB') 

python calculateFlops.py -s E:\2/model-symbol.json -ds data,1,3,112,112 -ls prob_label,1,256

如果key_error prob_label,可以改下代码,不加载这个label_shapes参数


  
  1. # -*- coding: utf-8 -*-
  2. """
  3. File Name: calculate_flops.py
  4. Author: liangdepeng
  5. mail: liangdepeng@gmail.com
  6. """
  7. import mxnet as mx
  8. import argparse
  9. import numpy as np
  10. import json
  11. import re
  12. def parse_args():
  13. parser = argparse.ArgumentParser(description='')
  14. parser.add_argument('-ds', '--data_shapes',default=["data,1,3,112,112"], type=str, nargs='+',
  15. help='data_shapes, format: arg_name,s1,s2,...,sn, example: data,1,3,224,224')
  16. parser.add_argument('-ls', '--label_shapes',default=["label,1,512"], type=str, nargs='+',
  17. help='label_shapes, format: arg_name,s1,s2,...,sn, example: label,1,1,224,224')
  18. # parser.add_argument('-s', '--symbol_path', type=str, default=r'model-symbol.json', help='')
  19. return parser.parse_args()
  20. def product(tu):
  21. """Calculates the product of a tuple"""
  22. prod = 1
  23. for x in tu:
  24. prod = prod * x
  25. return prod
  26. def get_internal_label_info(internal_sym, label_shapes):
  27. if label_shapes:
  28. internal_label_shapes = filter(lambda shape: shape[0] in internal_sym.list_arguments(), label_shapes)
  29. if internal_label_shapes:
  30. internal_label_names = [shape[0] for shape in internal_label_shapes]
  31. return internal_label_names, internal_label_shapes
  32. return None, None
  33. if __name__ == '__main__':
  34. args = parse_args()
  35. sym = mx.sym.load(args.symbol_path)
  36. data_shapes = list()
  37. data_names = list()
  38. if args.data_shapes is not None and len(args.data_shapes) > 0:
  39. for shape in args.data_shapes:
  40. items = shape.replace('\'', '').replace('"', '').split(',')
  41. data_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
  42. data_names.append(items[0])
  43. label_shapes = None
  44. label_names = list()
  45. if args.label_shapes is not None and len(args.label_shapes) > 0:
  46. label_shapes = list()
  47. for shape in args.label_shapes:
  48. items = shape.replace('\'', '').replace('"', '').split(',')
  49. label_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
  50. label_names.append(items[0])
  51. devs = [mx.cpu()]
  52. if len(label_names) == 0:
  53. label_names = None
  54. model = mx.mod.Module(context=devs, symbol=sym, data_names=data_names,label_names=None)
  55. model.bind(data_shapes=data_shapes, label_shapes=label_shapes, for_training=False)
  56. arg_params = model._exec_group.execs[0].arg_dict
  57. conf = json.loads(sym.tojson())
  58. nodes = conf["nodes"]
  59. total_flops=0.
  60. for node in nodes:
  61. op = node["op"]
  62. layer_name = node["name"]
  63. attrs = None
  64. if "param" in node:
  65. attrs = node["param"]
  66. elif "attrs" in node:
  67. attrs = node["attrs"]
  68. else:
  69. attrs = {}
  70. if op == 'Convolution':
  71. internal_sym = sym.get_internals()[layer_name + '_output']
  72. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, label_shapes)
  73. shape_dict = {}
  74. for k,v in data_shapes:
  75. shape_dict[k] = v
  76. if internal_label_shapes != None:
  77. for k,v in internal_label_shapes:
  78. shape_dict[k] = v
  79. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  80. out_shape = out_shapes[0]
  81. # num_group = 1
  82. # if "num_group" in attrs:
  83. # num_group = int(attrs['num_group'])
  84. # support conv1d NCW and conv2d NCHW layout
  85. out_shape_produt = out_shape[2] if len(out_shape) == 3 else out_shape[2] * out_shape[3]
  86. # the weight shape already consider the 'group', so no need to divide 'group'
  87. total_flops += out_shape_produt * product(arg_params[layer_name + '_weight'].shape) * data_shapes[0][1][0]
  88. if layer_name + "_bias" in arg_params:
  89. total_flops += product(out_shape)
  90. del shape_dict
  91. if op == 'Deconvolution':
  92. input_layer_name = nodes[node["inputs"][0][0]]["name"]
  93. internal_sym = sym.get_internals()[input_layer_name + '_output']
  94. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, label_shapes)
  95. shape_dict = {}
  96. for k,v in data_shapes:
  97. shape_dict[k] = v
  98. if internal_label_shapes != None:
  99. for k,v in internal_label_shapes:
  100. shape_dict[k] = v
  101. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  102. input_shape = out_shapes[0]
  103. # num_group = 1
  104. # if "num_group" in attrs:
  105. # num_group = int(attrs['num_group'])
  106. # the weight shape already consider the 'group', so no need to divide 'group'
  107. total_flops += input_shape[2] * input_shape[3] * product(arg_params[layer_name + '_weight'].shape) * data_shapes[0][1][0]
  108. del shape_dict
  109. if layer_name + "_bias" in arg_params:
  110. internal_sym = sym.get_internals()[layer_name + '_output']
  111. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, internal_label_shapes)
  112. shape_dict = {}
  113. for k,v in data_shapes:
  114. shape_dict[k] = v
  115. if internal_label_shapes != None:
  116. for k,v in internal_label_shapes:
  117. shape_dict[k] = v
  118. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  119. out_shapes = out_shapes[0]
  120. total_flops += product(out_shape)
  121. del shape_dict
  122. if op == 'FullyConnected':
  123. total_flops += product(arg_params[layer_name + '_weight'].shape) * data_shapes[0][1][0]
  124. if layer_name + '_bias' in arg_params:
  125. num_hidden = int(attrs['num_hidden'])
  126. total_flops += num_hidden * data_shapes[0][1][0]
  127. if op == 'Pooling':
  128. if "global_pool" in attrs and attrs['global_pool'] == 'True':
  129. input_layer_name = nodes[node["inputs"][0][0]]["name"]
  130. internal_sym = sym.get_internals()[input_layer_name + '_output']
  131. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, label_shapes)
  132. shape_dict = {}
  133. for k,v in data_shapes:
  134. shape_dict[k] = v
  135. if internal_label_shapes != None:
  136. for k,v in internal_label_shapes:
  137. shape_dict[k] = v
  138. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  139. input_shape = out_shapes[0]
  140. total_flops += product(input_shape)
  141. else:
  142. internal_sym = sym.get_internals()[layer_name + '_output']
  143. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, label_shapes)
  144. shape_dict = {}
  145. for k,v in data_shapes:
  146. shape_dict[k] = v
  147. if internal_label_shapes != None:
  148. for k,v in internal_label_shapes:
  149. shape_dict[k] = v
  150. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  151. out_shape = out_shapes[0]
  152. n = '\d+'
  153. kernel = [int(i) for i in re.findall(n, attrs['kernel'])]
  154. total_flops += product(out_shape) * product(kernel)
  155. del shape_dict
  156. if op == 'Activation':
  157. if attrs['act_type'] == 'relu':
  158. internal_sym = sym.get_internals()[layer_name + '_output']
  159. internal_label_names, internal_label_shapes = get_internal_label_info(internal_sym, label_shapes)
  160. shape_dict = {}
  161. for k,v in data_shapes:
  162. shape_dict[k] = v
  163. if internal_label_shapes != None:
  164. for k,v in internal_label_shapes:
  165. shape_dict[k] = v
  166. _, out_shapes, _ = internal_sym.infer_shape(**shape_dict)
  167. out_shape = out_shapes[0]
  168. total_flops += product(out_shape)
  169. del shape_dict
  170. model_size = 0.0
  171. if label_names == None:
  172. label_names = list()
  173. for k,v in arg_params.items():
  174. if k not in data_names and k not in label_names:
  175. model_size += product(v.shape) * np.dtype(v.dtype()).itemsize
  176. print('flops: ', str(total_flops / 1000000), ' MFLOPS')
  177. print('model size: ', str(model_size / 1024 / 1024), ' MB')

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/115634504

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。