论文阅读 经典剪枝方法《Learning both Weights and Connections for Networks》

举报
lutianfei 发表于 2021/06/21 23:43:38 2021/06/21
【摘要】 动机:通过在训练过程中改变模型结构, 可以减少模型参数量, 降低模型计算量, 加快模型推理速度模型通过训练学习到的权重不仅可以用来推理出最后结果, 也可以当作判断神经元重要性的指标。同时训练权重和要剪的神经元可以降低损失。现有方法SVD和量化。用池化层代替全连接层。其他剪枝方法。 如基于海森矩阵的方法。本文方法在数据集上对未压缩的模型进行完整的训练, 直至收敛。在经过完整训练的未压缩模型基础...

阅读大纲

image.png

动机

  1. 通过在训练过程中改变模型结构, 可以减少模型参数量, 降低模型计算量, 加快模型推理速度
    image.png

  2. 模型通过训练学习到的权重不仅可以用来推理出最后结果, 也可以当作判断神经元重要性的指标。
    image.png

  3. 同时训练权重和要剪的神经元可以降低损失。
    image.png

现有方法

  1. SVD和量化。
  2. 用池化层代替全连接层。
  3. 其他剪枝方法。 如基于海森矩阵的方法。

image.png

本文方法

  1. 在数据集上对未压缩的模型进行完整的训练, 直至收敛。
  2. 在经过完整训练的未压缩模型基础上, 剪掉一批权重小于阈值的神经元。
  3. 在数据集上对剪过后的模型进行权重微调, 使其恢复精度。

研究成果
在Imagenet数据集上, 本文的方法在将AlexNet压缩了9倍参数的情况下, 精度基本保持不变。
image.png

本文的研究意义

• 将剪枝引入模型训练过程, 降低了剪枝导致的精度损失
• 将神经元的权重值作为判断神经元重要性的依据, 简化了神经元重要性计算的过程。
• 在实验中讨论了各种正则化对剪枝过程的影响。

论文泛读

Abstract

说明神经网络计算量大以及占内存的问题。 为了解决该问题, 提出了本文的剪枝方法, 并在imagenet上进行了验证。

核心内容

  1. 要解决的问题:神经网络计算量太大,占用内存太多。
  2. 解决问题的办法:同时学习权重和连接。
  3. 证明办法能解决问题:在ImageNet数据集上对AlexNet进行了9倍压缩且基本不损失精度。

Introduction

用更多的数据和例子说明神经网络计算量大以及占内存的问题。 提出本文的方法。

Related Work

模型压缩的相关工作

Learning Connections in Addition to Weights

具体介绍本文的剪枝方法以及用到的Tricks

Experiments

在MNIST,ImageNet数据集上进行实验验证

Conclusion

论文总结

模型参数量和计算量

什么是参数量

参数量就是指, 模型所有带参数的层的权重参数总量。 视觉类网络组件中带参数的层, 主要有: 卷积层BN层全连接层等。 ( 注意: 激活函数层(relu等)和Maxpooling层、 Upsample层是没有参数的不需要学习, 他们只是提供了一种非线性的变换)

卷积层参数量: K K C i C o + C o K*K*Ci*Co+Co

BN层参数量: 2 C i 2*Ci

全连接层参数量: C i C o + C o Ci*Co + Co

什么是FLOPS、FLOPs

FLOPS(即“每秒浮点运算次数”,“每秒峰值速度”)。是“每秒所执行的浮点运算次数”(floating-point operations per second)的缩写。 它常被用来估算电脑的执行效能,尤其是在使用到大量浮点运算的科学计算领域中。

FLOPs:floating point of operations 的缩写,即浮点运算次数,用来衡量算法、模型复杂度

卷积层FLOPs: ( 2 C i K K 1 ) H W C o (2*C_i*K*K-1)*H*W*C_o

全连接层FLOPs: ( 2 C i 1 ) C o (2*C_i -1)*C_o

什么是MAC

MAC(Memory Access Cost,内存访问成本),计算机在进行计算时候要加载到缓存中,然后再计算,这个加载过程是需要时间的。其中,分组卷积(group convolution)是对MAC消耗比较多的操作。

卷积层(输入+输出+权重参数量): H i W i C i + H o W o C o + C o C i K K Hi*Wi*Ci + Ho*Wo*Co + Co*Ci*K*K

全连接层(输入+输出+权重参数量): X i C i + X o C o + C i C o Xi*Ci + Xo*Co + Ci*Co

在评估模型性能的时候要综合考虑FLOPs和MAC

神经网络剪枝(论文精读)

剪枝从何而来?
人工神经网络中的剪枝受启发于人脑中的突触修剪( Synaptic Pruning) 。 突触修剪即轴突和树突完全衰退和死亡,是许多哺乳动物幼年期和青春期间发生的突触消失过程。 突触修剪从人出生时就开始了, 一直持续到 20 多岁。

image.png

神经网络中的剪枝
神经网络通常如图左所示:下层中的每个神经元与上一层有连接,但这意味着我们必须进行大量浮点相乘操作。完美情况下,我们只需将每个神经元与几个其他神经元连接起来,不用进行其他浮点相乘操作,这叫做稀疏网络。

image.png

非结构化剪枝(实用价值不高)

非结构化剪枝是把每个单一的权重设为0(并没有减少参数量,要搭配CSC/CSR等类似方法)。因此若矩阵含有几乎不重要的行跟列,非结构化剪枝并没有办法完全把它移除掉(矩阵都是几百维的,不太可能刚好都被剪为0),只能把它大部分的权重设为0。

结构化剪枝(实用性更好)

结构化剪枝是把整行整列的权重移除掉(即把一个神经元去掉)。若在一个重要的行跟列里有些许不重要的权重,结构化剪枝没办法把它设为0。

如何判断神经元的重要性

  • L1 Norm

x 1 = i = 1 x i ||x||_1 = \sum_{i=1}^{\infty}|x_i|

image.png

  • L2 Norm

x 2 = i = 1 x i 2 ||x||_2 = \sqrt{\sum_{i=1}^{\infty}|x^2_i|}

三步剪枝方法

  1. 在数据集上从头训练一个网络,获得连接权重。
  2. 根据规则剪掉一定数量的神经元。
  3. 将剪完后的神经网络继续在数据集上进行微调,使其恢复正确率。
  • 为何要迭代多次?
    image.png

L1和L2正则化

  • 为什么要在训练过程中加入正则化?
    设目标方程为: y = c 1 x 1 + c 2 x 2 + c 3 x 3 y=c_1x_1 + c_2x_2 + c_3x_3
    对c2权重剪枝后:y^' - y = \deltac_2x_2
    加入正则化后,使得 c 2 c_2 本就接近于0,则可大大加速剪枝训练速度。

image.png

  • L1和L2正则化的区别

L1 正则化特征
image.png
下图圆圈为Loss函数的优化区域,正方形为正则化优化区域,要保证优化方向既在圆圈中又在正方形中。
image.png

L2 正则化特征
image.png

  • L1、L2正则化对剪枝结果的影响
  1. 只剪一次且没有重训练的情况下,L1效果比L2好。
  2. 只剪一次且有重训练的情况下,L2效果比L1好。

image.png

Dropout ratio 自适应

image.png

实验结果分析

MINIST数据集上效果对比

image.png

由以下数据得出的结论
(1) 靠后的全连接层被保留了更多的参数。
(2) 前面的卷积层比后面 的卷积层更重要

image.png

在剪枝后,对应图像中间区域的参数被保留的更多,说明该剪枝方法能保留重要参数

image.png

ImageNet数据集在AlexNet上的表现

在ImageNet上,经过9倍压缩的AlexNet错误率 不升反降。 在AlexNet上,同样前面的卷积层比靠后的卷积 层更加重要,除了最后一层全连接层,前面的全 连接层几乎可以被忽略。
image.png

下图中,左边是卷积层剪枝敏感性曲线 右边是全连接层剪枝敏感性曲线
结论:在图像任务上,卷积层相较于全连接层非 常非常重要。
image.png

ImageNet数据集在Vgg上的表现

image.png

在ImageNet上,经过13倍压缩的VGG-16错误率 同样不升反降。 在VGG-16上,总体也符合前面卷积层比靠后的 卷积层重要的结论,除了最后一层全连接层,前 面的全连接层几乎可以被忽略。

整体总结

关键点

  • 影响神经网络剪枝后性能的原因是什么。
  • 如何更加准确的判断神经网络中神经元的重要性。

创新点

  • 将迭代剪枝方法引入神经网络剪枝过程,提升神经网络剪枝后的模型性能。
  • 在剪枝训练过程中加入正则化, 帮助提升剪枝速度。
  • 在多个网络上完成了剪枝工作并获得了非常好的效果

一些实现细节

非结构化剪枝实现(并未真正减少参数与计算量)

image.png

  1. 构建如下示例网络,重点是对每个卷积层和全连接层加入一个权重mask
class XXXNet(nn.Module):
    def __init__(self, num_classes=xx):
        super(XXXNet, self).__init__()

        self.conv1 = nn.Conv2d()
        self.mask1 = torch.ones_like(self.conv1.weight)

        self.conv2 = nn.Conv2d()
        self.mask2 = torch.ones_like(self.conv2.weight)

        self.fc1 = nn.Linear(400, 120, bias=False)
        self.mask3 = torch.ones_like(self.fc1.weight)

        self.fc2 = nn.Linear(120, 84, bias=False)
        self.mask4 = torch.ones_like(self.fc2.weight)
        ...

    def forward(self, x):
        self.conv1.weight.data = torch.mul(self.conv1.weight, self.mask1)
        self.conv2.weight.data = torch.mul(self.conv2.weight, self.mask2)
        self.fc1.weight.data = torch.mul(self.fc1.weight, self.mask3)
        self.fc2.weight.data = torch.mul(self.fc2.weight, self.mask4)

        out = self.conv1(x)
        out = self.relu(self.bn1(out))
        out = self.pool1(out)

        out = self.conv2(out)
        out = self.relu(self.bn2(out))
        out = self.pool2(out)
        out = out.reshape(out.size(0), -1)
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out
  1. 利用迭代式剪枝方法进行非结构化剪枝

剪枝mask处理方法

def pruner(model, ratio):
    weight_list = torch.Tensor()
    for m in model.modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            weight_list = torch.cat((weight_list.view(-1), m.weight.view(-1)))
    weight_list = torch.abs(weight_list)
    threshold = np.percentile(weight_list.detach().numpy(), ratio)
    model.mask1 = torch.mul(torch.gt(torch.abs(model.conv1.weight), threshold).float(), model.mask1)
    model.mask2 = torch.mul(torch.gt(torch.abs(model.conv2.weight), threshold).float(), model.mask2)
    model.mask3 = torch.mul(torch.gt(torch.abs(model.conv3.weight), threshold).float(), model.mask3)
    model.mask4 = torch.mul(torch.gt(torch.abs(model.fc1.weight), threshold).float(), model.mask4)

剪枝执行流程(加入l2损失)

def prune(model, train_loader, criterion, optimizer, ratio):
    model.train()
    for r in range(10, int(ratio*100)+10, 10):    # ratio:0.9
        pruner(model, r)
        for i, (img, label) in enumerate(train_loader):
            logits = model(img)
            loss = criterion(logits, label)
            l2 = 0
            for param in model.parameters():
                l2 += torch.norm(param, 2)
            loss += 0.01*l2              # l = crosentropy + l2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1)
            top1_acc = torch.eq(pred, label).sum().item() / len(img)
    return model

def main():
    model = XXXNet()
    pretrained_model = torch.load(xxx.pt')
    model.load_state_dict(pretrained_model)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    test(model,test_loader, criterion)
    model = prune(model, train_loader, criterion, optimizer, 0.9)
    model = finetune(model, train_loader, criterion, optimizer)
    test(model, test_loader, criterion)

结构化剪枝实现

  1. 利用迭代式剪枝方法进行结构化剪枝(对每个通道求平均作为,权重系数)
def pruner(model, ratio):
    importance_list = []

    for i, m in enumerate(model.modules()):
        if isinstance(m, nn.Linear):
            importance_list.extend(torch.abs(m.weight).mean(1).detach().numpy())    #84, 120 ---> 84
        elif isinstance(m, nn.Conv2d):
            importance_list.extend(torch.abs(m.weight).mean(3).mean(2).mean(1).detach().numpy())
    importance_list = np.array(importance_list)
    threshold = np.percentile(importance_list, ratio)
    mask = torch.gt(torch.from_numpy(np.array(importance_list)), threshold).float()
    return mask

剪枝执行流程:

def prune(model, dataloader, criterion, optimizer, ratio):
    model.train()
    lam = 1e-2
    for r in range(10, int(ratio*100), 10):
        mask = pruner(model, r)
        for i, (img, label) in enumerate(dataloader):
            current_index = 0
            """inference"""
            # logits = model(img)
            
            """
            由于剪枝后,部分权重可被视为无效,因此为了避免这些权重参与网络的前馈和反馈过程,
            我们在网络推理之前让网络各层的权重*对应的mask。本质上还是加mask的前先推导
            """
            for m in model.modules():
                if isinstance(m, nn.Linear):
                    img = img.view(img.size(0), -1)
                    m_mask = mask[current_index:current_index+m.out_features].view(1, -1)
                    img = m(img) * m_mask
                    current_index += m.out_features
                elif isinstance(m, nn.Conv2d):
                    m_mask = mask[current_index:current_index + m.out_channels].view(1, -1, 1, 1)
                    img = m(img) * m_mask
                    current_index += m.out_channels
                elif isinstance(m, nn.ReLU):
                    img = m(img)
                elif isinstance(m, nn.MaxPool2d):
                    img = m(img)
            l2_regularization = 0
            loss = criterion(img, label)
            for param in model.parameters():
                l2_regularization += torch.norm(param, 2)
            loss += lam * l2_regularization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pred = nn.functional.softmax(img, dim=1).argmax(dim=1)
            top1_acc = torch.eq(pred, label).sum().item() / len(img)
    return mask, model
  1. 利用mask重新构建新的网络
out_list = []
current_index = 0
# mask:上步中的返回值
for m, name in model.state_dict().items():
    if 'conv' in name:
        out_channels = m.out_channels
        current_mask = mask[current_index:current_index+out_channels]
        out_list.append(sum(current_mask))
        current_index += out_channels
    elif 'fc' in name:
        out_features = m.out_features
        current_mask = mask[current_index:current_index+out_channels]
        out_list.append(sum(current_mask))
        current_index += out_features
        
class pruned_XXXNet(nn.Module):
    def __init__(self, out_list, num_classes=10):
        super(pruned_XXXNet, self).__init__()

        self.conv1 = nn.Conv2d(1, out_list[0], kernel_size=5, stride=1, padding=2, bias=False)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(out_list[0], out_list[1], kernel_size=5, bias=False)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # avgpool 5,5 --> 1,1
        self.fc1 = nn.Linear(int(out_list[1]*25), out_list[2], bias=False)

        self.fc2 = nn.Linear(out_list[2], out_list[3], bias=False)

        self.fc3 = nn.Sequential(nn.Linear(out_list[3], num_classes, bias=False))
        self.relu = nn.ReLU()

    def forward(self, x):

        out = self.conv1(x)
        out = self.relu(out)
        out = self.pool1(out)

        out = self.conv2(out)
        out = self.relu(out)
        out = self.pool2(out)

        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)

        out = self.fc2(out)
        out = self.relu(out)

        out = self.fc3(out)
        return out
【版权声明】本文为华为云社区用户翻译文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容, 举报邮箱:cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。