Knowledge Distillation 知识蒸馏学习总结
Knowledge Distillation 知识蒸馏
研究背景
动机
- 网络训练与应用之间的矛盾:在训练阶段,网络可以通过对同一数据集进行多次学习得到多个不同的模型,并将各模型的预测结果加权作为最终输出这一集成的方式来提高任务性能。然而这一方法在网络实际应用中非常耗时耗力,不利于部署。
2. 已经有工作证明,将集成模型中的知识(knowledge)压缩到方便部署的单个模型中是可行的。可以从大模型所学习到的知识信息(knowledge)中学习有用信息来训练小模型,在保证性能相近的情况下进行模型压缩。
现有方法
- 模型压缩:在已经训练好的模型上进行压缩,使得网络携带更少的网络参数。
- 直接训练一个小型网络:从改变网络结构出发,设计出更高效的网络计算方式,从而使网络参数减少 的同时,不损失网络的性能。
本文方法
- 提出蒸馏(distilling)的思想,将大模型中学习到的知识迁移到单一小模型中,在保证精度的基础上压缩模型。
- 利用大模型生成的类别概率作为softtargets,待压缩模型自身的类别作为hardtargets,两者结合共同训练待压缩模型
知识蒸馏的意义
- 知识蒸馏开创了模型压缩中的一个新方向,是蒸馏领域的开山之作
- 蒸馏思想在NLP,CV等领域均有成功应用,证明了该方法的有效性和普适性,能够将大模型的知识迁移到小模型中,使得小模型拥有大模型的能力
- 作为轻量化方向的一个重要分支,推动了轻量化网络的理论研究和应用落地
摘要核心
- 集成模型在训练和部署 应用间存在矛盾
- 已有工作证明模型压缩的可行性
- 在模型压缩领域提出知识蒸馏的方法,将大的集成模型的知识蒸馏到单一小模型中
- 介绍了一种针对大数据集的集成专家模型
知识蒸馏思想
知识蒸馏的主要思想是训练一个小的网络模型来模仿一个预先训练好的大型网络或者集成的网络。这种训练模式又被称为 “teacher-student”,大型的网络是“教师网络”,小型的网络是“学生网络”。
知识蒸馏期望让学生网络在拥有更少参数量,更小规模的情况下,达到与教师网络相似甚至超越教师网络的精度。
什么是Knowledge
模型的参数信息保留了模型学到的知识。学习如何从输入向量映射到输出向量
[cow, dog, cat, car]
Labels:[0,1,0,0]
Predictions:[0.05,0.3,0.2,0.005]
可以看出,将图像分类为母牛的概率是将其分类为汽车的概率的10倍。 Hinton在其论文中首先描述的正是这种knowledge,需要从教师网络向学生中蒸馏。
知识蒸馏方法
神经网络预测过程:
输入图像送入多层卷积神经网络,提取特征
拉伸卷积层,送入全连接层
多层全连接层得到logits Zi
Logits经过softmax得到预测概率
在知识蒸馏中,教师网络将知识传授给学生网络的方法是:
- 教师网络经过训练输出一个类别概率分布
- 学生网络以教师网络的输出预测为指导,输出一个类别概率分布
- 设计学生网络的损失函数,最小化以上两个概率分布之间的差距
训练一个学生网络模型来模仿一个预先训练好的教师网络模型预测输出概率分布
问题:教师网络softmax层的输出结果,正确的分类的概率值非常大,而其他类别的概率值几乎接近于0。 这种结果会忽视其他类别概率中包含的有用信息, 没有利用到老师强大的泛化性能。因此需要引入温度参数。
温度参数T(Temperature)
为了从教师网络中蒸馏出更多,更丰富的信息,引入温度参数T的概念, T越大,网络输出类别概率分布越“ soft”,学生网络越能从教师网络中学到更丰富的knowledge。即拉小概率分布之间的变化。
T越高,概率分布越平坦。
横坐标:Zi,网络送入softmax层的输入值
纵坐标:经升温蒸馏后的预测概率。
- 当温度T越高的时候,soft targets越平坦,信息不会全部集中在正确类别上,
这样增大温度参数T相当于放大(蒸馏出来)这些非正确类别所携带的信息,这些信息同样非常重要。
- 无论温度T取值如何,Soft targets都有忽略小的zi所携带信息的倾向。
引入温度参数T,将原始softmax层输出的概率预测分布软化,得到soft targets
类别标签作为hard targets
Soft targets可以在保证预测正确的情况下,尽可能从非正确预测中提取有用信息, 供学生网络学习。
知识蒸馏方法
蒸馏过程
Input:输入数据
Teacher model:教师网络模型
Student model:学生网络模型 (待蒸馏模型)
Soft targets:教师网络经升温后 的softmax预测类别输出
soft predictions:学生网络经升 温后softmax预测类别输出
hard prediction:温度系数为1, 即学生网络原始softmax预测输出
hard targets:输入数据真实标签。
具体流程:
- 教师网络训练。首先利用数据训练 一个层数更深,提取能力更强的教 师网络,得到logits后,利用升温T 的softmax得到预测类别概率分布 soft targets
- 蒸馏教师网络知识到学生网络。构造distillation loss和student loss,加权相加作为最终的损失函数。
distillation loss: 将教师网络输出 logits温度为T的蒸馏,经过softmax层 之后得到类别预测概率分布,作为soft targets,同时,学生网络输出logits经 过相同温度T进行蒸馏,经过softmax层 之后得到类别预测概率,作为soft predictions。蒸馏的目的是让学生网络 的类别输出预测分布尽可能拟合教师网 络输出预测分布
student loss: 教师网络也有一定的错误 率,使用真实标签作为hard targets, 可以有效阻止教师网络中的错误信息被 蒸馏到学生网络中。
蒸馏的特殊形式,直接利用logits
前提:
- 所有logits对每个样本都是零均值化即 = = 0
- 温度T趋近于无穷大
- 点赞
- 收藏
- 关注作者
评论(0)