关于使用omg转换tensorflow模型失败的一个解决方案
pre:
最近在tf2中复现了一个resnet34_CenterNet网络,在tf2上测试性能良好,所以想部署到Atlas 200 DK上去试试效果。由于tf2保存save_model的格式是网络结构与权重分开的,故需要调用tf.compact.v1里面的内容进行序列化保存,即生成华为omg程序需要的.pb文件。这一步进行得比较顺利,很快就得到了pb文件,但当我在om权重的时候,竟然报出了里面一些层不支持的错误(FusedBatchNorm AddV2等等)。原因分析了下,应该是tf.compact.v1里面的内容是比较高版本(>=1.14.x)的tensorflow,而华为官方的推荐是1.12.0,更高版本的tensorflow用了更新的计算算子,导致了转换错误。
ok,问题找到了,解决思路也有了:
A.使用华为提供的算子开发toolkit进行对应的算子开发。这个方法成本太高,除了要熟悉这套toolkit还得去看tensorflow2中关于这些层的具体实现或者对应的论文;
B.将此h5文件转换成caffe model,可以尝试下使用mmdnn进行转换,不过好像有些模型支持度也不是很好,备用策略;
C.将此h5由低版本的tf读入,再由低版本的tf进行网络冻结,由于h5权重是由更高一级的keras_api生成的,tensorflow只是充当背后的计算支持(个人浅显理解),所以只要低版本的tf中的keras能够成功读入模型,转换应该就比较容易了。
step:
1.在tf2中训练好,保存为h5权重;
2.在低版本中使用tf.keras.models.load_model('xxxx.h5')读入模型,如果没有报错,就可以进行下一步;
3.进行网络冻结,实质就是将网络结构与权重合并在一起,便于部署;
在这个过程中可能存在比较多的报错,具体视模型中使用的层的类型是否在高版本进行更新了:
3.1 错误:
ValueError: ('Unrecognized keyword arguments:', dict_keys(['ragged']))
因为低版本的Input不支持ragged参数,所以点一下最后报错的py文件,进入到tf1.x的源码,屏蔽如下代码
# if kwargs: # raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
3.2 错误
ValueError: Unknown initializer:XXXX
由于低版本中没有关于XXXX的实现,可以直接在新版本的tf2中拷贝关于XXXX实现的代码,贴到load_model之前;
3.3 对于
TypeError: ('Keyword argument not understood:', 'threshold') TypeError: ('Keyword argument not understood:', 'interpolation')
此类错误,是由于两个版本的关于报出这个错误的层的实现不一致导致的,所以解决方法是将低版本的相关层实现替换为高版本中的实现,层的实现当中可能还调用到其他一些工具类函数,这块的修改量可能比较大
可参考https://zhen8838.github.io/2020/03/18/h5-to-pb/
修改完应该就可以成功转换了。
4.转换代码:
import tensorflow as tf import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' tf.keras.backend.set_learning_phase(0) # 载入模型 model = tf.keras.models.load_model('checkpoints/2020-04-11/debug_model.h5') def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): from tensorflow.python.framework.graph_util import convert_variables_to_constants graph = session.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] # Graph -> GraphDef ProtoBuf input_graph_def = graph.as_graph_def(add_shapes=True) if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph frozen_graph = freeze_session(tf.keras.backend.get_session(), output_names=[out.op.name for out in model.outputs]) tf.train.write_graph(frozen_graph, "pb", "tf_model.pb", as_text=False)
参考:https://zhen8838.github.io/2020/03/18/h5-to-pb/
- 点赞
- 收藏
- 关注作者
评论(0)