Source code for topolosses.losses.topograph.src.topograph_loss

from __future__ import annotations
import warnings
from typing import Optional

import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import torch.multiprocessing as mp

# C++ implementation Topograph, can be installed locally as own package with the setup file
# during the build process for the distribution file a similar package is build using the cmake file.
# TODO make a local build of topograph that can be imported the same way during dev as during build process
# import Topograph as _topograph

from . import _topograph

from ...utils import (
    AggregationType,
    ThresholdDistribution,
    fill_adj_matr,
    new_compute_diag_diffs,
    new_compute_diffs,
    compute_default_dice_loss,
)
from scipy.ndimage import label
from scipy.cluster.hierarchy import DisjointSet


def reverse_pairing(pairing: int) -> tuple[int, int]:
    match pairing:
        case 0:
            return 0, 0
        case 1:
            return 1, 0
        case 2:
            return 0, 1
        case 3:
            return 1, 1
        case _:
            return -1, -1


def label_regions(pred: np.ndarray, gt: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Labels the regions in the predicted image based on the intersection of the predicted and ground truth images by assign a unique class to each connected component.

    Args:
        pred (ndarray, [H, W]): The predicted and binarized image in argmax encoding of shape [H, W].
        gt (ndarray, [H, W]): The ground truth image in argmax encoding of shape [H, W].

    Returns:
        tuple: A tuple containing the labeled regions, masks, prediction labels, and ground truth labels.
            - all_labels (ndarray, [H, W]): The labeled regions starting at 1 of shape [H, W].
            - pred_labels (ndarray, [N]): The predicted classes for each region.
            - gt_labels (ndarray, [N]): The ground truth classes for each region.
    """
    # create one hot encoding for each intersection class
    paired_img = pred + 2 * gt
    masked_imgs = np.eye(4)[paired_img].transpose(2, 0, 1).astype(np.int32)

    # use map to iterate through all possible combinations of classes and create connected component labeling with masks
    cc_result = map(label, masked_imgs)

    all_labels = np.zeros(pred.shape, dtype=np.int32)
    label_counter = 0  # counter for the number of classes that have already been set

    gt_labels = []
    pred_labels = []

    # iterate through all possible combinations of classes and aggregate all labels
    for inters_class, (labeled_regions, num_nodes) in enumerate(cc_result):
        # add the labeled mask to the final image
        all_labels += labeled_regions + (masked_imgs[inters_class] * label_counter)
        label_counter += num_nodes

        # get pred and gt class via reverse cantor pairing
        pred_class, gt_class = reverse_pairing(inters_class)

        # append pred and gt classes
        pred_labels.append(np.zeros((num_nodes)) + pred_class)
        gt_labels.append(np.zeros((num_nodes)) + gt_class)

    # convert label lists to numpy arrays
    pred_labels = np.concatenate(pred_labels)
    gt_labels = np.concatenate(gt_labels)

    if all_labels.max() > 0:
        all_labels -= 1

    return all_labels, pred_labels, gt_labels


def rag(labelled_regions, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl):
    max_label = labelled_regions.max()

    # if all voxel have the same class, there are no edges
    if max_label == 0:
        edges = np.empty((2, 0))
    else:
        # get the classes of each edge
        h_edges = np.stack([labelled_regions[1:, :][h_diff], labelled_regions[:-1, :][h_diff]])
        v_edges = np.stack([labelled_regions[:, 1:][v_diff], labelled_regions[:, :-1][v_diff]])

        # create adjacency matrix
        adj = np.zeros((max_label + 1, max_label + 1), dtype=bool)
        special_adj = np.zeros((max_label + 1, max_label + 1), dtype=bool)
        adj = fill_adj_matr(adj, h_edges, v_edges)

        dr_edges = np.stack([labelled_regions[:-1, :-1][diagr], labelled_regions[1:, 1:][diagr]])
        dl_edges = np.stack([labelled_regions[:-1, 1:][diagl], labelled_regions[1:, :-1][diagl]])
        special_dr_edges = np.stack(
            [labelled_regions[:-1, :-1][special_diagr], labelled_regions[1:, 1:][special_diagr]]
        )
        special_dl_edges = np.stack(
            [labelled_regions[:-1, 1:][special_diagl], labelled_regions[1:, :-1][special_diagl]]
        )
        adj = fill_adj_matr(adj, dr_edges, dl_edges)
        special_adj = fill_adj_matr(special_adj, special_dr_edges, special_dl_edges)

        # convert to edge index list
        edges = np.stack(np.nonzero(adj))
        special_edges = np.stack(np.nonzero(special_adj))

    return edges, special_edges


def contract_graph(graph):
    # identify clusters of nodes that all have the same predicted and gt class
    same_nodes = DisjointSet(graph.nodes)

    for node in graph.nodes:
        # skip correct background nodes because they never have only a diagonal edge
        if graph.nodes[node]["predicted_classes"] == 0 and graph.nodes[node]["gt_classes"] == 0:
            continue

        # get the node's cluster
        cur_node_cluster = same_nodes[node]

        # iterate through all neighbors of the current node
        for neighbor in graph[node]:
            # visit each edge only once or if it is a special edge, skip it
            if neighbor < node or graph[node][neighbor].get("special", False):
                continue
            # check if the neighbor has the same predicted and gt class as the current node
            if (
                graph.nodes[neighbor]["predicted_classes"] == graph.nodes[node]["predicted_classes"]
                and graph.nodes[neighbor]["gt_classes"] == graph.nodes[node]["gt_classes"]
            ):
                nbr_cluster = same_nodes[neighbor]

                if nbr_cluster != cur_node_cluster:
                    same_nodes.merge(cur_node_cluster, nbr_cluster)

    # contract nodes in the graph based on the clusters
    for cluster in same_nodes.subsets():
        if len(cluster) == 1:
            continue

        # get the first node in the cluster
        first_node = cluster.pop()

        # Save the contracted nodes in the first node of each cluster
        graph.nodes[first_node]["contracted_nodes"] = cluster

        # contract all other nodes in the cluster to the first node
        for node in cluster:
            nx.contracted_nodes(graph, first_node, node, self_loops=False, copy=False)

    return graph


def identify_clusters(graph):
    pred_cluster = DisjointSet(graph.nodes)
    gt_cluster = DisjointSet(graph.nodes)

    for node in graph.nodes:
        # skip correct background nodes because they're never part of a cluster
        if graph.nodes[node]["predicted_classes"] == 0 and graph.nodes[node]["gt_classes"] == 0:
            continue

        # get the node's clusters
        cur_pred_cluster = pred_cluster[node]
        cur_gt_cluster = gt_cluster[node]

        # iterate through all neighbors of the current node
        for neighbor in graph[node]:
            # visit each edge only once
            if neighbor < node:
                continue
            # # if it is a special edge, skip it
            # if graph[node][neighbor].get('special', False):
            #     continue
            # if they are both predicted foreground, merge pred cluster
            if graph.nodes[neighbor]["predicted_classes"] == 1 and graph.nodes[node]["predicted_classes"] == 1:
                pred_nbr_cluster = pred_cluster[neighbor]

                if pred_nbr_cluster != cur_pred_cluster:
                    pred_cluster.merge(cur_pred_cluster, pred_nbr_cluster)

            # if they have the same gt class, merge gt cluster
            if graph.nodes[neighbor]["gt_classes"] == 1 and graph.nodes[node]["gt_classes"] == 1:
                gt_nbr_cluster = gt_cluster[neighbor]

                if gt_nbr_cluster != cur_gt_cluster:
                    gt_cluster.merge(cur_gt_cluster, gt_nbr_cluster)

    # add pred cluster to each node
    for cluster in pred_cluster.subsets():
        node = cluster.pop()
        root = pred_cluster[node]
        graph.nodes[node]["pred_cluster"] = root

        for node in cluster:
            graph.nodes[node]["pred_cluster"] = root

    # add gt cluster to each node
    for cluster in gt_cluster.subsets():
        node = cluster.pop()
        root = gt_cluster[node]
        graph.nodes[node]["gt_cluster"] = root

        for node in cluster:
            graph.nodes[node]["gt_cluster"] = root

    return graph


def create_graph(argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl):
    labelled_regions, predicted_classes, gt_classes = label_regions(argmax_pred, argmax_gt)

    # create a graph from the labelled regions
    if labelled_regions.max() == 0:  # if there is only one class, create a graph with a single node
        graph = nx.Graph()
        graph.add_node(0)
        edge_index = torch.tensor([[], []])
        special_edge_index = torch.tensor([[], []])
    else:
        edge_index, special_edge_index = rag(
            labelled_regions, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
        )

    graph = nx.Graph()
    graph.add_edges_from(edge_index.T)
    graph.add_edges_from(special_edge_index.T, special=True)

    # add node attributes
    for node in graph.nodes:
        graph.nodes[node]["predicted_classes"] = predicted_classes[node]
        graph.nodes[node]["gt_classes"] = gt_classes[node]

    graph.graph["predicted_classes"] = predicted_classes

    graph = contract_graph(graph)

    graph = identify_clusters(graph)

    return graph, labelled_regions


def get_critical_nodes(graph):
    critical_nodes = []
    cluster_lengths = []
    error_type = []

    for node in graph.nodes:
        # skip correctly predicted nodes
        if graph.nodes[node]["predicted_classes"] == graph.nodes[node]["gt_classes"]:
            continue

        all_nbrs = list(graph[node])

        fg_nbr_clusters = set()
        correct_bg_nbrs_count = 0
        counter_class_str = "gt_cluster" if graph.nodes[node]["predicted_classes"] == 1 else "pred_cluster"

        for nbr in all_nbrs:
            # if it is a special edge, skip it
            if graph[node][nbr].get("special", False):
                continue

            nbr_gt_class = graph.nodes[nbr]["gt_classes"]

            # if neighbor is correctly predicted, add the
            if nbr_gt_class == 0 and graph.nodes[nbr]["predicted_classes"] == 0:  # correct background case
                correct_bg_nbrs_count += 1
                # If we have more than one correct background neighbor, we can stop here
                if correct_bg_nbrs_count > 1:
                    break
            else:  # all other nbrs are either incorrect foreground in the counter class or correct foreground
                fg_nbr_clusters.add(graph.nodes[nbr][counter_class_str])

        # if cur_node does not have exactly one correct background neighbor or not exactly one foreground nbr cluster, add it to critical nodes
        if correct_bg_nbrs_count != 1 or len(fg_nbr_clusters) != 1:
            node_error = 0 if correct_bg_nbrs_count + len(fg_nbr_clusters) < 2 else 2
            if node_error == 0:
                # check what type of error it is
                if correct_bg_nbrs_count == 0:
                    node_error = 0
                else:
                    node_error = 1
            error_type.append(node_error)
            critical_nodes.append(node)

            if "contracted_nodes" in graph.nodes[node]:
                critical_nodes += graph.nodes[node]["contracted_nodes"]
                cluster_lengths.append(len(graph.nodes[node]["contracted_nodes"]) + 1)
            else:
                cluster_lengths.append(1)
            continue

    return critical_nodes, cluster_lengths, error_type


def get_critical_nbrs(graph):
    error_count = 0

    for node in graph.nodes:
        # skip correctly predicted nodes
        if graph.nodes[node]["predicted_classes"] == graph.nodes[node]["gt_classes"]:
            continue

        all_nbrs = list(graph[node])

        fg_nbr_clusters = set()
        bg_nbrs = set()
        counter_class_str = "gt_cluster" if graph.nodes[node]["predicted_classes"] == 1 else "pred_cluster"
        class_str = "gt_visisted" if graph.nodes[node]["predicted_classes"] == 1 else "pred_visited"

        for nbr in all_nbrs:
            # if it is a special edge, skip it
            if graph[node][nbr].get("special", False):
                continue

            nbr_gt_class = graph.nodes[nbr]["gt_classes"]

            # if neighbor is correctly predicted, add the
            if nbr_gt_class == 0 and graph.nodes[nbr]["predicted_classes"] == 0:  # correct background case
                bg_nbrs.add(nbr)
            else:  # all other nbrs are either incorrect foreground in the counter class or correct foreground
                fg_nbr_clusters.add(graph.nodes[nbr][counter_class_str])

        if len(bg_nbrs) == 1 and len(fg_nbr_clusters) == 1:
            continue

        # if a correct nbr is missing, add one to the error count
        if len(bg_nbrs) == 0:
            error_count += 1
        elif len(bg_nbrs) > 1:
            # if there are too many nbrs, count each as error that has not been counted yet
            seen_nodes = 0
            for error_node in bg_nbrs:
                if not class_str in graph.nodes[error_node]:
                    graph.nodes[error_node][class_str] = True
                else:
                    seen_nodes += 1

            error_count += len(bg_nbrs) - max(seen_nodes, 1)

        if len(fg_nbr_clusters) == 0:
            error_count += 1
        elif len(fg_nbr_clusters) > 1:
            seen_nodes = 0
            for error_node in fg_nbr_clusters:
                if not class_str in graph.nodes[error_node]:
                    graph.nodes[error_node][class_str] = True
                else:
                    seen_nodes += 1

            error_count += len(fg_nbr_clusters) - max(seen_nodes, 1)

    return error_count


def create_relabel_masks(critical_node_list, cluster_lengths, all_labels):
    region_error_infos = []
    remaining_nodes_in_cluster = 0
    i = 0
    cluster_counter = -1

    while i < len(critical_node_list):
        cluster_counter += 1
        node_set = [critical_node_list[i]]
        i += 1
        remaining_nodes_in_cluster = cluster_lengths[cluster_counter] - 1

        while remaining_nodes_in_cluster > 0:
            node_set.append(critical_node_list[i])
            i += 1
            remaining_nodes_in_cluster -= 1

        # get indices from all positions where all_labels is equal to any node in the node_set
        relabel_mask = np.isin(all_labels, node_set)

        index_relabel_mask = np.nonzero(relabel_mask)

        region_error_infos.append(index_relabel_mask)

    return region_error_infos


def create_relabel_masks_c(critical_node_list, cluster_lengths, all_labels):

    # convert to fortran storage order
    all_labels = np.asfortranarray(all_labels).astype(np.int32)
    # convert list to 1-dim numpy array in fortran storage
    critical_nodes = np.asfortranarray(critical_node_list).astype(np.int32)
    cluster_lengths = np.asfortranarray(cluster_lengths).astype(np.int32)

    relabel_indices = _topograph.get_relabel_indices(all_labels, critical_nodes, cluster_lengths)

    return relabel_indices


def _single_sample_class_loss(
    argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl, sample_no, use_c=True
):
    # create graph
    graph, labelled_regions = create_graph(
        argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
    )

    # time graph creation
    # graph_time = timeit.timeit("new_create_graph(paired_img, argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl)", globals=locals() | globals(), number=1)
    # print(f"Small: Graph creation time: {graph_time}")

    # identify critical nodes
    critical_nodes, cluster_lengths, error_types = get_critical_nodes(graph)

    # time critical node identification
    # critical_node_time = timeit.timeit("new_get_critical_nodes(graph)", globals=locals() | globals(), number=1)
    # print(f"Small: Critical node identification time: {critical_node_time}")

    # create relabel masks for all classes
    # error_region_infos = new_create_relabel_masks(one_hot_pred, graph, critical_nodes, labelled_regions)
    if use_c:
        error_region_infos = create_relabel_masks_c(critical_nodes, cluster_lengths, labelled_regions)
    else:
        error_region_infos = create_relabel_masks(critical_nodes, cluster_lengths, labelled_regions)

    # time relabel mask creation
    # relabel_mask_time = timeit.timeit("new_create_relabel_masks(one_hot_pred, graph, critical_nodes, labelled_regions)", globals=locals() | globals(), number=1)
    # print(f"Small: Reabel mask creation time: {relabel_mask_time}")

    return error_region_infos, sample_no, error_types


def _single_sample_class_metric(
    argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl, sample_no
):
    # create graph
    graph, labelled_regions = create_graph(
        argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
    )

    # get error causing neighbors
    error_count = get_critical_nbrs(graph)

    return error_count, sample_no


def single_sample_class_loss(args: dict):
    # time = timeit.timeit("_new_single_sample_class_loss(**args)", globals=locals() | globals(), number=1)
    # print(f"Time: {time}")
    return _single_sample_class_loss(**args)


def single_sample_class_metric(args: dict):
    # time = timeit.timeit("_new_single_sample_class_metric(**args)", globals=locals() | globals(), number=1)
    # print(f"Time: {time}")
    return _single_sample_class_metric(**args)


def find_saddle_points_in_8_neighborhood(tensor):
    # tensor: (1, num_classes, H, W)
    unfolded = F.unfold(tensor, kernel_size=(3, 1), padding=(1, 0))

    # unfold now has dim (1, num_classes*3, H*W)
    # now reshape unfold to (1, num_classes, 3, H*W)
    unfolded = unfolded.view(1, tensor.size(1), 3, tensor.size(2) * tensor.size(3))
    # now take the max for each 3x1 window
    max_vertical_pooled = unfolded.max(dim=2).values
    min_vertical_pooled = unfolded.min(dim=2).values

    max_vertical = max_vertical_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
    min_vertical = min_vertical_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))

    # now unfold in the horizontal direction
    unfolded = F.unfold(tensor, kernel_size=(1, 3), padding=(0, 1))
    # unfold now has dim (1, num_classes*3, H*W)
    # now reshape unfold to (1, num_classes, 3, H*W)
    unfolded = unfolded.view(1, tensor.size(1), 3, tensor.size(2) * tensor.size(3))

    # now take the max for each 1x3 window
    max_horizontal_pooled = unfolded.max(dim=2).values
    min_horizontal_pooled = unfolded.min(dim=2).values

    max_horizontal = max_horizontal_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
    min_horizontal = min_horizontal_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))

    # A saddle point is a maximum in one direction and a minimum in the perpendicular direction
    saddle_mask = (
        (
            ((tensor >= max_horizontal) & (tensor <= min_vertical))  # Horizontal max, vertical min
            | ((tensor >= max_vertical) & (tensor <= min_horizontal))  # Vertical max, horizontal min
        )
        .squeeze(0)
        .squeeze(0)
    )  # Remove batch and channel dimensions

    return saddle_mask


[docs] class TopographLoss(_Loss): """TopographLoss is a loss function designed to ensure strict topology preservation during image segmentation tasks. The loss has been defined: Lux et al (2024) Topograph: An efficient Graph-Based Framework for Strictly Topology Preserving Image Segmentation (https://arxiv.org/pdf/2411.03228) By default the topograph component is combined with a dice loss comnponent. For more flexibility a custom base loss function can be passed. """ def __init__( self, num_processes=1, use_c: bool = True, sphere: bool = False, eight_connectivity: bool = True, aggregation: AggregationType | str = AggregationType.MEAN, thres_distr: ThresholdDistribution | str = ThresholdDistribution.NONE, thres_var: float = 0.0, include_background: bool = False, alpha: float = 0.1, softmax: bool = False, sigmoid: bool = False, use_base_loss: bool = True, base_loss: Optional[_Loss] = None, ) -> None: """ Args: num_processes (int): Number of parallel processes to use for computation. TODO how exactly is this implemented and what does the user need to know? use_c (bool): Whether to use the C implementation (likely for performance) instead of a pure Python version. Defaults to True. TODO figure out if this option is useful if the package always comes with c++ extension. sphere (bool): If True, adds padding to create periodic boundary conditions (sphere topology). Defaults to False eight_connectivity (bool): Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component) versus 4-connectivity when building the component graph. Defaults to 8-connectivity. aggregation (AggregationType): Specifies the aggregation method for loss calculation across the batch. Possible values are mean, sum, max, min, ce, rms, and leg. Defaults to mean thres_distr (ThresholdDistribution): Determines the distribution used for sampling the binarization threshold. Possible values are uniform and gaussian. Defaults to None which is a constant binarization threshold of 0.5. thres_var (float): If a thres_distribution is set, this varibale controls the magnitude of random threshold variation applied during loss computation, with higher values increasing the noise. Defaults to 0.0. include_background (bool): If `True`, includes the background class in the topograph computation. Background inclusion in the base loss component should be controlled independently. alpha (float): Weighting factor for the topograph loss component. Is only applied if a base loss is used. Defaults to 0.1. sigmoid (bool): If `True`, applies a sigmoid activation to the input before computing the CLDice loss. Typically used for binary segmentation. Defaults to `False`. softmax (bool): If `True`, applies a softmax activation to the input before computing the CLDice loss. This is useful for multi-class segmentation tasks. Defaults to `False`. For other activation functions set sigmoid and softmax to false and apply the transformation before passing inputs to the loss. use_base_component (bool): if false the loss only consists of the Topograph component. A forward call will return the full Topograph component. base_loss, weights, and alpha will be ignored if this flag is set to false. base_loss (_Loss, optional): The base loss function to be used alongside the Topograph loss. Defaults to `None`, meaning a Dice component with default parameters will be used. . Raises: ValueError: If more than one of [sigmoid, softmax] is set to True. """ if sum([sigmoid, softmax]) > 1: raise ValueError( "At most one of [sigmoid, softmax] can be set to True. " "You can only choose one of these options at a time or none if you already pass probabilites." ) super(TopographLoss, self).__init__() self.num_processes = num_processes if self.num_processes > 1: self.pool = mp.Pool(num_processes) self.use_c = use_c self.sphere = sphere self.eight_connectivity = eight_connectivity self.thres_distr = ( ThresholdDistribution(thres_distr) if not isinstance(thres_distr, ThresholdDistribution) else thres_distr ) self.thres_var = thres_var self.aggregation = ( AggregationType(aggregation) if not isinstance(aggregation, AggregationType) else aggregation ) self.include_background = include_background self.alpha = alpha self.softmax = softmax self.sigmoid = sigmoid self.use_base_loss = use_base_loss self.base_loss = base_loss if not self.use_base_loss: if base_loss is not None: warnings.warn("base_loss is ignored beacuse use_base_component is set to false") if self.alpha != 1: warnings.warn("Alpha < 1 has no effect when no base component is used.")
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Calculates the forward pass of the topograph loss. Args: input (Tensor): Input tensor of shape (batch_size, num_classes, H, W). target (Tensor): Target tensor of shape (batch_size, num_classes, H, W). Returns: Tensor: The calculated topological loss. Raises: ValueError: If the shape of the ground truth is different from the input shape. ValueError: If softmax=True and the number of channels for the prediction is 1. """ if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") starting_class = 0 if self.include_background else 1 num_classes = input.shape[1] if num_classes == 1: if self.softmax: raise ValueError( "softmax=True requires multiple channels for class probabilities, but received a single-channel input." ) if not self.include_background: warnings.warn( "Single-channel prediction detected. The `include_background=False` setting will be ignored." ) starting_class = 0 # Avoiding applying transformations like sigmoid, softmax, or one-vs-rest before passing the input to the base loss function # These settings have to be controlled by the user when initializing the base loss function base_loss = torch.tensor(0.0) if self.alpha < 1 and self.use_base_loss and self.base_loss is not None: base_loss = self.base_loss(input, target) if self.sigmoid: input = torch.sigmoid(input) elif self.softmax: input = torch.softmax(input, 1) if self.alpha < 1 and self.use_base_loss and self.base_loss is None: base_loss = compute_default_dice_loss(input, target) topograph_loss = torch.tensor(0.0) if self.alpha > 0: topograph_loss = self.compute_topopgraph_loss(input.float(), target.float(), starting_class, num_classes) total_loss = topograph_loss if not self.use_base_loss else base_loss + self.alpha * topograph_loss return total_loss
[docs] def compute_topopgraph_loss(self, input, target, starting_class, num_classes): if self.thres_distr != ThresholdDistribution.NONE: # Get the random probability to add to a class match self.thres_distr: case ThresholdDistribution.UNIFORM: thres_noise = torch.rand(size=[input.shape[0], 1, 1], requires_grad=False, device=input.device) * ( self.thres_var / (num_classes - 1) ) case ThresholdDistribution.GAUSSIAN: thres_noise = torch.randn( size=[input.shape[0], 1, 1], requires_grad=False, device=input.device ) * (self.thres_var / (num_classes - 1)) input_detached = input.detach().clone() # get class that is being reinforced noise_class = torch.randint(0, num_classes, (input.shape[0],), device=input.device) neg_noise = thres_noise / (num_classes - 1) # Randomly add noise to the input (we also add neg_noise bc we substract it later) input_detached[:, noise_class] += thres_noise + neg_noise input_detached -= neg_noise.unsqueeze(1) # Re-attach the modified input to the computation graph without affecting gradients modified_input = input_detached else: modified_input = input # create argmax encoding using torch argmax_preds = torch.argmax(modified_input, dim=1) argmax_gts = torch.argmax(target, dim=1) if self.sphere: argmax_preds = F.pad(argmax_preds, (1, 1, 1, 1), value=0) argmax_gts = F.pad(argmax_gts, (1, 1, 1, 1), value=0) single_calc_inputs = [] # get critical nodes for each class for class_index in range(starting_class, num_classes): # binarize image bin_preds = torch.zeros_like(argmax_preds) bin_gts = torch.zeros_like(argmax_gts) bin_preds[argmax_preds == class_index] = 1 bin_gts[argmax_gts == class_index] = 1 paired_imgs = bin_preds + 2 * bin_gts diag_val_1, diag_val_2 = (-4, 16) if self.eight_connectivity else (16, -4) paired_imgs[paired_imgs == 0] = diag_val_1 paired_imgs[paired_imgs == 3] = diag_val_2 h_diff, v_diff = new_compute_diffs(paired_imgs) diagr, diagl, special_diag_r, special_diag_l = new_compute_diag_diffs(paired_imgs, th=7) # move all to cpu # TODO: Fix device handling bin_preds = bin_preds.cpu().numpy() bin_gts = bin_gts.cpu().numpy() h_diff = h_diff.cpu().numpy() v_diff = v_diff.cpu().numpy() diagr = diagr.cpu().numpy() diagl = diagl.cpu().numpy() special_diag_r = special_diag_r.cpu().numpy() special_diag_l = special_diag_l.cpu().numpy() for i in range(input.shape[0]): # create dict with function arguments single_calc_input = { "argmax_pred": bin_preds[i], "argmax_gt": bin_gts[i], "h_diff": h_diff[i], "v_diff": v_diff[i], "diagr": diagr[i], "diagl": diagl[i], "special_diagr": special_diag_r[i], "special_diagl": special_diag_l[i], "sample_no": i, "use_c": self.use_c, } single_calc_inputs.append(single_calc_input) relabel_masks = [] if self.num_processes > 1: chunksize = ( len(single_calc_inputs) // self.num_processes if len(single_calc_inputs) > self.num_processes else 1 ) relabel_masks = self.pool.imap_unordered(single_sample_class_loss, single_calc_inputs, chunksize=chunksize) else: relabel_masks = map(single_sample_class_loss, single_calc_inputs) # calculate the topological loss for each class g_loss = torch.tensor(0.0, device=input.device) for region_error_infos, sample_no, error_types in relabel_masks: if self.aggregation == AggregationType.LEG: # clone and detach the input to avoid gradients input_mask_util = input.detach().clone() # calculate the local maxima local_maxima = F.max_pool2d( input_mask_util[sample_no, :, :, :].unsqueeze(0), kernel_size=3, stride=1, padding=1 ) # calculate the difference to all neighboring pixels local_maxima = local_maxima.squeeze(0) # get the local maxima in the region local_maxima_in_region = input_mask_util[sample_no, :, :, :] >= local_maxima # put the local maxima on the gpu local_maxima_in_region = local_maxima_in_region.bool().to(input.device) # get the saddle points saddle_mask = find_saddle_points_in_8_neighborhood(input_mask_util[sample_no, :, :, :].unsqueeze(0)) print("local maxima shape: ", local_maxima_in_region.shape) print("saddle shape: ", saddle_mask.shape) nominator_means = [] # torch.zeros(len(region_error_infos), device=input.device) num_elements = [] # torch.zeros(len(region_error_infos), device=input.device) found_or_not_dict = { "saddle_yes": 0, "saddle_no": 0, "single_local_maxima": 0, "multiple_local_maxima": 0, "no_local_maxima": 0, } for i, region_indices in enumerate(region_error_infos): if self.sphere: region_indices = torch.tensor(region_indices) region_indices -= 1 if self.aggregation != AggregationType.CE and self.aggregation != AggregationType.LEG: class_indices = argmax_preds[sample_no, region_indices[0], region_indices[1]] nominator = input[sample_no, class_indices, region_indices[0], region_indices[1]] match self.aggregation: case AggregationType.MEAN: g_loss += nominator.mean() case AggregationType.RMS: g_loss += torch.sqrt((nominator**2).mean()) case AggregationType.SUM: g_loss += nominator.sum() case AggregationType.MAX: g_loss += nominator.max() case AggregationType.MIN: g_loss += nominator.min() case AggregationType.LEG: class_indices = argmax_preds[sample_no, region_indices[0], region_indices[1]] # local_maxima_indices = np.nonzero(local_maxima_in_region[class_index, :, :]) # local_maxima_mask = local_maxima_in_region[class_index, :, :] # check for indices that are in both region_indices and local_maxima_indices region_mask = torch.zeros(input.shape[1:], device=input.device, dtype=torch.bool) region_mask[class_indices, region_indices[0], region_indices[1]] = True # create the mean loss that is for all elements in the region # mean_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].mean() # nominator_means.append(mean_loss) # num_elements.append(1) # this is the isolated component case if error_types[i] == 0 or error_types[i] == 1: # 0 has a correct foreground neighbor and is wrongly background # if error_types[i] == 0 and class_index == 0: # base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].max() # # 0 has a correct foreground neighbor and is wrongly foreground # elif error_types[i] == 0 and class_index == 1: # base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].min() # # 1 has a correct background neighbor and is wrongly foreground # elif error_types[i] == 1 and class_index == 1: # base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].max() # # 1 has a correct background neigbhor and is wrongly background # elif error_types[i] == 1 and class_index == 0: # base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].min() base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].mean() nominator_means.append(base_loss) num_elements.append(1) intersection_mask = region_mask & local_maxima_in_region if not intersection_mask.any(): found_or_not_dict["no_local_maxima"] += 1 else: # create a loss based on all but the largest local maxima in the region leg_indices = torch.nonzero(intersection_mask, as_tuple=True) if len(leg_indices[0]) <= 1: found_or_not_dict["single_local_maxima"] += 1 else: found_or_not_dict["multiple_local_maxima"] += 1 nominator = input[sample_no, leg_indices[0], leg_indices[1], leg_indices[2]] # print(f"Found {len(nominator)} local maxima") # remove the largest local maxima, if there are multiple just remove one largest_local_maxima = nominator == nominator.max() # get tthe index of the first largest local maxima remaining = largest_local_maxima.nonzero()[1:] non_max = nominator[nominator != nominator.max()].nonzero() all_idcs = torch.cat([remaining, non_max]) nominator = nominator[all_idcs] # check if there are still elements in the nominator, otherwise continue if len(nominator) < 1: continue # nominator = nominator[nominator != nominator.max()] # append every element of the nominator to the nominator_means, so that we can calculate the mean later # change nominator such that it can be appended without having different shapes else: mean_local_maxima = nominator.mean() nominator_means.append(mean_local_maxima) num_elements.append(len(nominator)) # this is the connecectivity error else: # add a mean loss as a base loss base_loss = input[sample_no, class_indices, region_indices[0], region_indices[1]].mean() nominator_means.append(base_loss) num_elements.append(1) # create the loss based on the saddle points in the region print("region mask shape", region_mask.shape) intersection_mask = region_mask & saddle_mask print("intersection mask shape", intersection_mask.shape) if not intersection_mask.any(): found_or_not_dict["saddle_no"] += 1 else: found_or_not_dict["saddle_yes"] += 1 # create a loss based on all saddle points in the region leg_indices = torch.nonzero(intersection_mask, as_tuple=True) nominator = input[sample_no, leg_indices[0], leg_indices[1], leg_indices[2]] mean_saddle_points = nominator.mean() nominator_means.append(mean_saddle_points) num_elements.append(1) # num_elements.append(1) # print(f"Found {len(nominator)} saddle points") # check if there are local maxima in the region # intersection_mask = region_mask & local_maxima_mask # print the number of local maxima in the region # print(f"Found {len(torch.nonzero(intersection_mask, as_tuple=True)[0])} local maxima for connectivity error") case AggregationType.CE: masked_input = input[sample_no, :, region_indices[0], region_indices[1]].unsqueeze(0) masked_target = target[sample_no, :, region_indices[0], region_indices[1]].unsqueeze(0) g_loss += F.cross_entropy(masked_input, masked_target, reduction="mean") case _: raise ValueError(f"Invalid aggregation type: {self.aggregation}") if self.aggregation == AggregationType.LEG: for i, lo in enumerate(nominator_means): g_loss += lo * num_elements[i] if sum(num_elements) != 0: g_loss *= len(nominator_means) / sum(num_elements) # normalize by number of classes and batch size g_loss /= input.shape[0] * (num_classes - starting_class) return g_loss