pytorch CenterLoss

举报
风吹稻花香 发表于 2021/06/05 22:13:02 2021/06/05
【摘要】   原文:https://github.com/jxgu1016/MNIST_center_loss_pytorch   c++不知道什么框架的: https://github.com/BOBrown/SSD-Centerloss # coding: utf8import torchfrom torch.autograd import Varia...

 

原文:https://github.com/jxgu1016/MNIST_center_loss_pytorch

 

c++不知道什么框架的:

https://github.com/BOBrown/SSD-Centerloss


  
  1. # coding: utf8
  2. import torch
  3. from torch.autograd import Variable
  4. class CenterLoss(torch.nn.Module):
  5. def __init__(self, num_classes, feat_dim, loss_weight=1.0):
  6. super(CenterLoss, self).__init__()
  7. self.num_classes = num_classes
  8. self.feat_dim = feat_dim
  9. self.loss_weight = loss_weight
  10. self.centers = torch.nn.Parameter(torch.randn(num_classes, feat_dim))
  11. self.use_cuda = False
  12. def forward(self, y, feat):
  13. if self.use_cuda:
  14. hist = Variable(
  15. torch.histc(y.cpu().data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1).cuda()
  16. else:
  17. hist = Variable(torch.histc(y.data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1)
  18. centers_count = hist.index_select(0, y.long()) # 计算每个类别对应的数目
  19. batch_size = feat.size()[0]
  20. feat = feat.view(batch_size, 1, 1, -1).squeeze()
  21. if feat.size()[1] != self.feat_dim:
  22. raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,
  23. feat.size()[1]))
  24. centers_pred = self.centers.index_select(0, y.long())
  25. diff = feat-centers_pred
  26. loss = self.loss_weight * 1/2.0 * (diff.pow(2).sum(1) / centers_count).sum()
  27. return loss
  28. def cuda(self, device_id=None):
  29. self.use_cuda = True
  30. return self._apply(lambda t: t.cuda(device_id))

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/90480434

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。