flask框架下多线程引发的tensorflow加载模型并预测的隐藏bug

举报
jjjjjjjjjj 发表于 2021/06/17 22:16:44 2021/06/17
【摘要】 最近部署推理服务的时候遇到了一个tensorflow在多线程中的隐藏bug。 tensorflow的模型图层的命名在多线程下是不安全的,多线程下图层的命名空间会变,导致调用predict找不到我们想要的图层。

最近部署推理服务的时候遇到了一个tensorflow在多线程中的隐藏bug

原始诉求

是希望在服务启动的时候就加载模型,并提取模型的中间层输出,之后推理不再重复加载模型,将加载好的模型传给新的线程调用。

报错

ValueError: Tensor Tensor(“dense_final/LeakyRelu:0”, shape=(None, 1), dtype=float32) is not an element of this graph.

W tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at resource_variable_ops.cc:619 : Not found: Container localhost does not exist. (Could not find resource: localhost/Embedding-Segment/embeddings)

一个是最终层的错误,一个是中间层的错误,这些都是表象。

根本原因

tensorflow的模型图层的命名在多线程下是不安全的,多线程下图层的命名空间会变,导致调用predict找不到我们想要的图层。

解决办法

定义全局字典存储初始sess和graph,每个初始模型定义一套,也就是每load一个模型,就要定义一套初始sess和graph,固定住图层。

import tensorflow as tf

sess_graph_dict={}
for global_task in task_list:
    task_sess_graph = {
        "sess": tf.compat.v1.Session(),
        "graph": tf.compat.v1.get_default_graph()
    }
    sess_graph_dict[global_task] = task_sess_graph

在加载模型前加如下代码

try:
    tmp = sess_graph_dict.get(self.task_name)
    sess = tmp.get("sess")
    set_session(sess)
    model=tf.keras.models.load_model(file_path, custom_objects=dependencies)
except Exception as e:
    ....


在切分后的模型前加如下代码(由于我们把模型做了拆分,需要在每个模型predict之前都加上如下的代码,从哪个原始模型上拆分出来的,就用哪个模型的sess和graph)

from tensorflow.python.keras.backend import set_session

try:
    global sess_graph_dict
    tmp = sess_graph_dict.get(self.task_name)
    sess = tmp.get("sess")
    graph = tmp.get("graph")
    with graph.as_default():
        set_session(sess)
        self.model.predict(self.input)
except Exception as e:
    ....

拆分模型的代码

def build_bottleneck_model(model, layer_name):
    """
    从头输入,获取中间层输出
    """
    output = None
    for layer in model.layers:
        if layer.name == layer_name:
            output = layer.output

    if output is None:
        raise Exception(...)
    bottleneck_model = Model(model.input, output)
    return bottleneck_model
def get_outputs_of(model, start_tensors, input_layers=None):
    """
    获取从中间层输入到最终输出的模型
    """

    # 为此操作建立新模型
    model = Model(inputs=model.input,
                  outputs=model.output,
                  name='outputs_of_' + model.name)
    # 适配工作,方便使用
    if not isinstance(start_tensors, list):
        start_tensors = [start_tensors]
    if input_layers is None:
        input_layers = [
            Input(shape=keras_backend.int_shape(x)[1:], dtype=keras_backend.dtype(x))
            for x in start_tensors
        ]
    elif not isinstance(input_layers, list):
        input_layers = [input_layers]
    # 核心:覆盖模型的输入
    model.inputs = start_tensors
    model._input_layers = [x._keras_history[0] for x in input_layers]
    # 适配工作,方便使用
    if len(input_layers) == 1:
        input_layers = input_layers[0]
    # 整理层,参考自 Model 的 run_internal_graph 函数
    layers, tensor_map = [], set()
    for x in model.inputs:
        tensor_map.add(str(id(x)))
    depth_keys = list(model._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
        nodes = model._nodes_by_depth[depth]
        for node in nodes:
            n = 0
            input_tensors = node.input_tensors
            if not isinstance(input_tensors, list):
                input_tensors = [input_tensors]
            for x in input_tensors:
                if str(id(x)) in tensor_map:
                    n += 1
            if n == len(input_tensors):
                if node.outbound_layer not in layers:
                    layers.append(node.outbound_layer)
                output_tensors = node.output_tensors
                if not isinstance(output_tensors, list):
                    output_tensors = [output_tensors]
                for x in output_tensors:
                    tensor_map.add(str(id(x)))
    model._layers = layers  # 只保留用到的层
    # 计算输出
    outputs = model(input_layers)
    cut_model = Model(input_layers, outputs)

    return cut_model

参考:

keras/tensorflow 使用flask部署服务的常见错误及部署多个模型

“让Keras更酷一些!”:层与模型的重用技巧

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。