tensorflow 推理场景dump脚本使用方法

举报
Ascend_helper 发表于 2020/07/04 15:52:04 2020/07/04
【摘要】 1. 功能说明用于获取tensorflow的运行结果,可以获取每一层的结果,输入需要是bin文件。需要注意的是,tensorflow一般需要的输入是NHWC格式的(具体网络要具体确认),需要的输入bin文件也需要是NHWC格式的。2. 主要参数说明参数名称参数 描述protobufpb文件路径-i, --input_bins模型推理输入bin文件路径,多个以;分隔,如'./a.bin;./c...

1. 功能说明

用于获取tensorflow的运行结果,可以获取每一层的结果,输入需要是bin文件。

需要注意的是,tensorflow一般需要的输入是NHWC格式的(具体网络要具体确认),需要的输入bin文件也需要是NHWC格式的。

2. 主要参数说明

参数名称

参数 描述

protobuf

pb文件路径

-i, --input_bins

模型推理输入bin文件路径,多个以;分隔,如'./a.bin;./c.bin'

-n, --input_names

模型推理输入节点名称,多个以;分隔,如'graph_input_0:0; graph_input_1:0'

-a, --dump_all

dump模型推理时的每个节点输出结果

-s, --shapes

模型推理输入节点的shape,多个以;分隔,如'[1,2,3,4];[2,3,4,5]'


3. 使用示例


示例1:dump模型每个节点的bin结果文件,并输出npy结果文件;

python3.7 tools_tensorflow_dump.py resnet.pb -i test.bin -n 'Inputs:0' -s [1,224,224,3] -a


4. 文件生成结果展示


示例1:npy格式的dump文件会生成到当前操作目录下名为{pb文件名}_dump 的文件夹中,如下图:


tensorflow_dump.py

# !/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import math
import sys
import os
import argparse
from logging import *
import logging
import tensorflow as tf
import shutil
import time


def errorExit(msg, *args, **kwargs):
    error(msg, *args, **kwargs)
    exit()


def checkConditionExit(condition, msg,  *args, **kwargs):
    if not condition:
        errorExit(msg, *args, **kwargs)


def convertToShape(shapeStr):
    try:
        shape = eval(shapeStr)
    except:
        errorExit("%s shape is invalid", shapeStr)
    return shape


def load_graph(filename):
    with tf.gfile.GFile(filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
    return graph


def load_inputs(input_bins, input_names, shapes, graph):
    input_bins = input_bins.split(";")
    input_names = input_names.split(";")
    inputs_map = {}
    input_shapes = []
    if shapes != None:
        shapes = shapes.split(";")
        for shape in shapes:
            input_shapes.append(convertToShape(shape))
    checkConditionExit(len(input_names) == len(input_bins), "input_names must have same length with input_bins")
    for i in range(len(input_bins)):
        input_name = input_names[i]
        input = graph.get_tensor_by_name(input_name)
        if len(input_shapes) == 0:
            input_bin = np.fromfile(input_bins[i], np.float32).reshape(input.shape)
        else:
            input_bin = np.fromfile(input_bins[i], np.float32).reshape(input_shapes[i])
        inputs_map[input] = input_bin
    return inputs_map


def load_outputs(dump_all, dump_nodes, graph):
    outputs = []
    if dump_all:
        ops = graph.get_operations()
        output_names = []
        for op in ops:
            op_outputs = op.inputs
            for op_output in op_outputs:
                output_names.append(op_output.name)
        for output_name in output_names:
            node = graph.get_tensor_by_name(output_name)
            outputs.append(node)
    else:
        checkConditionExit(dump_nodes != None, "no dump_nodes provides, %s", dump_nodes)
        dump_nodes = dump_nodes.split(";")
        for dump_node in dump_nodes:
            node = graph.get_tensor_by_name(dump_node)
            outputs.append(node)
    return outputs


def NHWC2NCHW(input):
    result = input.transpose([0, 3, 1, 2])
    return result

def main(args):
    print(args)
    protobuf = args.protobuf
    pb_path = protobuf.name
    graph = load_graph(pb_path)
    input_bins = args.input_bins
    input_names = args.input_names
    shapes = args.shapes
    inputs_map = load_inputs(input_bins, input_names, shapes, graph)
    dump_all = args.dump_all
    outputs = load_outputs(dump_all, None, graph)
    config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
    with tf.Session(graph=graph, config=config) as sess:
        dump_bins = sess.run(outputs, feed_dict=inputs_map)
    pathSep = os.path.sep
    dir_name = os.path.dirname(os.path.abspath(pb_path))
    output_floder = dir_name + pathSep + os.path.basename(pb_path).split('.')[0] + "_dump"
    # print(output_floder)
    if os.path.exists(output_floder):
        info("remove dir %s", output_floder)
        shutil.rmtree(output_floder)
    info("create dir %s", output_floder)
    os.mkdir(output_floder, 755)

    for i in range(len(outputs)):
        output = outputs[i].name
        output = output.replace("/", "_")
        output = output.replace(":", ".")
        prefix = output + "." + str(round(time.time() * 1000000))
        dump_bin = dump_bins[i]
        dump_path = output_floder + pathSep + prefix + ".npy"
        np.save(dump_path, dump_bin)


if __name__ == '__main__':
    logging.basicConfig(
        format="%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s", level=logging.DEBUG)
    argsParse = argparse.ArgumentParser(
        prog=sys.argv[0], description="This script is dump tensorflow output.only support float32", epilog="Enjoy it.")
    argsParse.add_argument("protobuf", help="protobuf file path",
                           type=argparse.FileType('r', encoding="utf-8"))
    dump_cate_group = argsParse.add_mutually_exclusive_group(required=True)
    argsParse.add_argument("-i", "--input_bins", help="input_bins bins. e.g. './a.bin;./c.bin'", required=True)
    argsParse.add_argument("-n", "--input_names", help="input nodes name. e.g. 'graph_input_0:0;graph_input_0:1'")
    dump_cate_group.add_argument(
        "-a", "--dump_all", help="dump all nodes' outputs. don't use this option for large network \
        if you don't mind when will get all out. in this mode will not dump last layer outputs", action="store_true", default=False)
    argsParse.add_argument(
        "-s", "--shapes", help="input shapes. e.g. [1,2,3,4];[2,3,4,5]. if input is placeholder set input shapes by this.", default=None)
    args = argsParse.parse_args()
    main(args)




【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。