ArcLoss
【摘要】 https://github.com/yanchangqin/Centrloss-ArcLoss/blob/master/ArcLoss.py
class Arc_Loss(nn.Module): def __init__(self,feature_num,cls_num,m =0.1,s=64): super().__init__() self.s = s ...
https://github.com/yanchangqin/Centrloss-ArcLoss/blob/master/ArcLoss.py
-
class Arc_Loss(nn.Module):
-
def __init__(self,feature_num,cls_num,m =0.1,s=64):
-
super().__init__()
-
self.s = s
-
self.m = m
-
self.W=nn.Parameter(torch.randn(feature_num,cls_num))
-
-
def forward(self, feature):
-
_w = F.normalize(self.W,dim=0)
-
_x = F.normalize(feature,dim=1)
-
cosa = (torch.matmul(_x,_w)/10)
-
a = torch.acos(cosa)
-
-
top = torch.exp(torch.cos(a+self.m)*self.s)
-
_top = torch.exp(torch.cos(a)*self.s)
-
bottom = torch.sum(_top,dim=1,keepdim=True)
-
-
# sina = torch.sqrt(1-torch.pow(cosa,2))
-
# cosm = torch.cos(torch.tensor(self.m)).cuda()
-
# sinm = torch.cos(torch.tensor(self.m)).cuda()
-
# cosa_m =cosa*cosm-sina*sinm
-
# top =torch.exp(cosa_m*sel
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/103811042
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)