论文阅读 经典剪枝方法《Learning both Weights and Connections for Networks》
阅读大纲
动机
-
通过在训练过程中改变模型结构, 可以减少模型参数量, 降低模型计算量, 加快模型推理速度
-
模型通过训练学习到的权重不仅可以用来推理出最后结果, 也可以当作判断神经元重要性的指标。
-
同时训练权重和要剪的神经元可以降低损失。
现有方法
- SVD和量化。
- 用池化层代替全连接层。
- 其他剪枝方法。 如基于海森矩阵的方法。
本文方法
- 在数据集上对未压缩的模型进行完整的训练, 直至收敛。
- 在经过完整训练的未压缩模型基础上, 剪掉一批权重小于阈值的神经元。
- 在数据集上对剪过后的模型进行权重微调, 使其恢复精度。
研究成果
在Imagenet数据集上, 本文的方法在将AlexNet压缩了9倍参数的情况下, 精度基本保持不变。
本文的研究意义
• 将剪枝引入模型训练过程, 降低了剪枝导致的精度损失
• 将神经元的权重值作为判断神经元重要性的依据, 简化了神经元重要性计算的过程。
• 在实验中讨论了各种正则化对剪枝过程的影响。
论文泛读
Abstract
说明神经网络计算量大以及占内存的问题。 为了解决该问题, 提出了本文的剪枝方法, 并在imagenet上进行了验证。
核心内容
- 要解决的问题:神经网络计算量太大,占用内存太多。
- 解决问题的办法:同时学习权重和连接。
- 证明办法能解决问题:在ImageNet数据集上对AlexNet进行了9倍压缩且基本不损失精度。
Introduction
用更多的数据和例子说明神经网络计算量大以及占内存的问题。 提出本文的方法。
Related Work
模型压缩的相关工作
Learning Connections in Addition to Weights
具体介绍本文的剪枝方法以及用到的Tricks
Experiments
在MNIST,ImageNet数据集上进行实验验证
Conclusion
论文总结
模型参数量和计算量
什么是参数量
参数量就是指, 模型所有带参数的层的权重参数总量。 视觉类网络组件中带参数的层, 主要有: 卷积层
、 BN层
、 全连接层
等。 ( 注意: 激活函数层(relu等)和Maxpooling层、 Upsample层是没有参数的不需要学习, 他们只是提供了一种非线性的变换)
卷积层参数量:
BN层参数量:
全连接层参数量:
什么是FLOPS、FLOPs
FLOPS
(即“每秒浮点运算次数”,“每秒峰值速度”)。是“每秒所执行的浮点运算次数”(floating-point operations per second)的缩写。 它常被用来估算电脑的执行效能,尤其是在使用到大量浮点运算的科学计算领域中。
FLOPs
:floating point of operations 的缩写,即浮点运算次数,用来衡量算法、模型复杂度
卷积层FLOPs:
全连接层FLOPs:
什么是MAC
MAC
(Memory Access Cost,内存访问成本),计算机在进行计算时候要加载到缓存中,然后再计算,这个加载过程是需要时间的。其中,分组卷积(group convolution)是对MAC消耗比较多的操作。
卷积层(输入+输出+权重参数量):
全连接层(输入+输出+权重参数量):
在评估模型性能的时候要综合考虑FLOPs和MAC
神经网络剪枝(论文精读)
剪枝从何而来?
人工神经网络中的剪枝受启发于人脑中的突触修剪( Synaptic Pruning) 。 突触修剪即轴突和树突完全衰退和死亡,是许多哺乳动物幼年期和青春期间发生的突触消失过程。 突触修剪从人出生时就开始了, 一直持续到 20 多岁。
神经网络中的剪枝
神经网络通常如图左所示:下层中的每个神经元与上一层有连接,但这意味着我们必须进行大量浮点相乘操作。完美情况下,我们只需将每个神经元与几个其他神经元连接起来,不用进行其他浮点相乘操作,这叫做稀疏
网络。
非结构化剪枝(实用价值不高)
非结构化剪枝是把每个单一的权重设为0(并没有减少参数量,要搭配CSC/CSR等类似方法)。因此若矩阵含有几乎不重要的行跟列,非结构化剪枝并没有办法完全把它移除掉(矩阵都是几百维的,不太可能刚好都被剪为0),只能把它大部分的权重设为0。
结构化剪枝(实用性更好)
结构化剪枝是把整行整列的权重移除掉(即把一个神经元去掉)。若在一个重要的行跟列里有些许不重要的权重,结构化剪枝没办法把它设为0。
如何判断神经元的重要性
- L1 Norm
- L2 Norm
三步剪枝方法
- 在数据集上从头训练一个网络,获得连接权重。
- 根据规则剪掉一定数量的神经元。
- 将剪完后的神经网络继续在数据集上进行微调,使其恢复正确率。
- 为何要迭代多次?
L1和L2正则化
- 为什么要在训练过程中加入正则化?
设目标方程为:
对c2权重剪枝后:y^' - y = \deltac_2x_2
加入正则化后,使得 本就接近于0,则可大大加速剪枝训练速度。
- L1和L2正则化的区别
L1 正则化特征
下图圆圈为Loss函数的优化区域,正方形为正则化优化区域,要保证优化方向既在圆圈中又在正方形中。
L2 正则化特征
- L1、L2正则化对剪枝结果的影响
- 只剪一次且没有重训练的情况下,L1效果比L2好。
- 只剪一次且有重训练的情况下,L2效果比L1好。
Dropout ratio 自适应
实验结果分析
MINIST数据集上效果对比
由以下数据得出的结论
(1) 靠后的全连接层被保留了更多的参数。
(2) 前面的卷积层比后面 的卷积层更重要
在剪枝后,对应图像中间区域的参数被保留的更多,说明该剪枝方法能保留重要参数
ImageNet数据集在AlexNet上的表现
在ImageNet上,经过9倍压缩的AlexNet错误率 不升反降。 在AlexNet上,同样前面的卷积层比靠后的卷积 层更加重要,除了最后一层全连接层,前面的全 连接层几乎可以被忽略。
下图中,左边是卷积层剪枝敏感性曲线 右边是全连接层剪枝敏感性曲线
结论:在图像任务上,卷积层相较于全连接层非 常非常重要。
ImageNet数据集在Vgg上的表现
在ImageNet上,经过13倍压缩的VGG-16错误率 同样不升反降。 在VGG-16上,总体也符合前面卷积层比靠后的 卷积层重要的结论,除了最后一层全连接层,前 面的全连接层几乎可以被忽略。
整体总结
关键点
- 影响神经网络剪枝后性能的原因是什么。
- 如何更加准确的判断神经网络中神经元的重要性。
创新点
- 将迭代剪枝方法引入神经网络剪枝过程,提升神经网络剪枝后的模型性能。
- 在剪枝训练过程中加入正则化, 帮助提升剪枝速度。
- 在多个网络上完成了剪枝工作并获得了非常好的效果
一些实现细节
非结构化剪枝实现(并未真正减少参数与计算量)
- 构建如下示例网络,重点是对每个卷积层和全连接层加入一个
权重mask
class XXXNet(nn.Module):
def __init__(self, num_classes=xx):
super(XXXNet, self).__init__()
self.conv1 = nn.Conv2d()
self.mask1 = torch.ones_like(self.conv1.weight)
self.conv2 = nn.Conv2d()
self.mask2 = torch.ones_like(self.conv2.weight)
self.fc1 = nn.Linear(400, 120, bias=False)
self.mask3 = torch.ones_like(self.fc1.weight)
self.fc2 = nn.Linear(120, 84, bias=False)
self.mask4 = torch.ones_like(self.fc2.weight)
...
def forward(self, x):
self.conv1.weight.data = torch.mul(self.conv1.weight, self.mask1)
self.conv2.weight.data = torch.mul(self.conv2.weight, self.mask2)
self.fc1.weight.data = torch.mul(self.fc1.weight, self.mask3)
self.fc2.weight.data = torch.mul(self.fc2.weight, self.mask4)
out = self.conv1(x)
out = self.relu(self.bn1(out))
out = self.pool1(out)
out = self.conv2(out)
out = self.relu(self.bn2(out))
out = self.pool2(out)
out = out.reshape(out.size(0), -1)
out = self.relu(self.fc1(out))
out = self.relu(self.fc2(out))
out = self.fc3(out)
return out
- 利用迭代式剪枝方法进行非结构化剪枝
剪枝mask处理方法
def pruner(model, ratio):
weight_list = torch.Tensor()
for m in model.modules():
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
weight_list = torch.cat((weight_list.view(-1), m.weight.view(-1)))
weight_list = torch.abs(weight_list)
threshold = np.percentile(weight_list.detach().numpy(), ratio)
model.mask1 = torch.mul(torch.gt(torch.abs(model.conv1.weight), threshold).float(), model.mask1)
model.mask2 = torch.mul(torch.gt(torch.abs(model.conv2.weight), threshold).float(), model.mask2)
model.mask3 = torch.mul(torch.gt(torch.abs(model.conv3.weight), threshold).float(), model.mask3)
model.mask4 = torch.mul(torch.gt(torch.abs(model.fc1.weight), threshold).float(), model.mask4)
剪枝执行流程(加入l2损失)
def prune(model, train_loader, criterion, optimizer, ratio):
model.train()
for r in range(10, int(ratio*100)+10, 10): # ratio:0.9
pruner(model, r)
for i, (img, label) in enumerate(train_loader):
logits = model(img)
loss = criterion(logits, label)
l2 = 0
for param in model.parameters():
l2 += torch.norm(param, 2)
loss += 0.01*l2 # l = crosentropy + l2
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1)
top1_acc = torch.eq(pred, label).sum().item() / len(img)
return model
def main():
model = XXXNet()
pretrained_model = torch.load(xxx.pt')
model.load_state_dict(pretrained_model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
test(model,test_loader, criterion)
model = prune(model, train_loader, criterion, optimizer, 0.9)
model = finetune(model, train_loader, criterion, optimizer)
test(model, test_loader, criterion)
结构化剪枝实现
- 利用迭代式剪枝方法进行结构化剪枝(对每个通道求平均作为,权重系数)
def pruner(model, ratio):
importance_list = []
for i, m in enumerate(model.modules()):
if isinstance(m, nn.Linear):
importance_list.extend(torch.abs(m.weight).mean(1).detach().numpy()) #84, 120 ---> 84
elif isinstance(m, nn.Conv2d):
importance_list.extend(torch.abs(m.weight).mean(3).mean(2).mean(1).detach().numpy())
importance_list = np.array(importance_list)
threshold = np.percentile(importance_list, ratio)
mask = torch.gt(torch.from_numpy(np.array(importance_list)), threshold).float()
return mask
剪枝执行流程:
def prune(model, dataloader, criterion, optimizer, ratio):
model.train()
lam = 1e-2
for r in range(10, int(ratio*100), 10):
mask = pruner(model, r)
for i, (img, label) in enumerate(dataloader):
current_index = 0
"""inference"""
# logits = model(img)
"""
由于剪枝后,部分权重可被视为无效,因此为了避免这些权重参与网络的前馈和反馈过程,
我们在网络推理之前让网络各层的权重*对应的mask。本质上还是加mask的前先推导
"""
for m in model.modules():
if isinstance(m, nn.Linear):
img = img.view(img.size(0), -1)
m_mask = mask[current_index:current_index+m.out_features].view(1, -1)
img = m(img) * m_mask
current_index += m.out_features
elif isinstance(m, nn.Conv2d):
m_mask = mask[current_index:current_index + m.out_channels].view(1, -1, 1, 1)
img = m(img) * m_mask
current_index += m.out_channels
elif isinstance(m, nn.ReLU):
img = m(img)
elif isinstance(m, nn.MaxPool2d):
img = m(img)
l2_regularization = 0
loss = criterion(img, label)
for param in model.parameters():
l2_regularization += torch.norm(param, 2)
loss += lam * l2_regularization
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = nn.functional.softmax(img, dim=1).argmax(dim=1)
top1_acc = torch.eq(pred, label).sum().item() / len(img)
return mask, model
- 利用mask重新构建新的网络
out_list = []
current_index = 0
# mask:上步中的返回值
for m, name in model.state_dict().items():
if 'conv' in name:
out_channels = m.out_channels
current_mask = mask[current_index:current_index+out_channels]
out_list.append(sum(current_mask))
current_index += out_channels
elif 'fc' in name:
out_features = m.out_features
current_mask = mask[current_index:current_index+out_channels]
out_list.append(sum(current_mask))
current_index += out_features
class pruned_XXXNet(nn.Module):
def __init__(self, out_list, num_classes=10):
super(pruned_XXXNet, self).__init__()
self.conv1 = nn.Conv2d(1, out_list[0], kernel_size=5, stride=1, padding=2, bias=False)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(out_list[0], out_list[1], kernel_size=5, bias=False)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# avgpool 5,5 --> 1,1
self.fc1 = nn.Linear(int(out_list[1]*25), out_list[2], bias=False)
self.fc2 = nn.Linear(out_list[2], out_list[3], bias=False)
self.fc3 = nn.Sequential(nn.Linear(out_list[3], num_classes, bias=False))
self.relu = nn.ReLU()
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.pool1(out)
out = self.conv2(out)
out = self.relu(out)
out = self.pool2(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
- 点赞
- 收藏
- 关注作者
评论(0)