使用ONNX将Pytorch转为Tensorflow的尝试

举报
nineteens 发表于 2021/03/09 16:17:44 2021/03/09
【摘要】 使用ONNX将Pytorch转为Tensorflow的尝试

  定义模型代码(resolution_model):

  import torch.nn as nn

  import torch.nn.init as init

  class SuperResolutionNet(nn.Module):

  def __init__(self, upscale_factor, inplace=False):

  super(SuperResolutionNet, self).__init__()

  self.relu = nn.ReLU(inplace=inplace)

  self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))

  self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))

  self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))

  self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))

  self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

  self._initialize_weights()

  def forward(self, x):

  x = self.relu(self.conv1(x))

  x = self.relu(self.conv2(x))

  x = self.relu(self.conv3(x))

  x = self.pixel_shuffle(self.conv4(x))

  return x

  def _initialize_weights(self):

  init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))

  init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))

  init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))

  init.orthogonal_(self.conv4.weight)

  调用测试的代码:

  from resolution_model import resolution

  import numpy as np

  import os

  from torch import nn

  import torch.utils.model_zoo as model_zoo

  import torch.onnx

  from onnx_tf.backend import prepare

  import onnx

  import onnxruntime

  import tensorflow as tf

  from PIL import Image

  import torchvision.transforms as transforms

  # 将torch模型保存为ONNX

  def save_onnx():

  torch_model = resolution.SuperResolutionNet(upscale_factor=3)

  # Load pretrained model weights

  model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

  batch_size = 1 # just a random number

  # Initialize model with the pretrained weights

  map_location = lambda storage, loc: storage

  if torch.cuda.is_available():

  map_location = None

  torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

  # set the model to inference mode

  torch_model.eval()

  x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

  torch_out = torch_model(x)

  # Export the model

  torch.onnx.export(torch_model, # model being run

  x, # model input (or a tuple for multiple inputs)

  "super_resolution.onnx", # where to save the model (can be a file or file-like object)

  export_params=True, # store the trained parameter weights inside the model file

  opset_version=10, # the ONNX version to export the model to

  do_constant_folding=True, # whether to execute constant folding for optimization

  input_names=['input'], # the model's input names

  output_names=['output'], # the model's output names

  dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes

  'output': {0: 'batch_size'}})

  # 输入的图像进行数据处理

  def data_process():

  img = Image.open("img/cat_224x224.jpg")

  resize = transforms.Resize([224, 224])

  img = resize(img)

  img_ycbcr = img.convert('YCbCr')

  img_y, img_cb, img_cr = img_ycbcr.split()

  to_tensor = transforms.ToTensor()

  img_y = to_tensor(img_y)

  img_y.unsqueeze_(0)

  return img_cb, img_cr,img_y

  # 将修改分辨率的图像结果保存为jpg

  def save_gen_img(model_output,img_cb,img_cr,new_img_path):

  if isinstance(model_output,list):

  model_output=model_output[0]

  out_img = model_output[0] * 255.0

  # np_out = np.clip(out_img.detach().numpy(),0,255)

  # transforms.ToPILImage()(np_out[0]).show()

  if not isinstance(model_output,torch.Tensor):

  img_out_y = Image.fromarray(np.uint8(np.clip(out_img, 0, 255)[0]), mode="L")

  else:

  img_out_y = Image.fromarray(np.uint8(np.clip(out_img.detach().numpy(), 0, 255)[0]), mode="L")

  final_img = Image.merge(

  "YCbCr", [

  img_out_y,

  img_cb.resize(img_out_y.size, Image.BICUBIC),

  img_cr.resize(img_out_y.size, Image.BICUBIC),

  ]).convert("RGB")

  final_img.save(new_img_path)

  # 使用pytorch进行预测

  def torch_predict(path):

  torch_model = resolution.SuperResolutionNet(upscale_factor=3)

  torch_model.load_state_dict(model_zoo.load_url("superres_epoch100-44c6958e.pth", map_location="cpu"))

  img_cb, img_cr,img_y = data_process()

  out = torch_model(img_y)

  save_gen_img(out,img_cb,img_cr,path)

  # 使用tensorflow进行预测

  def tf_predict(path="super_resolution.onnx"):

  tf.compat.v1.disable_eager_execution()

  onnx_model = onnx.load("super_resolution.onnx")

  onnx.checker.check_model(onnx_model)

  # input_x = tf.compat.v1.placeholder(tf.float32,shape=[None,1,224,224])

  img_cb, img_cr, img_y = data_process()

  tf_model_path = "tf_model/resolution.pb"

  if os.path.exists(tf_model_path):

  # tf.saved_model.load("tf_model")

  sess = tf.compat.v1.Session()

  with tf.compat.v1.gfile.GFile(tf_model_path, "rb") as f:

  graph_def = tf.compat.v1.GraphDef()

  graph_def.ParseFromString(f.read())

  sess.graph.as_default()

  tf.compat.v1.import_graph_def(graph_def, name='')

  input = sess.graph.get_tensor_by_name('input:0')

  output = sess.graph.get_tensor_by_name('output:0')

  # sess.run(tf.compat.v1.global_variables_initializer())

  img_y = img_y.detach().numpy()

  out = sess.run(output,feed_dict={input:img_y})

  save_gen_img(out, img_cb, img_cr, path)

  else:大连做人流哪家好 http://mobile.dlrlyy.com/

  tf_model = prepare(onnx_model)

  tf_model.export_graph(tf_model_path)

  out = tf_model.run(img_y)

  save_gen_img(out, img_cb, img_cr,path)

  # 使用onnxruntime进行预测

  def onnx_predict(path):

  onnx_model = onnx.load("super_resolution.onnx")

  onnx.checker.check_model(onnx_model)

  img_cb, img_cr, img_y = data_process()

  ort_session = onnxruntime.InferenceSession("super_resolution.onnx")

  ort_inputs = {ort_session.get_inputs()[0].name: img_y.numpy()}

  ort_outs = ort_session.run(None, ort_inputs)

  save_gen_img(ort_outs, img_cb, img_cr, path)

  if __name__ == '__main__':

  # save_onnx()

  # torch_predict("img/cat_superres_with_ort.jpg")

  tf_predict('img/tf_super_resolution.jpg')

  # onnx_predict('img/onnx_super_resolution.jpg')

  另外ONNX作为中间格式,尝试将其转为tensorflow2.0,加载模型时报错

  ValueError: Importing a SavedModel with tf.saved_model.load requires a 'tags=' argument if there is more than one MetaGraph. Got 'tags=None', but there are 0 MetaGraphs in the SavedModel with tag sets []. Pass a 'tags=' argument to load this SavedModel.

  在onnx-tensorflow github官方库有人提了这个issues,作者回复“The pb file cannot be loaded as a TF SavedModel. A pending PR, #603, is supposed to fix it. Feel free to try it out before it is merged”,因此考虑使用tf.V1的api进行加载,参考代码中的tf_predict()函数。

  最终torch_predict,onnx_predict,tf_predict三个函数的输出结果是一致的。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。