语义分割常用Loss Pytorch版
【摘要】 语义分割中常见的Loss,Pytorch版本实现。
CrossEntropyLoss
交叉熵
import torch import torch.nn as nn criterion = nn.CrossEntropyLoss() loss = criterion(input, target)
Dice loss
Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。公式如下:
import torch import torch.nn as nn class DiceLoss(nn.Module): def __init__(self): super(DiceLoss, self).__init__() def forward(self, input, target): N = target.size(0) smooth = 1 input_flat = input.view(N, -1) target_flat = target.view(N, -1) intersection = input_flat * target_flat loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) loss = 1 - loss.sum() / N return loss class MulticlassDiceLoss(nn.Module): def __init__(self): super(MulticlassDiceLoss, self).__init__() def forward(self, input, target, weights=None): C = target.shape[1] dice = DiceLoss() totalLoss = 0 for i in range(C): diceLoss = dice(input[:,i], target[:,i]) if weights is not None: diceLoss *= weights[i] totalLoss += diceLoss return totalLoss
IOU loss
import torch import torch.nn as nn import torch.nn.functional as F class SoftIoULoss(nn.Module): def __init__(self, n_classes): super(SoftIoULoss, self).__init__() self.n_classes = n_classes @staticmethod def to_one_hot(tensor, n_classes): n, h, w = tensor.size() one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) return one_hot def forward(self, input, target): # logit => N x Classes x H x W # target => N x H x W N = len(input) pred = F.softmax(input, dim=1) target_onehot = self.to_one_hot(target, self.n_classes) # Numerator Product inter = pred * target_onehot # Sum over all pixels N x C x H x W => N x C inter = inter.view(N, self.n_classes, -1).sum(2) # Denominator union = pred + target_onehot - (pred * target_onehot) # Sum over all pixels N x C x H x W => N x C union = union.view(N, self.n_classes, -1).sum(2) loss = inter / (union + 1e-16) # Return average loss over classes and batch return -loss.mean()
OhemCrossEntropy2d
OHEM(online hard example miniing)算法的核心思想是:
根据输入样本的损失进行筛选,筛选出hard example,表示对分类和检测影响较大的样本,然后将筛选得到的这些样本应用在随机梯度下降中训练。
class OhemCrossEntropy2d(nn.Module): def __init__(self, ignore_index=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs): super(OhemCrossEntropy2d, self).__init__() self.ignore_index = ignore_index self.thresh = float(thresh) self.min_kept = int(min_kept) if use_weight: weight = torch.FloatTensor( [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) else: self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) def forward(self, pred, target): n, c, h, w = pred.size() target = target.view(-1) valid_mask = target.ne(self.ignore_index) target = target * valid_mask.long() num_valid = valid_mask.sum() prob = F.softmax(pred, dim=1) prob = prob.transpose(0, 1).reshape(c, -1) if self.min_kept > num_valid: print("Lables: {}".format(num_valid)) elif num_valid > 0: prob = prob.masked_fill_(1 - valid_mask, 1) mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)] threshold = self.thresh if self.min_kept > 0: index = mask_prob.argsort() threshold_index = index[min(len(index), self.min_kept) - 1] if mask_prob[threshold_index] > self.thresh: threshold = mask_prob[threshold_index] kept_mask = mask_prob.le(threshold) valid_mask = valid_mask * kept_mask target = target * kept_mask.long() target = target.masked_fill_(1 - valid_mask, self.ignore_index) target = target.view(n, h, w) return self.criterion(pred, target)
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)