ModelArts AI Gallery推荐预置算法torch model转为script module

举报
逸尘2022 发表于 2022/03/07 19:58:08 2022/03/07
【摘要】 ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)通过ModelArts训练作业得到如下目录|- model |- dsf (推理依赖的代码库) |- ... ...

ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)通过ModelArts训练作业得到如下目录

|- model
  |- dsf  (推理依赖的代码库)
    |- ...
  |- customize_service.py  (推理脚本)
  |- config.json  (ModelArts推理服务配置文件)
  |- configs.yaml  (算法训练和推理的配置文件)
  |- best.pth  (训练过程中AUC最高的模型)
|- epoch0_rank0.pth  (第0个epoch的模型文件)
|- ...

在best.pth的同级目录下新建并运行 transform.py(内容如下),将torch model转换为script module。

特别注意:将num_inputs替换为实际连续特征、离散特征、多值离散特征的特征总数。

# Copyright 2022 ModelArts Authors from Huawei Cloud. All Rights Reserved.
# https://www.huaweicloud.com/product/modelarts.html
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import torch

from dsf.python.data_v2.encoder import IDEncoder
from dsf.python.utils.config import load_config
from dsf.python.utils.model_select_v2 import build_model

########################################################################
# 特别注意: 将num_inputs替换为实际连续特征、离散特征、多值离散特征的特征总数
num_inputs = 212 
########################################################################
batch_size = 64

CUR_DIR = os.path.dirname(os.path.abspath(__file__))
CWD = os.getcwd()
sys.path.insert(0, CUR_DIR)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    torch.set_flush_denormal(True)


def model_load(weights, encoder, train_config, model_config, map_location=None):
    # set device
    model_config.embedding_device = device
    model_config.dnn_device = device

    # loads model
    model = build_model(train_config, model_config, encoder)
    if hasattr(model, 'module'):
        model = torch.nn.DataParallel(model)
    checkpoint = torch.load(weights, map_location=map_location)
    model.load_state_dict(checkpoint)
    model.to(map_location)
    model.eval()
    return model


config = load_config("./configs.yaml")
id_encoder = IDEncoder(config.encoder_config, input_reader=None)
deep_model = model_load(os.path.join(CUR_DIR, 'best.pth'), id_encoder, config.train_config,
                        config.model_config, map_location=device)

# 特征对应的id
ids = torch.randint(0, num_inputs, (batch_size, num_inputs), dtype=torch.long)
wts = torch.randn(batch_size, num_inputs)

# turn an existing module into a TorchScript
traced_script_module = torch.jit.trace(deep_model, (ids, wts))

ids = torch.randint(0, num_inputs, (32, num_inputs), dtype=torch.long)
wts = torch.randn(32, num_inputs)

output = traced_script_module(ids, wts)
output1 = deep_model(ids, wts)
# save script module
traced_script_module.save('./best_script_module.pth')
# compare outputs
print(torch.equal(output1, output))
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。