mxnet加载保存部分参数

举报
风吹稻花香 发表于 2021/06/04 22:37:09 2021/06/04
【摘要】   def get_specify_mod(model_str,ctx,data_shpae,layer_name): _vec = model_str.split(",") prefix = _vec[0] epoch = int(_vec[1]) sym,arg_params,aux_params = mx.model.load_checkpoint(pref...

 


  
  1. def get_specify_mod(model_str,ctx,data_shpae,layer_name):
  2. _vec = model_str.split(",")
  3. prefix = _vec[0]
  4. epoch = int(_vec[1])
  5. sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
  6. #获取神经网络所有的层
  7. all_layers = sym.get_internals()
  8. #获取输出层
  9. sym = all_layers[layer_name+"_output"]
  10. mod = mx.mod.Module(symbol=sym,context=ctx)
  11. mod.bind(data_shapes=[("data",data_shpae)])
  12. mod.set_params(arg_params,aux_params)
  13. return mod

保存部分模型ok了,准确度也是ok的。


  
  1. import mxnet as mx
  2. import numpy as np
  3. import sys
  4. import os
  5. import argparse
  6. import onnx
  7. print('mxnet version:', mx.__version__)
  8. print('onnx version:', onnx.__version__)
  9. from mxnet.contrib import onnx as onnx_mxnet
  10. from onnx import checker
  11. input_shape = (1,3,112,112)
  12. a_sym, arg_params, aux_params = mx.model.load_checkpoint("./model", 13)
  13. all_layers = a_sym.get_internals()
  14. all_layers=model.symbol.get_internals()
  15. param_key=all_layers.list_outputs()
  16. _sym=all_layers["pre_fc1_output"]
  17. model=mx.mod.Module(symbol=_sym,context=[mx.cpu()])
  18. model.bind(data_shapes=[("data", (1, 3, 112, 112))])
  19. model.params_initialized=True
  20. _arg_params, _aux_params=model.get_params()
  21. arg_params_new=dict()
  22. aux_params_new=dict()
  23. for key in _arg_params.keys():
  24. # key1=replace_key(key)
  25. arg_params_new[key]=arg_params[key]
  26. for key in _aux_params.keys():
  27. # key1=replace_key(key)
  28. aux_params_new[key]=aux_params[key]
  29. model.set_params(arg_params_new,aux_params_new,allow_missing=True)
  30. model.save_checkpoint("55_jz",0)

前几层用的是vgg中的前10层,所以需要用到pretrained vgg中的前几层来initialize我的模型。

mxnet中有什么好的办法能像这样只给模型中的部分layer进行load吗?目前我看到的load_parameters()好像是指load整个模型

 

net[0:10].load_parameters( vgg_para_file_name, allow_missing=True, ignore_extra=True)
这样?

或者直接net.load_parameters( vgg_para_file_name, allow_missing=True, ignore_extra=True)不行吗?

因为改了网络,除了前10层能匹配上,后面都匹配不上,mxnet自动就初始化匹配上的那10层了

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/115645471

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。