如何将Pytorch模型转ONNX格式并使用OnnxRuntime推理

布拉德皮特痒 发表于 2020/06/30 16:00:33 2020/06/30
【摘要】 Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。本文介绍如何将Pytorch训练好的模型转成ONNX格式进行推理。Pytorch模型定义和模型权重暂时不支持...

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。本文介绍如何将Pytorch训练好的模型转成ONNX格式进行推理。


Pytorch模型定义和模型权重暂时不支持打包在一起,这在推理时候需要先用模型定义代码构建模型,再加载模型权重,比较麻烦。借助于onnx格式转换可以把模型打包一起,在ONNX Runtime中运行推理,ONNX Runtime 是针对 ONNX 模型的以性能为中心的引擎,可大大提升模型的性能。另外,onnx模型支持在不同框架之间转换,也支持tensorRT加速。


Pytorch模型加载

首先将Pytorch模型加载进来,包括模型结构定义和权重加载。注意需要将模型转换成推理模式,去除dropout和batchnorm层训练和推理不同的影响。

torch_model = model()
torch_model.load_state_dict()

# set the model to inference mode
torch_model.eval()

Pytorch模型转换成ONNX格式

我们调用torch.onnx.export()函数将Pytorch模型转换成ONNX格式。 这将执行模型,并记录使用什么运算符计算输出的轨迹。 因为export运行模型,所以我们需要提供输入张量x注意,由于pytorch在不断更新来解决转onnx过程中的bug,建议采用最新版本的pytorch。建议采用opset_version=11,对一些层支持性较好。

# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model,                       # model being run
                  x,                                 # model input (or a tuple for multiple inputs)
                  "test.onnx",                       # where to save the model (can be a file or file-like object)
                  export_params=True,                # store the trained parameter weights inside the model file
                  opset_version=11,                  # the ONNX version to export the model to
                  do_constant_folding=True,          # whether to execute constant folding for optimization
                  input_names = ['input'],           # the model's input names
                  output_names = ['output'],         # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},       # variable lenght axes
                                'output' : {0 : 'batch_size'}})

ONNX模型精度验证

我们先用 onnx.load("test.onnx")加载模型, onnx.checker.check_model(onnx_model)验证模型的结构并确认模型具有有效的架构。然后我们可以验证 ONNX Runtime 和 PyTorch 网络输出值是否相同。

import onnx
onnx_model = onnx.load("test.onnx")
onnx.checker.check_model(onnx_model)

import onnxruntime
ort_session = onnxruntime.InferenceSession("test.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

ONNX模型使用onnxruntime推理

使用 ONNX Runtime 运行模型,需要使用onnxruntime.InferenceSession("test.onnx")为模型创建一个推理会话。创建会话后,我们将使用 run()API 运行推理模型获得推理输出结果。这样,就完成了Pytorch模型的打包推理。

from PIL import Image
import torchvision.transforms as transforms

img = Image.open("./_static/img/cat.jpg")
image = np.asarray(img ,dtype=np.float32)
image = np.transpose(image,(2,0,1))##input in CHW format

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区),文章链接,文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件至:cloudbbs@huaweicloud.com进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容。
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。