#!/usr/bin/env python # -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def make_one_hot(input, num_classes): """Convert class index tensor to one hot encoding tensor. Args: input: A tensor of shape [N, 1, *] num_classes: An int of number of class Returns: A tensor of shape [N, num_classes, *] """ shape = np.array(input.shape) shape[1] = num_classes shape = tuple(shape) result = torch.zeros(shape) result = result.scatter_(1, input.cpu(), 1) return result class BinaryDiceLoss(nn.Module): """Dice loss of binary class Args: smooth: A float number to smooth loss, and avoid NaN error, default: 1 p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 predict: A tensor of shape [N, *] target: A tensor of shape same with predict Returns: Loss tensor according to arg reduction Raise: Exception if unexpected reduction """ def __init__(self, smooth=1, p=2): super(BinaryDiceLoss, self).__init__() self.smooth = smooth self.p = p def forward(self, predict, target): assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" predict = predict.contiguous().view(predict.shape[0], -1) target = target.contiguous().view(target.shape[0], -1) num = torch.sum(torch.mul(predict, target))*2 + self.smooth den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth dice = num / den loss = 1 - dice return loss class DiceLoss(nn.Module): """Dice loss, need one hot encode input Args: weight: An array of shape [num_classes,] ignore_index: class index to ignore predict: A tensor of shape [N, C, *] target: A tensor of same shape with predict other args pass to BinaryDiceLoss Return: same as BinaryDiceLoss """ def __init__(self, weight=None, ignore_index=None, **kwargs): super(DiceLoss, self).__init__() self.kwargs = kwargs self.weight = weight self.ignore_index = ignore_index def forward(self, predict, target): assert predict.shape == target.shape, 'predict & target shape do not match' dice = BinaryDiceLoss(**self.kwargs) total_loss = 0 predict = F.softmax(predict, dim=1) for i in range(target.shape[1]): if i != self.ignore_index: dice_loss = dice(predict[:, i], target[:, i]) if self.weight is not None: assert self.weight.shape[0] == target.shape[1], \ 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) dice_loss *= self.weights[i] total_loss += dice_loss return total_loss/target.shape[1]