Source code for topolosses.losses.dice.src.dice_loss

import warnings
from typing import List, Optional

import torch
from torch import Tensor
from torch.nn.modules.loss import _Loss


[docs] class DiceLoss(_Loss): """Computes the Dice loss between two tensors.""" def __init__( self, smooth=1e-5, sigmoid=False, softmax=False, batch=False, include_background=True, weights: Optional[Tensor] = None, ) -> None: """ Args: smooth (float): Smoothing factor to avoid division by zero added to numerator and denominator. Defaults to 1e-5. sigmoid (bool): If `True`, applies a sigmoid activation to the input before computing the loss. Defaults to `False`. softmax (bool): If `True`, applies a softmax activation to the input before computing the loss. Defaults to `False`. batch (bool): If `True`, reduces the loss across the batch dimension by summing intersection and union areas before division. Defaults to `False`, where the loss is computed independently for each item for the Dice calculation and reduced afterwards. include_background (bool): If `False`, channel index 0 (background class) is excluded from the calculation. Defaults to `False`. weights (Tensor, optional): A 1D tensor of class-wise weights, with length equal to the number of classes (adjusted for background inclusion). It allows emphasizing or ignoring classes. Defaults to `None` (unweighted). Raises: ValueError: If more than one of `sigmoid`, `softmax`, or `convert_to_one_vs_rest` is set to `True`. """ if sum([sigmoid, softmax]) > 1: raise ValueError( "At most one of [sigmoid, softmax, convert_to_one_vs_rest] can be set to True. " "You can only choose one of these options at a time or none if you already pass probabilites." ) super(DiceLoss, self).__init__() self.smooth = smooth self.sigmoid = sigmoid self.softmax = softmax self.batch = batch self.include_background = include_background self.register_buffer("weights", weights) self.weights: Optional[Tensor]
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Computes the Dice loss between two tensors. Args: input (torch.Tensor): Predicted segmentation map of shape BC[spatial dimensions], where C is the number of classes, and [spatial dimensions] represent height, width, and optionally depth. target (torch.Tensor): Ground truth segmentation map of shape BC[spatial dimensions] Returns: torch.Tensor: The Dice loss as a scalar. Raises: ValueError: If the shape of the ground truth is different from the input shape. ValueError: If softmax=True and the number of channels for the prediction is 1. """ if target.shape != input.shape: raise ValueError(f"Ground truth has different shape ({target.shape}) from input ({input.shape})") if self.weights is not None: if len(self.weights.shape) != 1: raise ValueError("weights must be a 1-dimensional tensor (vector).") if len(self.weights) != (input.shape[1] - (0 if self.include_background or input.shape[1] == 1 else 1)): raise ValueError( f"Wrong shape of weight vector: Number of class weights ({len(self.weights)}) must match the number of classes." f"({'including' if self.include_background else 'excluding'} background) ({input.shape[1]})." ) non_zero_weights_mask = self.weights != 0 input = input[:, non_zero_weights_mask] target = target[:, non_zero_weights_mask] starting_class = 0 if self.include_background else 1 if input.shape[1] == 1: if self.softmax: raise ValueError( "softmax=True requires multiple channels for class probabilities, but received a single-channel input." ) if not self.include_background: warnings.warn( "Single-channel prediction detected. The `include_background=False` setting will be ignored." ) starting_class = 0 input = input[:, starting_class:] target = target[:, starting_class:] if self.sigmoid: input = torch.sigmoid(input) elif self.softmax: input = torch.softmax(input, 1) reduce_axis: List[int] = [0] * self.batch + list(range(2, len(input.shape))) intersection = torch.sum(target * input, dim=reduce_axis) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o dice = 1.0 - (2.0 * intersection + self.smooth) / (denominator + self.smooth) # Weights are normalized to keep scales consistent # This is different to the monai implementation of weighted dice loss if self.weights is not None: weighted_dice = dice * (self.weights[non_zero_weights_mask] / self.weights[non_zero_weights_mask].sum()) dice = torch.mean(weighted_dice.sum(dim=1)) if not self.batch else weighted_dice.sum() else: dice = torch.mean(dice) return dice