在modelarts上部署backend为TensorFlow的keras模型

举报
开源小分舵-山找海味 发表于 2019/09/23 20:32:32 2019/09/23
1w+ 2 2
【摘要】 最近老山在研究在modelarts上部署mask-rcnn,源代码提供的是keras模型。我们可以将keras转化成savedModel模型,在TensorFlow Serving上部署,可参考老山的上篇部署文章。至于输入和输出张量,到已经预先存在model.input和model.output中了。不多说,直接上代码。from keras import backend as Kimport...

最近老山在研究在modelarts上部署mask-rcnn,源代码提供的是keras模型。我们可以将keras转化成savedModel模型,在TensorFlow Serving上部署,可参考老山的上篇部署文章。至于输入和输出张量,到已经预先存在model.input和model.output中了。

不多说,直接上代码。

from keras import backend as K
import tensorflow as tf
# 在此之前,先加载keras模型
# 。。。
# 加载完成

with K.get_session() as sess:
    export_path = './saved_model'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)

    signature_inputs = {
        'input_image': tf.saved_model.utils.build_tensor_info(model.input[0]),
        'input_image_meta': tf.saved_model.utils.build_tensor_info(model.input[1]),
        'input_anchors': tf.saved_model.utils.build_tensor_info(model.input[2]),
    }

    signature_outputs = {
        'mrcnn_detection':tf.saved_model.utils.build_tensor_info(model.output[0]),
        'mrcnn_class':tf.saved_model.utils.build_tensor_info(model.output[1]),
        'mrcnn_bbox':tf.saved_model.utils.build_tensor_info(model.output[2]),
        'mrcnn_mask':tf.saved_model.utils.build_tensor_info(model.output[3]),
        'ROI':tf.saved_model.utils.build_tensor_info(model.output[4]),
        'rpn_class':tf.saved_model.utils.build_tensor_info(model.output[5]),
        'rpn_bbox':tf.saved_model.utils.build_tensor_info(model.output[6]),        
    }

    classification_signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=signature_inputs,
        outputs=signature_outputs,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

    builder.add_meta_graph_and_variables(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'root': classification_signature_def
        },
    )

    builder.save()

如果您觉得老山的文章不错,不妨点击下关注。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

作者其他文章

评论(2

抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。