mxnet加载保存部分参数
【摘要】
def get_specify_mod(model_str,ctx,data_shpae,layer_name): _vec = model_str.split(",") prefix = _vec[0] epoch = int(_vec[1]) sym,arg_params,aux_params = mx.model.load_checkpoint(pref...
-
def get_specify_mod(model_str,ctx,data_shpae,layer_name):
-
_vec = model_str.split(",")
-
prefix = _vec[0]
-
epoch = int(_vec[1])
-
sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)
-
#获取神经网络所有的层
-
all_layers = sym.get_internals()
-
#获取输出层
-
sym = all_layers[layer_name+"_output"]
-
mod = mx.mod.Module(symbol=sym,context=ctx)
-
mod.bind(data_shapes=[("data",data_shpae)])
-
mod.set_params(arg_params,aux_params)
-
return mod
保存部分模型ok了,准确度也是ok的。
-
-
import mxnet as mx
-
import numpy as np
-
import sys
-
import os
-
import argparse
-
import onnx
-
-
print('mxnet version:', mx.__version__)
-
print('onnx version:', onnx.__version__)
-
-
from mxnet.contrib import onnx as onnx_mxnet
-
-
from onnx import checker
-
-
-
input_shape = (1,3,112,112)
-
-
a_sym, arg_params, aux_params = mx.model.load_checkpoint("./model", 13)
-
-
all_layers = a_sym.get_internals()
-
-
-
all_layers=model.symbol.get_internals()
-
-
param_key=all_layers.list_outputs()
-
-
_sym=all_layers["pre_fc1_output"]
-
-
model=mx.mod.Module(symbol=_sym,context=[mx.cpu()])
-
model.bind(data_shapes=[("data", (1, 3, 112, 112))])
-
-
model.params_initialized=True
-
_arg_params, _aux_params=model.get_params()
-
arg_params_new=dict()
-
aux_params_new=dict()
-
for key in _arg_params.keys():
-
# key1=replace_key(key)
-
arg_params_new[key]=arg_params[key]
-
for key in _aux_params.keys():
-
# key1=replace_key(key)
-
aux_params_new[key]=aux_params[key]
-
-
model.set_params(arg_params_new,aux_params_new,allow_missing=True)
-
model.save_checkpoint("55_jz",0)
前几层用的是vgg中的前10层,所以需要用到pretrained vgg中的前几层来initialize我的模型。
mxnet中有什么好的办法能像这样只给模型中的部分layer进行load吗?目前我看到的load_parameters()好像是指load整个模型
net[0:10].load_parameters( vgg_para_file_name, allow_missing=True, ignore_extra=True)
这样?
或者直接net.load_parameters( vgg_para_file_name, allow_missing=True, ignore_extra=True)不行吗?
因为改了网络,除了前10层能匹配上,后面都匹配不上,mxnet自动就初始化匹配上的那10层了
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/115645471
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)