论文阅读 BN剪枝《Learning Efficient Networks through Network Slimming》
Learning Efficient Convolutional Networks Through Network Slimming
通过网络瘦身学习高效的卷积神经网络
作者: Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang
单位: Intel, Tsinghua University, Fudan University, Cornell University
发表会议及时间: ICCV 2017
论文导读
本文说明神经网络计算量大以及占内存的问题以及非结构化剪枝的问题。 为了解决上述问题,提出了本文的剪枝方法, 并在多个数据集多
种模型上完成了实验说明神经网络被设计的越来越深, 因此带来了计算量大以及占内存的问题。 之前的非结构性剪枝方法无法在大多数设备上进行快速的计算。 而通道剪枝方法又缺乏快速计算重要性的方案, 从而提出了本文的方案。
动机
-
通过在训练过程中剪去部分权重, 可以减少模型参数量, 降低模型计算量, 加快模型推理速度。 但是剪枝后的模型参数由于是非结构的, 因此无法直接部署到通用计算设备上。
-
通道剪枝往往直接剪去整个通道参数, 如何快速而又准确的判断通道的重要性至关重要。
-
迭代式训练可以降低剪枝带来的精度损失。
现有方法
- 低秩分解
- 权重量化
- 权重剪枝
- 通道剪枝
- 网络结构搜索(NAS)
本文方法
- 在数据集上对未压缩的模型进行完整的训练,直至收敛。
- 在经过完整训练的未压缩模型基础上,剪掉一批权重小于阈值的通道。
- 在数据集上对剪过后的模型进行权重微调,使其恢复精度。继续进行通道剪枝,直至满足剪枝要求。
本质上:用Batch Norm层中的参数大小作为通道重要性的判断依据。
研究成果
在Imagenet数据集上, 本文的方法在将VGG剪去82.5%参数的情况下, 精度基本保持不变。并且计算量降低了30.4%。
研究意义
- 将通道剪枝引入模型剪枝,提高了剪枝后的网络在计算设备上的通用性。
- 将BN层可学习的缩放参数作为判断通道重要性的依据,简化了通道重要性计算的过程。
- 在实验中对带有残差通道的网络(resnet,densenet)进行了剪枝,说明了通道剪枝方法的普适性。
论文精读
CNN通道剪枝
1. weight-level(如上篇文章)
2. channel-level
通道剪枝正如其名字channel pruning,核心思想是移除一些冗余的channel,从而简化模型。右图是通道剪枝的示意图,它表示的网络模型中某一层的channel pruning。B表示输入feature map,C表示输出的feature map;c表示输入B的通道数量,n表示输出C的通道数量;W表示卷积核,卷积核的数量是n,每个卷积核的维度是 , 和 表示卷积核的size。通道剪枝的目的就是要把B中的某些通道剪掉,但是剪掉后的B和W的卷积结果能尽可能和C接近。当删减B中的某些通道时,同时也裁剪了W中与这些通道的对应的卷积核,因此通过通过剪枝能减小卷积的运算量。
-
通道剪枝的优点
通道剪枝是对网络结构的(channels,neurons等)一部分做剪枝或稀疏化, 而不是对个别权重, 因此不太需要特别的库来实现推理加速和运行时的内存节省。 -
通道剪枝的数学表达式
通过数学表达式描述了通道剪枝。X表示输入feature map,W表示卷积核,Y表示输出feature map。beta表示通道系数,如果等于0,表示该通道可以被删除。我们期望将输入feature map的channel从c压缩为c’( ),同时要使得构造误差(reconstruction error)尽可能的小。通过下面的优化表达式,就可以选择哪些通道被删除。
上述公式的优化还是比较复杂的,首先beta并没有参与到权重参数的更新过程中,因为其非1即0,非连续值不可导。
3. layer-level
BN层介绍
Batch Normalization是2015年一篇论文中提出的数据归一化方法, 往往用在深度神经网络中激活层之前。 其作用可以加快模型训练时的收敛速度, 使得模型训练过程更加稳定, 避免梯度爆炸或者梯度消失。 并且起到一定的正则化作用, 几乎代替了Dropout。
BN层的作用
- 缓解Internal Covariate Shift(ICS)
所谓ICS即:特征图每经过一次激活函数,其分布就会发生一次变化。 - 缓解梯度消失问题
- 缓解模型过拟合(正则化效果)
- BN如何工作
1.先计算B的均值和方差,之后将B集合的均值、方差变换为0、1。
2.将B中每个元素乘以gama再加beta,输出。gama和beta是可训练参数,参与整个网络的BP。
gama可以决定输出通道的值。
基于BN的通道剪枝
本文的剪枝流程
- 初始化一个目标网络
- 在对通道进行稀疏正则化的约束下,训练目标网络
- 把缩放参数较小的通道整个去除
- 微调剪枝后的网络
基于BN的通道重要性判断
使用bn层中的缩放参数γ判断通道的重要性,当值越小,代表可以裁剪掉。 那么如果同一个bn层中γ值很接近,怎么办。 都很大时, 删除会对网络精度的很大影响。
- 通过正则化进行通道稀疏
论文中提出了使用L1范数来稀疏化γ值。网络训练过程中的损失如下
当γ<0时, 约束项的梯度为-1, 当γ>0时,约束项的梯度为1, 因此随着γ更新, 约束项会将γ拉向0。
随着惩罚系数变大, 越来越多的通道被约束到了0。
由于本文采用的是迭代式的剪枝方案,因此随着训练次数的增多,越来越多的通道被至0,下采样的区域会被保留的更多。
- 对有跨层连接的网络进行剪枝
这种剪枝方法在应对ResNet这类跨层连接的网络不太好, 因为每一层的输出会作为后续多个层的输入。
如何解决这类问题:增加一个channel selection模块,即对每个layer中最会一个Conv不做剪枝处理,而是在BN后增加一个CS模块根据gama进行选择。
实验结果及分析
Datasets
CIFAR: CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。CIFAR-100和CIFAR-10类似,每个类包含600个图像。每类各有500个训练图像和100个测试图像。
SVHN: 街景门牌号数据集,由0-9的32x32的彩色图像组成。其中训练集有60万张,测试集有2.6万张。
ImageNet:ImageNet数据集共有1000类,包含了120万张训练图像和5万张测试图像。
MNIST:黑白的0-9手写字数据集,共有10类,其中包含了6万张训练图像和1万张测试图像。
在各种数据集上的结果
- 在小数据集上, 在剪去适量通道能够缓解模型的过拟合现象。
- 基于BN的通道剪枝方法能在剪去超过一半通道的情况下,保持性能基本不变。
多步迭代式剪枝的结果
- 在CIFAR10上,迭代到第三次时获得最佳的模型结构。在CIFAR100上,迭代到第2次时获得最佳模型结构。说明越大的数据集需要越多的参数。
- 尽管L1正则化能稀疏通道,但当剪枝过多时,仍会引起性能下降。
在ImageNet上的结果
- 基于BN的通道剪枝方法, 在将VGG压缩6倍左右的参数下, 能几乎没有性能损失。
- 本文的通道剪枝方法压缩率不及非结构化剪枝。
剪枝数量对结果的影响
- 剪枝数量在阈值内,则剪枝后的模型性能可以通过微调拉回来。但是超过阈值,即使微调也无法挽救模型性能。
- 由于缓解了过拟合现象,少量的通道剪枝即使不微调也能保持性能。
论文总结
关键点
- 如何进行通道剪枝。
- 如何更加准确的判断CNN中通道的重要性。
- 如何对有跨层连接的网络进行剪枝。
创新点
- 将基于BN的通道剪枝方法引入神经网络剪枝过程, 简化了剪枝中通道重要性的计算过程。
- 在剪枝训练过程中加入正则化, 帮助提升剪枝性能。
- 在多个网络上完成了剪枝工作并获得了非常好的效果。
存在的问题
BN层中除了缩放参数gamma外, 还存在一个平移参数beta, 是否会影响L1正则化的结果?
论文实现细节
模型基本Block结构
- pre-activation ConvBlock结构效果更佳
通道选择模块
class channel_selection(nn.Module):
"""
Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
The output shape of this layer is determined by the number of 1 in `self.indexes`.
"""
def __init__(self, num_channels):
"""
Initialize the `indexes` with all one vector with the length same as the number of channels.
During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
"""
super(channel_selection, self).__init__()
self.indexes = nn.Parameter(torch.ones(num_channels))
def forward(self, input_tensor):
"""
Parameter
---------
input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
"""
selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
if selected_index.size == 1:
selected_index = np.resize(selected_index, (1,))
output = input_tensor[:, selected_index, :, :]
return output
模型构建
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.select = channel_selection(inplanes)
self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[1])
self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(cfg[2])
self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.select(out)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class XXXNet(nn.Module):
def __init__(self, depth=56, dataset='cifar10', cfg=None):
super(resnet, self).__init__()
assert (depth - 2) % 9 == 0, 'depth should be 9n+2'
n = (depth - 2) // 9 # (56 - 2) // 9 = 6
block = Bottleneck
if cfg is None:
# Construct config variable.
cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]]
cfg = [item for sub_list in cfg for item in sub_list]
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
bias=False)
self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n])
self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2)
self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2)
self.bn = nn.BatchNorm2d(64 * block.expansion)
self.select = channel_selection(64 * block.expansion)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)
if dataset == 'cifar10':
self.fc = nn.Linear(cfg[-1], 10)
elif dataset == 'cifar100':
self.fc = nn.Linear(cfg[-1], 100)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, cfg, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
)
layers = []
layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)]))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x) # 32x32
x = self.layer2(x) # 16x16
x = self.layer3(x) # 8x8
x = self.bn(x)
x = self.select(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
BN层中的gama正则化
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(s*torch.sign(m.weight.data)) # L1
训练方法
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
pred = output.data.max(1, keepdim=True)[1]
loss.backward()
if sr:
updateBN()
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
利用BN进行一次通道剪枝
原模型加载以及BN层参数提取
total = 0
percent = 0.5
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
pruned_ratio = pruned/total
新模型加载及BN参数剪枝处理
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
if isinstance(old_modules[layer_id + 1], channel_selection):
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data.zero_()
m2.indexes.data[idx1.tolist()] = 1.0
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
continue
# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
改进:利用BN进行迭代式通道剪枝
剪枝方法封装
def prune(percent):
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
pruned = pruned + mask.shape[0] - torch.sum(mask)
#m.weight.data.mul_(mask)
#m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
return cfg_mask, cfg
BN参数更新
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(0.0001*torch.sign(m.weight.data)) # L1
正式训练代码
注:这里为了保证已被剪枝的BN参数不再参与前向、反向传播,需要利用mask将其置为0。
def train(cfg_mask):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
mask_index = 0
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
mask = cfg_mask[mask_index]
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
data = m(data)
elif isinstance(m, nn.Conv2d):
data = m(data)
elif isinstance(m, nn.MaxPool2d) or isinstance(m, nn.AvgPool2d):
data = m(data)
elif isinstance(m, nn.Linear):
data = data.view(data.size(0), -1)
data = m(data)
elif isinstance(m, nn.ReLU):
data = m(data)
#output = model(data)
loss = F.cross_entropy(output, target)
pred = output.data.max(1, keepdim=True)[1]
loss.backward()
updateBN()
optimizer.step()
for percent in range(10, 60):
cfg_mask, cfg = prune(percent/100.0)
train(cfg_mask)
剪枝后新模型生成
newmodel = vgg(cfg=cfg)
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join('.', 'pruned.pth.tar'))
- 点赞
- 收藏
- 关注作者
评论(0)