Source code for topolosses.losses.mosin.src.mosin_loss

from __future__ import annotations
import warnings
from typing import Optional

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import torchvision

from ...utils import compute_default_dice_loss


[docs] class MosinLoss(_Loss): """A topology-aware loss function for curvilinear structure delineation using perceptual features. The loss has been defined in: Mosinska et al. (2018) Beyond the Pixel-Wise Loss for Topology-Aware Delineation. This loss uses a pre-trained VGG19 network to extract multi-level features from predictions and targets, comparing them to enforce topological consistency. By default, it combines with a pixel-wise base loss. """ def __init__( self, include_background: bool = False, alpha: float = 0.5, softmax: bool = False, sigmoid: bool = False, use_base_loss: bool = True, base_loss: Optional[_Loss] = None, ) -> None: """ Args: include_background (bool): If `True`, includes the background class in feature extraction. Defaults to `False`. alpha (float): Weighting factor for combining the base loss and the topology loss (i.e.: base_loss + alpha*topology_loss). Defaults to 0.5. sigmoid (bool): If `True`, applies a sigmoid activation to the input before computing the loss. Sigmoid is not applied before passing it to a custom base loss function. Defaults to `False`. softmax (bool): If `True`, applies a softmax activation to the input before computing the loss. Softmax is not applied before passing it to a custom base loss function. Defaults to `False`. use_base_loss (bool): If `False`, the loss only consists of the topology component. The base_loss and alpha will be ignored if this flag is set to false. Defaults to `True`. base_loss (_Loss, optional): The base loss function to be used alongside the topology loss. Defaults to `None`, meaning a standard cross-entropy loss will be used. Raises: ValueError: If more than one of [sigmoid, softmax] is set to True. """ if sum([sigmoid, softmax]) > 1: raise ValueError( "At most one of [sigmoid, softmax] can be set to True. " "You can only choose one of these options at a time or none if you already pass probabilites." ) super(MosinLoss, self).__init__() # TODO ask authors if it makes sense to allow the users to use other models # TODO think about adding this self.vgg = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1).features self.feature_layers = [2, 7, 16] self.layer_names = ["conv1_2", "conv2_2", "conv3_4"] self.activation = {} for i, name in zip(self.feature_layers, self.layer_names): self.vgg[i].register_forward_hook(self._get_activation(name)) self.vgg.eval() self.vgg.requires_grad_(False) self.include_background = include_background self.alpha = alpha self.softmax = softmax self.sigmoid = sigmoid self.use_base_loss = use_base_loss self.base_loss = base_loss if not self.use_base_loss: if base_loss is not None: warnings.warn("base_loss is ignored beacuse use_base_component is set to false") if self.alpha != 1: warnings.warn("Alpha < 1 has no effect when no base component is used.")
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Calculates the forward pass of the Mosin Loss. Args: input (Tensor): Input tensor of shape (batch_size, num_classes, H, W). target (Tensor): Target tensor of shape (batch_size, num_classes, H, W). Returns: Tensor: The calculated betti matching loss. Raises: ValueError: If the shape of the ground truth is different from the input shape. ValueError: If softmax=True and the number of channels for the prediction is 1. ValueError: If the input dimension is smaller than 32x32. """ if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") if input.shape[2] < 32 or input.shape[3] < 32: raise ValueError(f"input dimensions must be at least 32x32, got {input.shape[2]}x{input.shape[3]}") starting_class = 0 if self.include_background else 1 num_classes = input.shape[1] if num_classes == 1: if self.softmax: raise ValueError( "softmax=True requires multiple channels for class probabilities, but received a single-channel input." ) if not self.include_background: warnings.warn( "Single-channel prediction detected. The `include_background=False` setting will be ignored." ) starting_class = 0 # Avoiding applying transformations sigmoid and softmax before passing the input to the base loss function # These settings have to be controlled by the user when initializing the base loss function base_loss = torch.tensor(0.0) if self.alpha < 1 and self.use_base_loss and self.base_loss is not None: base_loss = self.base_loss(input, target) if self.sigmoid: input = torch.sigmoid(input) elif self.softmax: input = torch.softmax(input, 1) if self.alpha < 1 and self.use_base_loss and self.base_loss is None: base_loss = compute_default_dice_loss(input, target) mosin_loss = torch.tensor(0.0) if self.alpha > 0: mosin_loss = self.compute_mosin_loss( input[:, starting_class:].float(), target[:, starting_class:].float(), ) total_loss = mosin_loss if not self.use_base_loss else base_loss + self.alpha * mosin_loss return total_loss
[docs] def compute_mosin_loss(self, prediction, target): # for multi class input prediction = torch.flatten(prediction, start_dim=0, end_dim=1).unsqueeze(1) target = torch.flatten(target, start_dim=0, end_dim=1).unsqueeze(1) prediction = prediction.expand(-1, 3, -1, -1) target = target.argmax(dim=1, keepdim=True).to(torch.float32).expand(-1, 3, -1, -1) pred_features = self._get_features(prediction) target_features = self._get_features(target) loss = 0 for layer_name in self.layer_names: loss += F.mse_loss(pred_features[layer_name], target_features[layer_name]) return loss
def _get_activation(self, name): """Hook to save activation for a given layer""" def hook(model, input, output): self.activation[name] = output return hook def _get_features(self, x): """Extract features from specified VGG layers""" self.activation = {} # Clear previous activations self.vgg(x) # Forward pass through VGG return self.activation.copy()