mxnet测试网络速度
【摘要】
# -*- coding: utf-8 -*-"""File Name: calculate_flops.pyAuthor: liangdepengmail: liangdepeng@gmail.com"""import time import cv2import mxnet as mximport argparseimport numpy as npimpo...
-
# -*- coding: utf-8 -*-
-
"""
-
File Name: calculate_flops.py
-
Author: liangdepeng
-
mail: liangdepeng@gmail.com
-
"""
-
import time
-
-
import cv2
-
import mxnet as mx
-
import argparse
-
import numpy as np
-
import json
-
import re
-
-
-
def parse_args():
-
parser = argparse.ArgumentParser(description='')
-
parser.add_argument('-ds', '--data_shapes', default=["data,1,3,112,112"], type=str, nargs='+',
-
help='data_shapes, format: arg_name,s1,s2,...,sn, example: data,1,3,224,224')
-
parser.add_argument('-ls', '--label_shapes', default=["label,1,512"], type=str, nargs='+',
-
help='label_shapes, format: arg_name,s1,s2,...,sn, example: label,1,1,224,224')
-
parser.add_argument('-s', '--symbol_path', type=str, default=r'softmax_label-symbol.json', help='')
-
#
-
return parser.parse_args()
-
-
def single_input(path):
-
img = cv2.imread(path)
-
# mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-
img = cv2.resize(img, (112, 112))
-
img = img.transpose(2, 0, 1)
-
-
# 添加一个第四维度并构建NDArray
-
img = img[np.newaxis, :]
-
array = mx.nd.array(img)
-
return array
-
-
-
if __name__ == '__main__':
-
-
from collections import namedtuple
-
args = parse_args()
-
sym = mx.sym.load(args.symbol_path)
-
-
data_shapes = list()
-
data_names = list()
-
if args.data_shapes is not None and len(args.data_shapes) > 0:
-
for shape in args.data_shapes:
-
items = shape.replace('\'', '').replace('"', '').split(',')
-
data_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
-
data_names.append(items[0])
-
-
label_shapes = None
-
label_names = list()
-
if args.label_shapes is not None and len(args.label_shapes) > 0:
-
label_shapes = list()
-
for shape in args.label_shapes:
-
items = shape.replace('\'', '').replace('"', '').split(',')
-
label_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
-
label_names.append(items[0])
-
-
devs = [mx.cpu()]
-
-
if len(label_names) == 0:
-
label_names = None
-
model = mx.mod.Module(context=devs, symbol=sym, data_names=data_names, label_names=None)
-
model.bind(data_shapes=data_shapes, for_training=False)
-
-
model._params_dirty = True
-
model.params_initialized = True
-
-
-
model.save_checkpoint("0412", 1)
-
-
time1 = time.time()
-
-
-
# print("模型加载和重建时间:{0}".format(time1 - time0))
-
-
Batch = namedtuple("batch", ['data'])
-
-
img1_path = r'1054086.jpg'
-
-
array1 = single_input(img1_path)
-
start1 = time.time()
-
for i in range(40):
-
start = time.time()
-
model.forward(Batch([array1]))
-
vector1 = model.get_outputs()[0].asnumpy()
-
print("time", time.time() - start, vector1.shape)
-
print("total time", time.time() - start1)
-
-
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/115648268
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)