{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ce07e17d-82d7-4d42-b6b7-6ee80a55341b", "metadata": {}, "outputs": [], "source": [ "#| default_exp vision_loss" ] }, { "cell_type": "code", "execution_count": null, "id": "e8eee6b3-75c3-4468-8645-8f6a42a76bfe", "metadata": {}, "outputs": [], "source": [ "#| export\n", "import torch\n", "from fastMONAI.vision_core import *\n", "from fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask\n", "from monai.losses import TverskyLoss, FocalLoss\n", "from torch.nn.modules.loss import _Loss" ] }, { "cell_type": "markdown", "id": "5931e820-48c8-46d9-a30f-0172dc708f26", "metadata": {}, "source": [ "# Custom loss functions\n", ">" ] }, { "cell_type": "code", "execution_count": null, "id": "e0c0c220-aaeb-46c6-8d18-f72cd9da0555", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class CustomLoss:\n", " \"\"\"A custom loss wrapper class for loss functions to allow them to work with\n", " the 'show_results' method in fastai. \n", " \"\"\"\n", "\n", " def __init__(self, loss_func):\n", " \"\"\"Constructs CustomLoss object.\"\"\"\n", " \n", " self.loss_func = loss_func\n", "\n", " def __call__(self, pred, targ):\n", " \"\"\"Computes the loss for given predictions and targets.\"\"\"\n", " \n", " if isinstance(pred, MedBase):\n", " pred, targ = torch.Tensor(pred.cpu()), torch.Tensor(targ.cpu().float())\n", " \n", " return self.loss_func(pred, targ)\n", "\n", " def activation(self, x):\n", " return x\n", " \n", " def decodes(self, x) -> torch.Tensor:\n", " \"\"\"Converts model output to target format.\n", " \n", " Args:\n", " x: Activations for each class with dimensions [B, C, W, H, D].\n", "\n", " Returns:\n", " The predicted mask.\n", " \"\"\"\n", " \n", " n_classes = x.shape[1]\n", " if n_classes == 1: \n", " x = pred_to_binary_mask(x)\n", " else: \n", " x,_ = batch_pred_to_multiclass_mask(x)\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "id": "5052c7bc-3d9a-4e34-8b64-bceaf2fdc7b6", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class TverskyFocalLoss(_Loss):\n", " \"\"\"\n", " Compute Tversky loss with a focus parameter, gamma, applied.\n", " The details of Tversky loss is shown in ``monai.losses.TverskyLoss``.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " include_background: bool = True,\n", " to_onehot_y: bool = False,\n", " sigmoid: bool = False,\n", " softmax: bool = False,\n", " gamma: float = 2,\n", " alpha: float = 0.5, \n", " beta: float = 0.99):\n", " \"\"\"\n", " Args:\n", " include_background: if to calculate loss for the background class.\n", " to_onehot_y: whether to convert `y` into one-hot format.\n", " sigmoid: if True, apply a sigmoid function to the prediction.\n", " softmax: if True, apply a softmax function to the prediction.\n", " gamma: the focal parameter, it modulates the loss with regards to \n", " how far the prediction is from target.\n", " alpha: the weight of false positive in Tversky loss calculation.\n", " beta: the weight of false negative in Tversky loss calculation.\n", " \"\"\"\n", " \n", " super().__init__()\n", " self.tversky = TverskyLoss(\n", " to_onehot_y=to_onehot_y, \n", " include_background=include_background, \n", " sigmoid=sigmoid, \n", " softmax=softmax, \n", " alpha=alpha, \n", " beta=beta\n", " )\n", " self.gamma = gamma\n", "\n", " def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " input: the shape should be [B, C, W, H, D]. The input should be the original logits.\n", " target: the shape should be[B, C, W, H, D].\n", "\n", " Raises:\n", " ValueError: When number of dimensions for input and target are different.\n", " \"\"\"\n", " if len(input.shape) != len(target.shape):\n", " raise ValueError(\"The number of dimensions for input and target should be the same.\")\n", "\n", " tversky_loss = self.tversky(input, target)\n", " total_loss: torch.Tensor = 1 - ((1 - tversky_loss)**self.gamma)\n", "\n", " return total_loss" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }