pytorch CenterLoss
【摘要】
原文:https://github.com/jxgu1016/MNIST_center_loss_pytorch
c++不知道什么框架的:
https://github.com/BOBrown/SSD-Centerloss
# coding: utf8import torchfrom torch.autograd import Varia...
原文:https://github.com/jxgu1016/MNIST_center_loss_pytorch
c++不知道什么框架的:
https://github.com/BOBrown/SSD-Centerloss
-
# coding: utf8
-
import torch
-
from torch.autograd import Variable
-
-
-
class CenterLoss(torch.nn.Module):
-
def __init__(self, num_classes, feat_dim, loss_weight=1.0):
-
super(CenterLoss, self).__init__()
-
self.num_classes = num_classes
-
self.feat_dim = feat_dim
-
self.loss_weight = loss_weight
-
self.centers = torch.nn.Parameter(torch.randn(num_classes, feat_dim))
-
self.use_cuda = False
-
-
def forward(self, y, feat):
-
if self.use_cuda:
-
hist = Variable(
-
torch.histc(y.cpu().data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1).cuda()
-
else:
-
hist = Variable(torch.histc(y.data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1)
-
-
centers_count = hist.index_select(0, y.long()) # 计算每个类别对应的数目
-
-
batch_size = feat.size()[0]
-
feat = feat.view(batch_size, 1, 1, -1).squeeze()
-
if feat.size()[1] != self.feat_dim:
-
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,
-
feat.size()[1]))
-
centers_pred = self.centers.index_select(0, y.long())
-
diff = feat-centers_pred
-
loss = self.loss_weight * 1/2.0 * (diff.pow(2).sum(1) / centers_count).sum()
-
return loss
-
-
def cuda(self, device_id=None):
-
self.use_cuda = True
-
return self._apply(lambda t: t.cuda(device_id))
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/90480434
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)