pytorch focalloss
【摘要】
import torch gamma = torch.ones_like(focal_weight).cuda() gamma[focal_weight > 0.5] = 0.4 gamma[focal_weight < 0.5] = 2.2 focal_weight = alpha_factor * torch.pow(focal_wei...
-
import torch
-
-
gamma = torch.ones_like(focal_weight).cuda()
-
gamma[focal_weight > 0.5] = 0.4
-
gamma[focal_weight < 0.5] = 2.2
-
-
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
FocalLoss_cls 可以替代交叉熵 nn.CrossEntropyLoss()
-
import torch
-
import torch.nn as nn
-
import torch.nn.functional as torch_F
-
-
-
class FocalLoss(nn.Module):
-
-
def __init__(self, gamma=0):
-
super(FocalLoss, self).__init__()
-
self.gamma = gamma
-
# self.ce = torch.nn.CrossEntropyLoss(reduction='sum')
-
self.ce = torch.nn.BCELoss(reduction='sum')
-
-
-
def forward(self, input, target):
-
logp = self.ce(input, target)
-
p = torch.exp(-logp)
-
loss = (1 - p) ** self.gamma * logp
-
re
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/105578908
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)