语义分割常用Loss Pytorch版

举报
Stephen1998 发表于 2020/07/16 18:17:01 2020/07/16
【摘要】 语义分割中常见的Loss,Pytorch版本实现。

CrossEntropyLoss

交叉熵

import torch
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
loss = criterion(input, target)

Dice loss

Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。公式如下:

image.png

import torch
import torch.nn as nn
 
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
 
    def forward(self, input, target):
        N = target.size(0)
        smooth = 1
	input_flat = input.view(N, -1)
	target_flat = target.view(N, -1)
	intersection = input_flat * target_flat
	loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
	loss = 1 - loss.sum() / N
	return loss
 
class MulticlassDiceLoss(nn.Module):
    def __init__(self):
	super(MulticlassDiceLoss, self).__init__()
 
    def forward(self, input, target, weights=None):
        C = target.shape[1]
        dice = DiceLoss()
        totalLoss = 0
        for i in range(C):
	    diceLoss = dice(input[:,i], target[:,i])
	    if weights is not None:
	        diceLoss *= weights[i]
	        
	    totalLoss += diceLoss
 
        return totalLoss

IOU loss

image.png

import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftIoULoss(nn.Module):
    def __init__(self, n_classes):
        super(SoftIoULoss, self).__init__()
        self.n_classes = n_classes

    @staticmethod
    def to_one_hot(tensor, n_classes):
        n, h, w = tensor.size()
        one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1)
        return one_hot

    def forward(self, input, target):
        # logit => N x Classes x H x W
        # target => N x H x W
        N = len(input)
        pred = F.softmax(input, dim=1)
        target_onehot = self.to_one_hot(target, self.n_classes)
        # Numerator Product
        inter = pred * target_onehot
        # Sum over all pixels N x C x H x W => N x C
        inter = inter.view(N, self.n_classes, -1).sum(2)
        # Denominator
        union = pred + target_onehot - (pred * target_onehot)
        # Sum over all pixels N x C x H x W => N x C
        union = union.view(N, self.n_classes, -1).sum(2)
        loss = inter / (union + 1e-16)
        # Return average loss over classes and batch
        return -loss.mean()

OhemCrossEntropy2d

OHEM(online hard example miniing)算法的核心思想是: 

根据输入样本的损失进行筛选,筛选出hard example,表示对分类和检测影响较大的样本,然后将筛选得到的这些样本应用在随机梯度下降中训练。

class OhemCrossEntropy2d(nn.Module):    
    def __init__(self, ignore_index=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs):    
        super(OhemCrossEntropy2d, self).__init__()    
        self.ignore_index = ignore_index    
        self.thresh = float(thresh)    
        self.min_kept = int(min_kept)    
        if use_weight:    
            weight = torch.FloatTensor(
            [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,    
                        1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 
                        1.0865, 1.0955,  1.0865, 1.1529, 1.0507])    
        self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)    
        else:    
            self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
            
    def forward(self, pred, target):
        n, c, h, w = pred.size()
        target = target.view(-1)
        valid_mask = target.ne(self.ignore_index)
        target = target * valid_mask.long()
        num_valid = valid_mask.sum()
        prob = F.softmax(pred, dim=1)
        prob = prob.transpose(0, 1).reshape(c, -1)
        if self.min_kept > num_valid:
            print("Lables: {}".format(num_valid))
        elif num_valid > 0:
            prob = prob.masked_fill_(1 - valid_mask, 1)
            mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)]
            threshold = self.thresh    
            if self.min_kept > 0:    
                index = mask_prob.argsort()    
                threshold_index = index[min(len(index), self.min_kept) - 1]
                if mask_prob[threshold_index] > self.thresh:
                    threshold = mask_prob[threshold_index]
            kept_mask = mask_prob.le(threshold)
            valid_mask = valid_mask * kept_mask
            target = target * kept_mask.long()
        target = target.masked_fill_(1 - valid_mask, self.ignore_index)
        target = target.view(n, h, w)    
    return self.criterion(pred, target)


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。