import warnings
from typing import List, Optional
import torch
from torch import Tensor
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from ...utils import compute_default_dice_loss
[docs]
class CLDiceLoss(_Loss):
"""A loss function for segmentation that combines a base loss and a CLDice component.
The loss has been defined in:
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)
By default the cl dice component is combined with a (weighted) default dice loss.
For more flexibility a custom base loss function can be passed.
"""
def __init__(
self,
iter_: int = 3,
smooth: float = 1e-5,
batch: bool = False,
include_background: bool = False,
alpha: float = 0.5,
sigmoid: bool = False,
softmax: bool = False,
use_base_loss: bool = True,
base_loss: Optional[_Loss] = None,
) -> None:
"""
Args:
iter_ (int): Number of iterations for soft skeleton computation. Higher values refine
the skeleton but increase computation time. Defaults to 3.
smooth (float): Smoothing factor to avoid division by zero in CLDice and the default base dice calculations. Defaults to 1e-5.
batch (bool): If `True`, reduces the loss across the batch dimension by summing intersection and union areas before division.
Defaults to `False`, where the loss is computed independently for each item for the CLDice and default base dice component calculation.
include_background (bool): If `True`, includes the background class in CLDice computation. Defaults to `False`.
alpha (float): Weighting factor for combining the CLDice component (i.e.: base_loss + alpha*cldice_loss).
Defaults to 0.5.
sigmoid (bool): If `True`, applies a sigmoid activation to the input before computing the CLDice and the default dice component.
Sigmoid is not applied before passing it to a custom base loss function. Defaults to `False`.
softmax (bool): If `True`, applies a softmax activation to the input before computing the CLDice loss.
Softmax is not applied before passing it to a custom base loss function. Defaults to `False`.
use_base_component (bool): if false the loss only consists of the CLDice component. A forward call will return the full CLDice component.
base_loss and alpha will be ignored if this flag is set to false.
base_loss (_Loss, optional): The base loss function to be used alongside the CLDice loss.
Defaults to `None`, meaning a Dice component with default parameters will be used.
Raises:
ValueError: If more than one of [sigmoid, softmax] is set to True.
"""
if sum([sigmoid, softmax]) > 1:
raise ValueError(
"At most one of [sigmoid, softmax] can be set to True. "
"You can only choose one of these options at a time or none if you already pass probabilites."
)
super(CLDiceLoss, self).__init__()
self.iter_ = iter_
self.smooth = smooth
self.batch = batch
self.include_background = include_background
self.alpha = alpha
self.sigmoid = sigmoid
self.softmax = softmax
self.use_base_loss = use_base_loss
self.base_loss = base_loss
if not self.use_base_loss:
if base_loss is not None:
warnings.warn("base_loss is ignored beacuse use_base_component is set to false")
if self.alpha != 1:
warnings.warn(
"Alpha < 1 has no effect when no base component is used. The full ClDice loss will be returned."
)
[docs]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Computes the CLDice loss and base loss for the given input and target.
Args:
input (torch.Tensor): Predicted segmentation map of shape BC[spatial dimensions],
where C is the number of classes, and [spatial dimensions] represent height, width, and optionally depth.
target (torch.Tensor): Ground truth segmentation map of shape BC[spatial dimensions]
Returns:
Tensor: The calculated CLDice loss
Raises:
ValueError: If the shape of the ground truth is different from the input shape.
ValueError: If softmax=True and the number of channels for the prediction is 1.
"""
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
if len(input.shape) < 4:
raise ValueError(
"Invalid input tensor shape. Expected at least 4 dimensions in the format (batch, channel, [spatial dims]), "
"where 'spatial dims' must be at least 2D (height, width). "
f"Received shape: {input.shape}."
)
starting_class = 0 if self.include_background else 1
if input.shape[1] == 1:
if self.softmax:
raise ValueError(
"softmax=True requires multiple channels for class probabilities, but received a single-channel input."
)
if not self.include_background:
warnings.warn(
"Single-channel prediction detected. The `include_background=False` setting will be ignored."
)
starting_class = 0
# Avoiding applying transformations like sigmoid, softmax, or one-vs-rest before passing the input to the base loss function
# These settings have to be controlled by the user when initializing the base loss function
base_loss = torch.tensor(0.0)
if self.alpha < 1 and self.use_base_loss and self.base_loss is not None:
base_loss = self.base_loss(input, target)
if self.sigmoid:
input = torch.sigmoid(input)
elif self.softmax:
input = torch.softmax(input, 1)
reduce_axis: List[int] = [0] * self.batch + list(range(2, len(input.shape)))
if self.alpha < 1 and self.use_base_loss and self.base_loss is None:
base_loss = compute_default_dice_loss(
input,
target,
reduce_axis,
self.smooth,
)
cl_dice = torch.tensor(0.0)
if self.alpha > 0:
cl_dice = self.compute_cldice_loss(
input[:, starting_class:].float(),
target[:, starting_class:].float(),
reduce_axis,
)
total_loss = cl_dice if not self.use_base_loss else base_loss + self.alpha * cl_dice
return total_loss # , {"base": (1 - self.alpha) * base_loss, "cldice": self.alpha * cl_dice}
[docs]
def compute_cldice_loss(
self,
input: torch.Tensor,
target: torch.Tensor,
reduce_axis: List[int],
) -> torch.Tensor:
"""Computes the CLDice loss.
Args:
input (torch.Tensor): The predicted segmentation map with shape (N, C, ...),
where N is batch size, C is the number of classes.
target (torch.Tensor): The ground truth segmentation map with the same shape as `input`.
smooth (float): Smoothing factor to avoid division by zero.
iter_ (int): Number of iterations for soft skeleton computation.
reduce_axis (List[int]): The axes along which to reduce the loss computation.
It decides whether to sum the intersection and union areas over the batch dimension before the dividing.
Returns:
torch.Tensor: The CLDice loss as a scalar tensor.
"""
pred_skeletons = soft_skel(input, self.iter_)
target_skeletons = soft_skel(target, self.iter_)
tprec = (
torch.sum(
torch.multiply(pred_skeletons, target),
dim=reduce_axis,
)
+ self.smooth
) / (torch.sum(pred_skeletons, dim=reduce_axis) + self.smooth)
tsens = (
torch.sum(
torch.multiply(target_skeletons, input),
dim=reduce_axis,
)
+ self.smooth
) / (torch.sum(target_skeletons, dim=reduce_axis) + self.smooth)
return torch.mean(1.0 - 2.0 * (tprec * tsens) / (tprec + tsens))
def soft_erode(img: torch.Tensor) -> torch.Tensor:
"""Erode the input image by shrinking objects using max pooling"""
if len(img.shape) == 4:
p1 = -F.max_pool2d(-img, (3, 1), (1, 1), (1, 0))
p2 = -F.max_pool2d(-img, (1, 3), (1, 1), (0, 1))
return torch.min(p1, p2)
else:
raise ValueError("input tensor must have 4D with shape: (batch, channel, height, width)")
def soft_dilate(img: torch.Tensor) -> torch.Tensor:
"""Perform soft dilation on the input image using max pooling."""
if len(img.shape) == 4:
return F.max_pool2d(img, (3, 3), (1, 1), (1, 1))
else:
raise ValueError("input tensor must have 4D with shape: (batch, channel, height, width)")
def soft_open(img: torch.Tensor) -> torch.Tensor:
"""Apply opening: erosion followed by dilation."""
return soft_dilate(soft_erode(img))
def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
"""Generate a soft skeleton by iteratively applying erosion and opening."""
img1 = soft_open(img)
skel = F.relu(img - img1)
for _ in range(iter_):
img = soft_erode(img)
img1 = soft_open(img)
delta = F.relu(img - img1)
skel = skel + F.relu(delta - skel * delta)
return skel