模型蒸馏简介

举报
maxloop 发表于 2023/04/27 20:11:19 2023/04/27
【摘要】 模型蒸馏是一种模型压缩和加速的技术,它的目的是将一个大型或复杂的模型(称为教师模型)的知识迁移到一个小型或简单的模型(称为学生模型)上,使学生模型能够达到与教师模型相近的性能,但占用更少的计算资源和内存空间。模型蒸馏的基本思想是让学生模型学习教师模型的输出概率分布,而不仅仅是数据标签。这样可以利用教师模型输出的暗知识(dark knowledge),即低概率类别之间的关系,来提高学生模型的泛...

模型蒸馏是一种模型压缩和加速的技术,它的目的是将一个大型或复杂的模型(称为教师模型)的知识迁移到一个小型或简单的模型(称为学生模型)上,使学生模型能够达到与教师模型相近的性能,但占用更少的计算资源和内存空间。模型蒸馏的基本思想是让学生模型学习教师模型的输出概率分布,而不仅仅是数据标签。这样可以利用教师模型输出的暗知识(dark knowledge),即低概率类别之间的关系,来提高学生模型的泛化能力。

模型蒸馏的概念最早由Hinton等人在2015年提出,并在手写数字识别和语音识别等任务上验证了其有效性。许多研究者对模型蒸馏进行了扩展和改进,例如利用教师模型的中间层特征、注意力机制、对比损失等来增强知识迁移的效果。模型蒸馏也被应用于自然语言处理、计算机视觉和语音识别等领域,例如BERT、Faster R-CNN、ResNet等。

我们可以从两个方面去了解模型蒸馏:

  • 对网络的什么地方的特征进行蒸馏?
  • 对选择的特征蒸馏过程中,选择什么损失函数进行训练?

对网络的什么地方的特征进行蒸馏

一般来说,我们可以对网络的以下两个地方的特征进行蒸馏:

  • 输出层:这是最常见和最简单的蒸馏方式,即让学生模型匹配教师模型的输出层特征,也就是logits或者softmax后的概率分布。这种方式可以捕捉教师模型对不同类别之间相似度和区分度的判断,从而提升学生模型的分类能力。
  • 中间层:这是一种更深入和细致的蒸馏方式,即让学生模型匹配教师模型的中间层特征,也就是某些隐藏层或者注意力层的输出。这种方式可以捕捉教师模型对输入数据的抽象和表示能力,从而提升学生模型的表达能力。

模型蒸馏常用的损失函数有哪几种

KL散度损失:这是最基本的模型蒸馏损失函数,目的是让学生模型的输出概率分布尽可能接近教师模型的输出概率分布,从而学习教师模型的暗知识。具体地,假设教师模型的输出概率分布为 y t y_t ,学生模型的输出概率分布为 y s y_s ,则KL散度损失为

L soft = i y t , i log y t , i y s , i L_{\text{soft}} = \sum_i y_{t,i} \log \frac{y_{t,i}}{y_{s,i}}

其中 i i 表示类别索引。为了增加概率分布的平滑性和多样性,通常会在softmax层之前加入一个温度参数 T T ,使得输出概率分布为

y t , i = exp ( z t , i / T ) j exp ( z t , j / T ) , y s , i = exp ( z s , i / T ) j exp ( z s , j / T ) y_{t,i} = \frac{\exp(z_{t,i}/T)}{\sum_j \exp(z_{t,j}/T)}, \quad y_{s,i} = \frac{\exp(z_{s,i}/T)}{\sum_j \exp(z_{s,j}/T)}

其中 z t z_t z s z_s 表示教师模型和学生模型的logits,即softmax层之前的输出。温度参数 T T 可以控制概率分布的熵,当 T = 1 T=1 时,相当于原始的softmax输出,当 T > 1 T>1 时,概率分布变得更加平坦和均匀,当 T < 1 T<1 时,概率分布变得更加尖锐和集中。

均方误差损失:这是一种直接匹配教师模型和学生模型的logits的损失函数,而不经过softmax层。它的优点是可以保留教师模型logits中的细微差异,而不受温度参数的影响。具体地,假设教师模型的logits为 z t z_t ,学生模型的logits为 z s z_s ,则均方误差损失为

L mse = 1 n i ( z t , i z s , i ) 2 L_{\text{mse}} = \frac{1}{n} \sum_i (z_{t,i} - z_{s,i})^2

其中 n n 表示类别数, i i 表示类别索引。

对比损失:这是一种利用对比学习的思想来进行知识蒸馏的损失函数。它的目的是让教师模型和学生模型在特征空间中有相似的表达能力,即对同一类别的样本有较高的相似度,对不同类别的样本有较低的相似度。具体地,假设教师模型和学生模型在某一层输出的特征向量分别为 f t ( x ) f_t(x) f s ( x ) f_s(x) ,其中 x x 表示输入样本,则对比损失为

L crd = 1 n i log exp ( f t ( x i ) , f s ( x i ) / τ ) j exp ( f t ( x i ) , f s ( x j ) / τ ) L_{\text{crd}} = -\frac{1}{n} \sum_i \log \frac{\exp(\langle f_t(x_i), f_s(x_i) \rangle / \tau)}{\sum_j \exp(\langle f_t(x_i), f_s(x_j) \rangle / \tau)}

其中 n n 表示样本数, , \langle \cdot, \cdot \rangle 表示向量内积,

Reference

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。