Caffe Softmax 层的实现原理【细节补充】
本文是看了知乎的这篇文章以后觉得作者写的很好,但是有些细节讲解得不够详细,回复里面大家也多有疑问,特加以补充:
为了对原作者表示尊重和感谢,先注明原作出处:
作者:John Wang
链接:https://www.zhihu.com/question/28927103/answer/78810153
作者原文和我的补充
====================================
设 z 是 softmax loss 层的输入,f(z)是 softmax 的输出,即
y 是输入样本 z 对应的类别,y=0,1,...,N
对于 z ,其损失函数定义为
展开上式:
对上式求导,有
梯度下降方向即为
====================================
增加关于 softmax 层的反向传播说明
设 softmax 的输出为 a ,输入为 z ,损失函数为 l
则
其中
在 caffe 中是 top_diff,a 为 caffe 中得 top_data,需要计算的是
if i!=k
if i==k
【我的补充】
----------------------------------------------------------------
当 i!=k 时,
当 i==k 时,
----------------------------------------------------------------
于是
【我的补充】
----------------------------------------------------------------
把负号提出去,改为点乘,即得到上式。注意,这里的 n 表示 channels,这里的 k 和 caffe 源码中的 k 含义不同。
----------------------------------------------------------------
整理一下得到
其中表示将标量扩展为 n 维向量,表示向量按元素相乘
【我的补充】
----------------------------------------------------------------
这边作者讲解得有误,因为对照代码可以发现,点乘后其实得到的是 1*inner_num 大小的向量,所以为了对应通道相减,需要将其扩展为 channels*inner_num 的矩阵,而不是 n 维向量。
最后矩阵再按元素进行相乘。
对照 caffe 源码
-
-
// top_diff : l 对 a 向量求偏导
-
// top_data :a 向量
-
// 将 top_diff 拷贝到 bottom_diff
-
// dim = channels * inner_num_
-
// inner_num_ = height * width
-
caffe_copy(top[0]->count(), top_diff, bottom_diff);
-
// 遍历一个 batch 中的样本
-
for (int i = 0; i < outer_num_; ++i) {
-
// compute dot(top_diff, top_data) and subtract them from the bottom diff
-
// 此处计算两个向量的点积,注意 top_diff 已经拷贝到 bottom_diff 当中
-
// 步长为 inner_num_(跨通道)构造一个长度为 channels (类别个数)的向量,进行点乘
-
for (int k = 0; k < inner_num_; ++k) {
-
scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
-
bottom_diff + i * dim + k, inner_num_,
-
top_data + i * dim + k, inner_num_);
-
}
-
// subtraction
-
// 此处计算大括号内的减法(即负号)
-
// 将 scale_data 扩展为 channels 个通道(多少个类别),再和 bottom_diff 对应的通道相减
-
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,
-
-1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
-
}
-
// elementwise multiplication
-
// 元素级的乘法
-
// 此处计算大括号外和 a 向量的乘法
-
caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff)
文章来源: panda1234lee.blog.csdn.net,作者:panda1234lee,版权归原作者所有,如需转载,请联系作者。
原文链接:panda1234lee.blog.csdn.net/article/details/82459595
- 点赞
- 收藏
- 关注作者
评论(0)