ModelArts AI Gallery推荐预置算法torch model转为script module
【摘要】 ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)通过ModelArts训练作业得到如下目录|- model |- dsf (推理依赖的代码库) |- ... ...
ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)通过ModelArts训练作业得到如下目录
|- model
|- dsf (推理依赖的代码库)
|- ...
|- customize_service.py (推理脚本)
|- config.json (ModelArts推理服务配置文件)
|- configs.yaml (算法训练和推理的配置文件)
|- best.pth (训练过程中AUC最高的模型)
|- epoch0_rank0.pth (第0个epoch的模型文件)
|- ...
在best.pth的同级目录下新建并运行 transform.py(内容如下),将torch model转换为script module。
特别注意:将num_inputs替换为实际连续特征、离散特征、多值离散特征的特征总数。
# Copyright 2022 ModelArts Authors from Huawei Cloud. All Rights Reserved.
# https://www.huaweicloud.com/product/modelarts.html
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
from dsf.python.data_v2.encoder import IDEncoder
from dsf.python.utils.config import load_config
from dsf.python.utils.model_select_v2 import build_model
########################################################################
# 特别注意: 将num_inputs替换为实际连续特征、离散特征、多值离散特征的特征总数
num_inputs = 212
########################################################################
batch_size = 64
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
CWD = os.getcwd()
sys.path.insert(0, CUR_DIR)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
torch.set_flush_denormal(True)
def model_load(weights, encoder, train_config, model_config, map_location=None):
# set device
model_config.embedding_device = device
model_config.dnn_device = device
# loads model
model = build_model(train_config, model_config, encoder)
if hasattr(model, 'module'):
model = torch.nn.DataParallel(model)
checkpoint = torch.load(weights, map_location=map_location)
model.load_state_dict(checkpoint)
model.to(map_location)
model.eval()
return model
config = load_config("./configs.yaml")
id_encoder = IDEncoder(config.encoder_config, input_reader=None)
deep_model = model_load(os.path.join(CUR_DIR, 'best.pth'), id_encoder, config.train_config,
config.model_config, map_location=device)
# 特征对应的id
ids = torch.randint(0, num_inputs, (batch_size, num_inputs), dtype=torch.long)
wts = torch.randn(batch_size, num_inputs)
# turn an existing module into a TorchScript
traced_script_module = torch.jit.trace(deep_model, (ids, wts))
ids = torch.randint(0, num_inputs, (32, num_inputs), dtype=torch.long)
wts = torch.randn(32, num_inputs)
output = traced_script_module(ids, wts)
output1 = deep_model(ids, wts)
# save script module
traced_script_module.save('./best_script_module.pth')
# compare outputs
print(torch.equal(output1, output))
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)