如何将Pytorch模型转ONNX格式并使用OnnxRuntime推理
Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。本文介绍如何将Pytorch训练好的模型转成ONNX格式进行推理。
Pytorch模型定义和模型权重暂时不支持打包在一起,这在推理时候需要先用模型定义代码构建模型,再加载模型权重,比较麻烦。借助于onnx格式转换可以把模型打包一起,在ONNX Runtime中运行推理,ONNX Runtime 是针对 ONNX 模型的以性能为中心的引擎,可大大提升模型的性能。另外,onnx模型支持在不同框架之间转换,也支持tensorRT加速。
Pytorch模型加载
首先将Pytorch模型加载进来,包括模型结构定义和权重加载。注意需要将模型转换成推理模式,去除dropout和batchnorm层训练和推理不同的影响。
torch_model = model() torch_model.load_state_dict() # set the model to inference mode torch_model.eval()
Pytorch模型转换成ONNX格式
我们调用torch.onnx.export()
函数将Pytorch模型转换成ONNX格式。 这将执行模型,并记录使用什么运算符计算输出的轨迹。 因为export
运行模型,所以我们需要提供输入张量x
。注意,由于pytorch在不断更新来解决转onnx过程中的bug,建议采用最新版本的pytorch。建议采用opset_version=11,对一些层支持性较好。
# Input to the model x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) torch_out = torch_model(x) # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "test.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}})
ONNX模型精度验证
我们先用 onnx.load("test.onnx")加载模型, onnx.checker.check_model(onnx_model)验证模型的结构并确认模型具有有效的架构。然后我们可以验证 ONNX Runtime 和 PyTorch 网络输出值是否相同。
import onnx onnx_model = onnx.load("test.onnx") onnx.checker.check_model(onnx_model) import onnxruntime ort_session = onnxruntime.InferenceSession("test.onnx") def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) print("Exported model has been tested with ONNXRuntime, and the result looks good!")
ONNX模型使用onnxruntime推理
使用 ONNX Runtime 运行模型,需要使用onnxruntime.InferenceSession("test.onnx")为模型创建一个推理会话。创建会话后,我们将使用 run()API 运行推理模型获得推理输出结果。这样,就完成了Pytorch模型的打包推理。
from PIL import Image import torchvision.transforms as transforms img = Image.open("./_static/img/cat.jpg") image = np.asarray(img ,dtype=np.float32) image = np.transpose(image,(2,0,1))##input in CHW format ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} ort_outs = ort_session.run(None, ort_inputs)
- 点赞
- 收藏
- 关注作者
评论(0)