模型蒸馏简介
模型蒸馏是一种模型压缩和加速的技术,它的目的是将一个大型或复杂的模型(称为教师模型)的知识迁移到一个小型或简单的模型(称为学生模型)上,使学生模型能够达到与教师模型相近的性能,但占用更少的计算资源和内存空间。模型蒸馏的基本思想是让学生模型学习教师模型的输出概率分布,而不仅仅是数据标签。这样可以利用教师模型输出的暗知识(dark knowledge),即低概率类别之间的关系,来提高学生模型的泛化能力。
模型蒸馏的概念最早由Hinton等人在2015年提出,并在手写数字识别和语音识别等任务上验证了其有效性。许多研究者对模型蒸馏进行了扩展和改进,例如利用教师模型的中间层特征、注意力机制、对比损失等来增强知识迁移的效果。模型蒸馏也被应用于自然语言处理、计算机视觉和语音识别等领域,例如BERT、Faster R-CNN、ResNet等。
我们可以从两个方面去了解模型蒸馏:
- 对网络的什么地方的特征进行蒸馏?
- 对选择的特征蒸馏过程中,选择什么损失函数进行训练?
对网络的什么地方的特征进行蒸馏
一般来说,我们可以对网络的以下两个地方的特征进行蒸馏:
- 输出层:这是最常见和最简单的蒸馏方式,即让学生模型匹配教师模型的输出层特征,也就是logits或者softmax后的概率分布。这种方式可以捕捉教师模型对不同类别之间相似度和区分度的判断,从而提升学生模型的分类能力。
- 中间层:这是一种更深入和细致的蒸馏方式,即让学生模型匹配教师模型的中间层特征,也就是某些隐藏层或者注意力层的输出。这种方式可以捕捉教师模型对输入数据的抽象和表示能力,从而提升学生模型的表达能力。
模型蒸馏常用的损失函数有哪几种
KL散度损失:这是最基本的模型蒸馏损失函数,目的是让学生模型的输出概率分布尽可能接近教师模型的输出概率分布,从而学习教师模型的暗知识。具体地,假设教师模型的输出概率分布为 ,学生模型的输出概率分布为 ,则KL散度损失为
其中 表示类别索引。为了增加概率分布的平滑性和多样性,通常会在softmax层之前加入一个温度参数 ,使得输出概率分布为
其中 和 表示教师模型和学生模型的logits,即softmax层之前的输出。温度参数 可以控制概率分布的熵,当 时,相当于原始的softmax输出,当 时,概率分布变得更加平坦和均匀,当 时,概率分布变得更加尖锐和集中。
均方误差损失:这是一种直接匹配教师模型和学生模型的logits的损失函数,而不经过softmax层。它的优点是可以保留教师模型logits中的细微差异,而不受温度参数的影响。具体地,假设教师模型的logits为 ,学生模型的logits为 ,则均方误差损失为
其中 表示类别数, 表示类别索引。
对比损失:这是一种利用对比学习的思想来进行知识蒸馏的损失函数。它的目的是让教师模型和学生模型在特征空间中有相似的表达能力,即对同一类别的样本有较高的相似度,对不同类别的样本有较低的相似度。具体地,假设教师模型和学生模型在某一层输出的特征向量分别为 和 ,其中 表示输入样本,则对比损失为
其中 表示样本数, 表示向量内积,
Reference
- 点赞
- 收藏
- 关注作者
评论(0)