tensorflow 推理场景dump脚本使用方法
【摘要】 1. 功能说明用于获取tensorflow的运行结果,可以获取每一层的结果,输入需要是bin文件。需要注意的是,tensorflow一般需要的输入是NHWC格式的(具体网络要具体确认),需要的输入bin文件也需要是NHWC格式的。2. 主要参数说明参数名称参数 描述protobufpb文件路径-i, --input_bins模型推理输入bin文件路径,多个以;分隔,如'./a.bin;./c...
1. 功能说明
用于获取tensorflow的运行结果,可以获取每一层的结果,输入需要是bin文件。
需要注意的是,tensorflow一般需要的输入是NHWC格式的(具体网络要具体确认),需要的输入bin文件也需要是NHWC格式的。
2. 主要参数说明
3. 使用示例
示例1:dump模型每个节点的bin结果文件,并输出npy结果文件;
python3.7 tools_tensorflow_dump.py resnet.pb -i test.bin -n 'Inputs:0' -s [1,224,224,3] -a
4. 文件生成结果展示
示例1:npy格式的dump文件会生成到当前操作目录下名为{pb文件名}_dump 的文件夹中,如下图:
tensorflow_dump.py
# !/usr/bin/env python3 # -*- coding: utf-8 -*- import numpy as np import math import sys import os import argparse from logging import * import logging import tensorflow as tf import shutil import time def errorExit(msg, *args, **kwargs): error(msg, *args, **kwargs) exit() def checkConditionExit(condition, msg, *args, **kwargs): if not condition: errorExit(msg, *args, **kwargs) def convertToShape(shapeStr): try: shape = eval(shapeStr) except: errorExit("%s shape is invalid", shapeStr) return shape def load_graph(filename): with tf.gfile.GFile(filename, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") return graph def load_inputs(input_bins, input_names, shapes, graph): input_bins = input_bins.split(";") input_names = input_names.split(";") inputs_map = {} input_shapes = [] if shapes != None: shapes = shapes.split(";") for shape in shapes: input_shapes.append(convertToShape(shape)) checkConditionExit(len(input_names) == len(input_bins), "input_names must have same length with input_bins") for i in range(len(input_bins)): input_name = input_names[i] input = graph.get_tensor_by_name(input_name) if len(input_shapes) == 0: input_bin = np.fromfile(input_bins[i], np.float32).reshape(input.shape) else: input_bin = np.fromfile(input_bins[i], np.float32).reshape(input_shapes[i]) inputs_map[input] = input_bin return inputs_map def load_outputs(dump_all, dump_nodes, graph): outputs = [] if dump_all: ops = graph.get_operations() output_names = [] for op in ops: op_outputs = op.inputs for op_output in op_outputs: output_names.append(op_output.name) for output_name in output_names: node = graph.get_tensor_by_name(output_name) outputs.append(node) else: checkConditionExit(dump_nodes != None, "no dump_nodes provides, %s", dump_nodes) dump_nodes = dump_nodes.split(";") for dump_node in dump_nodes: node = graph.get_tensor_by_name(dump_node) outputs.append(node) return outputs def NHWC2NCHW(input): result = input.transpose([0, 3, 1, 2]) return result def main(args): print(args) protobuf = args.protobuf pb_path = protobuf.name graph = load_graph(pb_path) input_bins = args.input_bins input_names = args.input_names shapes = args.shapes inputs_map = load_inputs(input_bins, input_names, shapes, graph) dump_all = args.dump_all outputs = load_outputs(dump_all, None, graph) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) with tf.Session(graph=graph, config=config) as sess: dump_bins = sess.run(outputs, feed_dict=inputs_map) pathSep = os.path.sep dir_name = os.path.dirname(os.path.abspath(pb_path)) output_floder = dir_name + pathSep + os.path.basename(pb_path).split('.')[0] + "_dump" # print(output_floder) if os.path.exists(output_floder): info("remove dir %s", output_floder) shutil.rmtree(output_floder) info("create dir %s", output_floder) os.mkdir(output_floder, 755) for i in range(len(outputs)): output = outputs[i].name output = output.replace("/", "_") output = output.replace(":", ".") prefix = output + "." + str(round(time.time() * 1000000)) dump_bin = dump_bins[i] dump_path = output_floder + pathSep + prefix + ".npy" np.save(dump_path, dump_bin) if __name__ == '__main__': logging.basicConfig( format="%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s", level=logging.DEBUG) argsParse = argparse.ArgumentParser( prog=sys.argv[0], description="This script is dump tensorflow output.only support float32", epilog="Enjoy it.") argsParse.add_argument("protobuf", help="protobuf file path", type=argparse.FileType('r', encoding="utf-8")) dump_cate_group = argsParse.add_mutually_exclusive_group(required=True) argsParse.add_argument("-i", "--input_bins", help="input_bins bins. e.g. './a.bin;./c.bin'", required=True) argsParse.add_argument("-n", "--input_names", help="input nodes name. e.g. 'graph_input_0:0;graph_input_0:1'") dump_cate_group.add_argument( "-a", "--dump_all", help="dump all nodes' outputs. don't use this option for large network \ if you don't mind when will get all out. in this mode will not dump last layer outputs", action="store_true", default=False) argsParse.add_argument( "-s", "--shapes", help="input shapes. e.g. [1,2,3,4];[2,3,4,5]. if input is placeholder set input shapes by this.", default=None) args = argsParse.parse_args() main(args)
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)