mxnet 卷积层和BN 合并
【摘要】 https://github.com/jac578/merge_bn_mxnet/blob/master/merge_bn_mxnet.py
import osimport os.path as ospimport mxnet as mximport jsonimport sysimport numpy as npimport copy import fr...
https://github.com/jac578/merge_bn_mxnet/blob/master/merge_bn_mxnet.py
-
import os
-
import os.path as osp
-
import mxnet as mx
-
import json
-
import sys
-
import numpy as np
-
import copy
-
-
import fresnet
-
-
def merge_bn(args, auxs, conv_name, bn_prefix):
-
conv_weights = args[conv_name+"_weight"].asnumpy()
-
gamma = args[bn_prefix+"_gamma"].asnumpy()
-
beta = args[bn_prefix+"_beta"].asnumpy()
-
# print('conv_weights.shape={}'.format(conv_weights.shape))
-
mean = auxs[bn_prefix+"_moving_mean"].asnumpy()
-
variance = auxs[bn_prefix+"_moving_var"].asnumpy()
-
channels = conv_weights.shape[0]
-
epsilon = 2e-5
-
rstd = 1. / np.sqrt(variance + epsilon)
-
rstd = rstd.reshape((channels, 1, 1, 1))
-
gamma = gamma.reshape((channels, 1, 1, 1))
-
beta = beta.reshape((channels, 1, 1, 1))
-
# bias = bias.reshape((channels, 1, 1, 1))
-
mean = mean.reshape((channels, 1, 1, 1))
-
-
new_weights = conv_weights * gamma * rstd
-
# new_bias = (bias - mean) * rstd * gamma + beta
-
new_bias = ( - mean) * rstd * gamma + beta
-
-
new_bias = new_bias.reshape((channels,))
-
-
args[conv_name+"_weight"] = mx.nd.array(new_weights)
-
args[conv_name+"_bias"] = mx.nd.array(new_bias)
-
-
# delete
-
args.pop(bn_prefix+"_gamma")
-
args.pop(bn_prefix+"_beta")
-
auxs.pop(bn_prefix+"_moving_mean")
-
auxs.pop(bn_prefix+"_moving_var")
-
-
if __name__ == '__main__':
-
prefix='r50_128d/model-r50-128d-slim'
-
epoch=0
-
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix,epoch)
-
-
# mxnet_symbol=json.loads(sym.tojson())
-
-
conv_names=['conv0','stage1_unit1_conv1','stage1_unit1_conv2','stage1_unit1_conv1sc','stage1_unit2_conv1','stage1_unit2_conv2','stage1_unit3_conv1','stage1_unit3_conv2','stage2_unit1_conv1','stage2_unit1_conv2','stage2_unit1_conv1sc','stage2_unit2_conv1','stage2_unit2_conv2','stage2_unit3_conv1','stage2_unit3_conv2','stage2_unit4_conv1','stage2_unit4_conv2','stage3_unit1_conv1','stage3_unit1_conv2','stage3_unit1_conv1sc','stage3_unit2_conv1','stage3_unit2_conv2','stage3_unit3_conv1','stage3_unit3_conv2','stage3_unit4_conv1','stage3_unit4_conv2','stage3_unit5_conv1','stage3_unit5_conv2','stage3_unit6_conv1','stage3_unit6_conv2','stage3_unit7_conv1','stage3_unit7_conv2','stage3_unit8_conv1','stage3_unit8_conv2','stage3_unit9_conv1','stage3_unit9_conv2','stage3_unit10_conv1','stage3_unit10_conv2','stage3_unit11_conv1','stage3_unit11_conv2','stage3_unit12_conv1','stage3_unit12_conv2','stage3_unit13_conv1','stage3_unit13_conv2','stage3_unit14_conv1','stage3_unit14_conv2','stage4_unit1_conv1','stage4_unit1_conv2','stage4_unit1_conv1sc','stage4_unit2_conv1','stage4_unit2_conv2','stage4_unit3_conv1','stage4_unit3_conv2']
-
-
bn_prefixes=['bn0','stage1_unit1_bn2','stage1_unit1_bn3','stage1_unit1_sc','stage1_unit2_bn2','stage1_unit2_bn3','stage1_unit3_bn2','stage1_unit3_bn3','stage2_unit1_bn2','stage2_unit1_bn3','stage2_unit1_sc','stage2_unit2_bn2','stage2_unit2_bn3','stage2_unit3_bn2','stage2_unit3_bn3','stage2_unit4_bn2','stage2_unit4_bn3','stage3_unit1_bn2','stage3_unit1_bn3','stage3_unit1_sc','stage3_unit2_bn2','stage3_unit2_bn3','stage3_unit3_bn2','stage3_unit3_bn3','stage3_unit4_bn2','stage3_unit4_bn3','stage3_unit5_bn2','stage3_unit5_bn3','stage3_unit6_bn2','stage3_unit6_bn3','stage3_unit7_bn2','stage3_unit7_bn3','stage3_unit8_bn2','stage3_unit8_bn3','stage3_unit9_bn2','stage3_unit9_bn3','stage3_unit10_bn2','stage3_unit10_bn3','stage3_unit11_bn2','stage3_unit11_bn3','stage3_unit12_bn2','stage3_unit12_bn3','stage3_unit13_bn2','stage3_unit13_bn3','stage3_unit14_bn2','stage3_unit14_bn3','stage4_unit1_bn2','stage4_unit1_bn3','stage4_unit1_sc','stage4_unit2_bn2','stage4_unit2_bn3','stage4_unit3_bn2','stage4_unit3_bn3']
-
-
assert(len(conv_names)==len(bn_prefixes))
-
for i in xrange(len(conv_names)):
-
conv_name = conv_names[i]
-
bn_prefix = bn_prefixes[i]
-
merge_bn(arg_params, aux_params, conv_name, bn_prefix)
-
-
emb_size=128
-
num_layers=50
-
version_se=0
-
version_input=1
-
version_output='E'
-
version_unit=3
-
version_act='prelu'
-
nobn_sym=fresnet.get_symbol(emb_size, num_layers,
-
version_se=version_se, version_input=version_input,
-
version_output=version_output, version_unit=version_unit,
-
version_act=version_act)
-
-
mx.model.save_checkpoint('mergebn_test',0,nobn_sym,arg_params, aux_params)
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/117050087
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)