from __future__ import annotations
import warnings
from typing import Optional
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import torch.multiprocessing as mp
# C++ implementation Topograph, can be installed locally as own package with the setup file
# during the build process for the distribution file a similar package is build using the cmake file.
# TODO make a local build of topograph that can be imported the same way during dev as during build process
# import Topograph as _topograph
from . import _topograph
from ...utils import (
AggregationType,
ThresholdDistribution,
fill_adj_matr,
new_compute_diag_diffs,
new_compute_diffs,
compute_default_dice_loss,
)
from scipy.ndimage import label
from scipy.cluster.hierarchy import DisjointSet
def reverse_pairing(pairing: int) -> tuple[int, int]:
match pairing:
case 0:
return 0, 0
case 1:
return 1, 0
case 2:
return 0, 1
case 3:
return 1, 1
case _:
return -1, -1
def label_regions(pred: np.ndarray, gt: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Labels the regions in the predicted image based on the intersection of the predicted and ground truth images by assign a unique class to each connected component.
Args:
pred (ndarray, [H, W]): The predicted and binarized image in argmax encoding of shape [H, W].
gt (ndarray, [H, W]): The ground truth image in argmax encoding of shape [H, W].
Returns:
tuple: A tuple containing the labeled regions, masks, prediction labels, and ground truth labels.
- all_labels (ndarray, [H, W]): The labeled regions starting at 1 of shape [H, W].
- pred_labels (ndarray, [N]): The predicted classes for each region.
- gt_labels (ndarray, [N]): The ground truth classes for each region.
"""
# create one hot encoding for each intersection class
paired_img = pred + 2 * gt
masked_imgs = np.eye(4)[paired_img].transpose(2, 0, 1).astype(np.int32)
# use map to iterate through all possible combinations of classes and create connected component labeling with masks
cc_result = map(label, masked_imgs)
all_labels = np.zeros(pred.shape, dtype=np.int32)
label_counter = 0 # counter for the number of classes that have already been set
gt_labels = []
pred_labels = []
# iterate through all possible combinations of classes and aggregate all labels
for inters_class, (labeled_regions, num_nodes) in enumerate(cc_result):
# add the labeled mask to the final image
all_labels += labeled_regions + (masked_imgs[inters_class] * label_counter)
label_counter += num_nodes
# get pred and gt class via reverse cantor pairing
pred_class, gt_class = reverse_pairing(inters_class)
# append pred and gt classes
pred_labels.append(np.zeros((num_nodes)) + pred_class)
gt_labels.append(np.zeros((num_nodes)) + gt_class)
# convert label lists to numpy arrays
pred_labels = np.concatenate(pred_labels)
gt_labels = np.concatenate(gt_labels)
if all_labels.max() > 0:
all_labels -= 1
return all_labels, pred_labels, gt_labels
def rag(labelled_regions, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl):
max_label = labelled_regions.max()
# if all voxel have the same class, there are no edges
if max_label == 0:
edges = np.empty((2, 0))
else:
# get the classes of each edge
h_edges = np.stack([labelled_regions[1:, :][h_diff], labelled_regions[:-1, :][h_diff]])
v_edges = np.stack([labelled_regions[:, 1:][v_diff], labelled_regions[:, :-1][v_diff]])
# create adjacency matrix
adj = np.zeros((max_label + 1, max_label + 1), dtype=bool)
special_adj = np.zeros((max_label + 1, max_label + 1), dtype=bool)
adj = fill_adj_matr(adj, h_edges, v_edges)
dr_edges = np.stack([labelled_regions[:-1, :-1][diagr], labelled_regions[1:, 1:][diagr]])
dl_edges = np.stack([labelled_regions[:-1, 1:][diagl], labelled_regions[1:, :-1][diagl]])
special_dr_edges = np.stack(
[labelled_regions[:-1, :-1][special_diagr], labelled_regions[1:, 1:][special_diagr]]
)
special_dl_edges = np.stack(
[labelled_regions[:-1, 1:][special_diagl], labelled_regions[1:, :-1][special_diagl]]
)
adj = fill_adj_matr(adj, dr_edges, dl_edges)
special_adj = fill_adj_matr(special_adj, special_dr_edges, special_dl_edges)
# convert to edge index list
edges = np.stack(np.nonzero(adj))
special_edges = np.stack(np.nonzero(special_adj))
return edges, special_edges
def contract_graph(graph):
# identify clusters of nodes that all have the same predicted and gt class
same_nodes = DisjointSet(graph.nodes)
for node in graph.nodes:
# skip correct background nodes because they never have only a diagonal edge
if graph.nodes[node]["predicted_classes"] == 0 and graph.nodes[node]["gt_classes"] == 0:
continue
# get the node's cluster
cur_node_cluster = same_nodes[node]
# iterate through all neighbors of the current node
for neighbor in graph[node]:
# visit each edge only once or if it is a special edge, skip it
if neighbor < node or graph[node][neighbor].get("special", False):
continue
# check if the neighbor has the same predicted and gt class as the current node
if (
graph.nodes[neighbor]["predicted_classes"] == graph.nodes[node]["predicted_classes"]
and graph.nodes[neighbor]["gt_classes"] == graph.nodes[node]["gt_classes"]
):
nbr_cluster = same_nodes[neighbor]
if nbr_cluster != cur_node_cluster:
same_nodes.merge(cur_node_cluster, nbr_cluster)
# contract nodes in the graph based on the clusters
for cluster in same_nodes.subsets():
if len(cluster) == 1:
continue
# get the first node in the cluster
first_node = cluster.pop()
# Save the contracted nodes in the first node of each cluster
graph.nodes[first_node]["contracted_nodes"] = cluster
# contract all other nodes in the cluster to the first node
for node in cluster:
nx.contracted_nodes(graph, first_node, node, self_loops=False, copy=False)
return graph
def identify_clusters(graph):
pred_cluster = DisjointSet(graph.nodes)
gt_cluster = DisjointSet(graph.nodes)
for node in graph.nodes:
# skip correct background nodes because they're never part of a cluster
if graph.nodes[node]["predicted_classes"] == 0 and graph.nodes[node]["gt_classes"] == 0:
continue
# get the node's clusters
cur_pred_cluster = pred_cluster[node]
cur_gt_cluster = gt_cluster[node]
# iterate through all neighbors of the current node
for neighbor in graph[node]:
# visit each edge only once
if neighbor < node:
continue
# # if it is a special edge, skip it
# if graph[node][neighbor].get('special', False):
# continue
# if they are both predicted foreground, merge pred cluster
if graph.nodes[neighbor]["predicted_classes"] == 1 and graph.nodes[node]["predicted_classes"] == 1:
pred_nbr_cluster = pred_cluster[neighbor]
if pred_nbr_cluster != cur_pred_cluster:
pred_cluster.merge(cur_pred_cluster, pred_nbr_cluster)
# if they have the same gt class, merge gt cluster
if graph.nodes[neighbor]["gt_classes"] == 1 and graph.nodes[node]["gt_classes"] == 1:
gt_nbr_cluster = gt_cluster[neighbor]
if gt_nbr_cluster != cur_gt_cluster:
gt_cluster.merge(cur_gt_cluster, gt_nbr_cluster)
# add pred cluster to each node
for cluster in pred_cluster.subsets():
node = cluster.pop()
root = pred_cluster[node]
graph.nodes[node]["pred_cluster"] = root
for node in cluster:
graph.nodes[node]["pred_cluster"] = root
# add gt cluster to each node
for cluster in gt_cluster.subsets():
node = cluster.pop()
root = gt_cluster[node]
graph.nodes[node]["gt_cluster"] = root
for node in cluster:
graph.nodes[node]["gt_cluster"] = root
return graph
def create_graph(argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl):
labelled_regions, predicted_classes, gt_classes = label_regions(argmax_pred, argmax_gt)
# create a graph from the labelled regions
if labelled_regions.max() == 0: # if there is only one class, create a graph with a single node
graph = nx.Graph()
graph.add_node(0)
edge_index = torch.tensor([[], []])
special_edge_index = torch.tensor([[], []])
else:
edge_index, special_edge_index = rag(
labelled_regions, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
)
graph = nx.Graph()
graph.add_edges_from(edge_index.T)
graph.add_edges_from(special_edge_index.T, special=True)
# add node attributes
for node in graph.nodes:
graph.nodes[node]["predicted_classes"] = predicted_classes[node]
graph.nodes[node]["gt_classes"] = gt_classes[node]
graph.graph["predicted_classes"] = predicted_classes
graph = contract_graph(graph)
graph = identify_clusters(graph)
return graph, labelled_regions
def get_critical_nodes(graph):
critical_nodes = []
cluster_lengths = []
error_type = []
for node in graph.nodes:
# skip correctly predicted nodes
if graph.nodes[node]["predicted_classes"] == graph.nodes[node]["gt_classes"]:
continue
all_nbrs = list(graph[node])
fg_nbr_clusters = set()
correct_bg_nbrs_count = 0
counter_class_str = "gt_cluster" if graph.nodes[node]["predicted_classes"] == 1 else "pred_cluster"
for nbr in all_nbrs:
# if it is a special edge, skip it
if graph[node][nbr].get("special", False):
continue
nbr_gt_class = graph.nodes[nbr]["gt_classes"]
# if neighbor is correctly predicted, add the
if nbr_gt_class == 0 and graph.nodes[nbr]["predicted_classes"] == 0: # correct background case
correct_bg_nbrs_count += 1
# If we have more than one correct background neighbor, we can stop here
if correct_bg_nbrs_count > 1:
break
else: # all other nbrs are either incorrect foreground in the counter class or correct foreground
fg_nbr_clusters.add(graph.nodes[nbr][counter_class_str])
# if cur_node does not have exactly one correct background neighbor or not exactly one foreground nbr cluster, add it to critical nodes
if correct_bg_nbrs_count != 1 or len(fg_nbr_clusters) != 1:
node_error = 0 if correct_bg_nbrs_count + len(fg_nbr_clusters) < 2 else 2
if node_error == 0:
# check what type of error it is
if correct_bg_nbrs_count == 0:
node_error = 0
else:
node_error = 1
error_type.append(node_error)
critical_nodes.append(node)
if "contracted_nodes" in graph.nodes[node]:
critical_nodes += graph.nodes[node]["contracted_nodes"]
cluster_lengths.append(len(graph.nodes[node]["contracted_nodes"]) + 1)
else:
cluster_lengths.append(1)
continue
return critical_nodes, cluster_lengths, error_type
def get_critical_nbrs(graph):
error_count = 0
for node in graph.nodes:
# skip correctly predicted nodes
if graph.nodes[node]["predicted_classes"] == graph.nodes[node]["gt_classes"]:
continue
all_nbrs = list(graph[node])
fg_nbr_clusters = set()
bg_nbrs = set()
counter_class_str = "gt_cluster" if graph.nodes[node]["predicted_classes"] == 1 else "pred_cluster"
class_str = "gt_visisted" if graph.nodes[node]["predicted_classes"] == 1 else "pred_visited"
for nbr in all_nbrs:
# if it is a special edge, skip it
if graph[node][nbr].get("special", False):
continue
nbr_gt_class = graph.nodes[nbr]["gt_classes"]
# if neighbor is correctly predicted, add the
if nbr_gt_class == 0 and graph.nodes[nbr]["predicted_classes"] == 0: # correct background case
bg_nbrs.add(nbr)
else: # all other nbrs are either incorrect foreground in the counter class or correct foreground
fg_nbr_clusters.add(graph.nodes[nbr][counter_class_str])
if len(bg_nbrs) == 1 and len(fg_nbr_clusters) == 1:
continue
# if a correct nbr is missing, add one to the error count
if len(bg_nbrs) == 0:
error_count += 1
elif len(bg_nbrs) > 1:
# if there are too many nbrs, count each as error that has not been counted yet
seen_nodes = 0
for error_node in bg_nbrs:
if not class_str in graph.nodes[error_node]:
graph.nodes[error_node][class_str] = True
else:
seen_nodes += 1
error_count += len(bg_nbrs) - max(seen_nodes, 1)
if len(fg_nbr_clusters) == 0:
error_count += 1
elif len(fg_nbr_clusters) > 1:
seen_nodes = 0
for error_node in fg_nbr_clusters:
if not class_str in graph.nodes[error_node]:
graph.nodes[error_node][class_str] = True
else:
seen_nodes += 1
error_count += len(fg_nbr_clusters) - max(seen_nodes, 1)
return error_count
def create_relabel_masks(critical_node_list, cluster_lengths, all_labels):
region_error_infos = []
remaining_nodes_in_cluster = 0
i = 0
cluster_counter = -1
while i < len(critical_node_list):
cluster_counter += 1
node_set = [critical_node_list[i]]
i += 1
remaining_nodes_in_cluster = cluster_lengths[cluster_counter] - 1
while remaining_nodes_in_cluster > 0:
node_set.append(critical_node_list[i])
i += 1
remaining_nodes_in_cluster -= 1
# get indices from all positions where all_labels is equal to any node in the node_set
relabel_mask = np.isin(all_labels, node_set)
index_relabel_mask = np.nonzero(relabel_mask)
region_error_infos.append(index_relabel_mask)
return region_error_infos
def create_relabel_masks_c(critical_node_list, cluster_lengths, all_labels):
# convert to fortran storage order
all_labels = np.asfortranarray(all_labels).astype(np.int32)
# convert list to 1-dim numpy array in fortran storage
critical_nodes = np.asfortranarray(critical_node_list).astype(np.int32)
cluster_lengths = np.asfortranarray(cluster_lengths).astype(np.int32)
relabel_indices = _topograph.get_relabel_indices(all_labels, critical_nodes, cluster_lengths)
return relabel_indices
def _single_sample_class_loss(
argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl, sample_no, use_c=True
):
# create graph
graph, labelled_regions = create_graph(
argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
)
# time graph creation
# graph_time = timeit.timeit("new_create_graph(paired_img, argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl)", globals=locals() | globals(), number=1)
# print(f"Small: Graph creation time: {graph_time}")
# identify critical nodes
critical_nodes, cluster_lengths, error_types = get_critical_nodes(graph)
# time critical node identification
# critical_node_time = timeit.timeit("new_get_critical_nodes(graph)", globals=locals() | globals(), number=1)
# print(f"Small: Critical node identification time: {critical_node_time}")
# create relabel masks for all classes
# error_region_infos = new_create_relabel_masks(one_hot_pred, graph, critical_nodes, labelled_regions)
if use_c:
error_region_infos = create_relabel_masks_c(critical_nodes, cluster_lengths, labelled_regions)
else:
error_region_infos = create_relabel_masks(critical_nodes, cluster_lengths, labelled_regions)
# time relabel mask creation
# relabel_mask_time = timeit.timeit("new_create_relabel_masks(one_hot_pred, graph, critical_nodes, labelled_regions)", globals=locals() | globals(), number=1)
# print(f"Small: Reabel mask creation time: {relabel_mask_time}")
return error_region_infos, sample_no, error_types
def _single_sample_class_metric(
argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl, sample_no
):
# create graph
graph, labelled_regions = create_graph(
argmax_pred, argmax_gt, h_diff, v_diff, diagr, diagl, special_diagr, special_diagl
)
# get error causing neighbors
error_count = get_critical_nbrs(graph)
return error_count, sample_no
def single_sample_class_loss(args: dict):
# time = timeit.timeit("_new_single_sample_class_loss(**args)", globals=locals() | globals(), number=1)
# print(f"Time: {time}")
return _single_sample_class_loss(**args)
def single_sample_class_metric(args: dict):
# time = timeit.timeit("_new_single_sample_class_metric(**args)", globals=locals() | globals(), number=1)
# print(f"Time: {time}")
return _single_sample_class_metric(**args)
def find_saddle_points_in_8_neighborhood(tensor):
# tensor: (1, num_classes, H, W)
unfolded = F.unfold(tensor, kernel_size=(3, 1), padding=(1, 0))
# unfold now has dim (1, num_classes*3, H*W)
# now reshape unfold to (1, num_classes, 3, H*W)
unfolded = unfolded.view(1, tensor.size(1), 3, tensor.size(2) * tensor.size(3))
# now take the max for each 3x1 window
max_vertical_pooled = unfolded.max(dim=2).values
min_vertical_pooled = unfolded.min(dim=2).values
max_vertical = max_vertical_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
min_vertical = min_vertical_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
# now unfold in the horizontal direction
unfolded = F.unfold(tensor, kernel_size=(1, 3), padding=(0, 1))
# unfold now has dim (1, num_classes*3, H*W)
# now reshape unfold to (1, num_classes, 3, H*W)
unfolded = unfolded.view(1, tensor.size(1), 3, tensor.size(2) * tensor.size(3))
# now take the max for each 1x3 window
max_horizontal_pooled = unfolded.max(dim=2).values
min_horizontal_pooled = unfolded.min(dim=2).values
max_horizontal = max_horizontal_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
min_horizontal = min_horizontal_pooled.view(1, tensor.size(1), tensor.size(2), tensor.size(3))
# A saddle point is a maximum in one direction and a minimum in the perpendicular direction
saddle_mask = (
(
((tensor >= max_horizontal) & (tensor <= min_vertical)) # Horizontal max, vertical min
| ((tensor >= max_vertical) & (tensor <= min_horizontal)) # Vertical max, horizontal min
)
.squeeze(0)
.squeeze(0)
) # Remove batch and channel dimensions
return saddle_mask
[docs]
class TopographLoss(_Loss):
"""TopographLoss is a loss function designed to ensure strict topology preservation during image segmentation tasks.
The loss has been defined:
Lux et al (2024) Topograph: An efficient Graph-Based Framework for
Strictly Topology Preserving Image Segmentation (https://arxiv.org/pdf/2411.03228)
By default the topograph component is combined with a dice loss comnponent.
For more flexibility a custom base loss function can be passed.
"""
def __init__(
self,
num_processes=1,
use_c: bool = True,
sphere: bool = False,
eight_connectivity: bool = True,
aggregation: AggregationType | str = AggregationType.MEAN,
thres_distr: ThresholdDistribution | str = ThresholdDistribution.NONE,
thres_var: float = 0.0,
include_background: bool = False,
alpha: float = 0.1,
softmax: bool = False,
sigmoid: bool = False,
use_base_loss: bool = True,
base_loss: Optional[_Loss] = None,
) -> None:
"""
Args:
num_processes (int): Number of parallel processes to use for computation. TODO how exactly is this implemented and what does the user need to know?
use_c (bool): Whether to use the C implementation (likely for performance) instead of a pure Python version. Defaults to True.
TODO figure out if this option is useful if the package always comes with c++ extension.
sphere (bool): If True, adds padding to create periodic boundary conditions (sphere topology). Defaults to False
eight_connectivity (bool): Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component)
versus 4-connectivity when building the component graph. Defaults to 8-connectivity.
aggregation (AggregationType): Specifies the aggregation method for loss calculation across the batch.
Possible values are mean, sum, max, min, ce, rms, and leg. Defaults to mean
thres_distr (ThresholdDistribution): Determines the distribution used for sampling the binarization threshold.
Possible values are uniform and gaussian. Defaults to None which is a constant binarization threshold of 0.5.
thres_var (float): If a thres_distribution is set, this varibale controls the magnitude of random threshold variation applied during loss computation,
with higher values increasing the noise. Defaults to 0.0.
include_background (bool): If `True`, includes the background class in the topograph computation.
Background inclusion in the base loss component should be controlled independently.
alpha (float): Weighting factor for the topograph loss component. Is only applied if a base loss is used. Defaults to 0.1.
sigmoid (bool): If `True`, applies a sigmoid activation to the input before computing the CLDice loss.
Typically used for binary segmentation. Defaults to `False`.
softmax (bool): If `True`, applies a softmax activation to the input before computing the CLDice loss.
This is useful for multi-class segmentation tasks. Defaults to `False`.
For other activation functions set sigmoid and softmax to false and apply the transformation before passing inputs to the loss.
use_base_component (bool): if false the loss only consists of the Topograph component.
A forward call will return the full Topograph component.
base_loss, weights, and alpha will be ignored if this flag is set to false.
base_loss (_Loss, optional): The base loss function to be used alongside the Topograph loss.
Defaults to `None`, meaning a Dice component with default parameters will be used. .
Raises:
ValueError: If more than one of [sigmoid, softmax] is set to True.
"""
if sum([sigmoid, softmax]) > 1:
raise ValueError(
"At most one of [sigmoid, softmax] can be set to True. "
"You can only choose one of these options at a time or none if you already pass probabilites."
)
super(TopographLoss, self).__init__()
self.num_processes = num_processes
if self.num_processes > 1:
self.pool = mp.Pool(num_processes)
self.use_c = use_c
self.sphere = sphere
self.eight_connectivity = eight_connectivity
self.thres_distr = (
ThresholdDistribution(thres_distr) if not isinstance(thres_distr, ThresholdDistribution) else thres_distr
)
self.thres_var = thres_var
self.aggregation = (
AggregationType(aggregation) if not isinstance(aggregation, AggregationType) else aggregation
)
self.include_background = include_background
self.alpha = alpha
self.softmax = softmax
self.sigmoid = sigmoid
self.use_base_loss = use_base_loss
self.base_loss = base_loss
if not self.use_base_loss:
if base_loss is not None:
warnings.warn("base_loss is ignored beacuse use_base_component is set to false")
if self.alpha != 1:
warnings.warn("Alpha < 1 has no effect when no base component is used.")
[docs]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Calculates the forward pass of the topograph loss.
Args:
input (Tensor): Input tensor of shape (batch_size, num_classes, H, W).
target (Tensor): Target tensor of shape (batch_size, num_classes, H, W).
Returns:
Tensor: The calculated topological loss.
Raises:
ValueError: If the shape of the ground truth is different from the input shape.
ValueError: If softmax=True and the number of channels for the prediction is 1.
"""
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
starting_class = 0 if self.include_background else 1
num_classes = input.shape[1]
if num_classes == 1:
if self.softmax:
raise ValueError(
"softmax=True requires multiple channels for class probabilities, but received a single-channel input."
)
if not self.include_background:
warnings.warn(
"Single-channel prediction detected. The `include_background=False` setting will be ignored."
)
starting_class = 0
# Avoiding applying transformations like sigmoid, softmax, or one-vs-rest before passing the input to the base loss function
# These settings have to be controlled by the user when initializing the base loss function
base_loss = torch.tensor(0.0)
if self.alpha < 1 and self.use_base_loss and self.base_loss is not None:
base_loss = self.base_loss(input, target)
if self.sigmoid:
input = torch.sigmoid(input)
elif self.softmax:
input = torch.softmax(input, 1)
if self.alpha < 1 and self.use_base_loss and self.base_loss is None:
base_loss = compute_default_dice_loss(input, target)
topograph_loss = torch.tensor(0.0)
if self.alpha > 0:
topograph_loss = self.compute_topopgraph_loss(input.float(), target.float(), starting_class, num_classes)
total_loss = topograph_loss if not self.use_base_loss else base_loss + self.alpha * topograph_loss
return total_loss
[docs]
def compute_topopgraph_loss(self, input, target, starting_class, num_classes):
if self.thres_distr != ThresholdDistribution.NONE:
# Get the random probability to add to a class
match self.thres_distr:
case ThresholdDistribution.UNIFORM:
thres_noise = torch.rand(size=[input.shape[0], 1, 1], requires_grad=False, device=input.device) * (
self.thres_var / (num_classes - 1)
)
case ThresholdDistribution.GAUSSIAN:
thres_noise = torch.randn(
size=[input.shape[0], 1, 1], requires_grad=False, device=input.device
) * (self.thres_var / (num_classes - 1))
input_detached = input.detach().clone()
# get class that is being reinforced
noise_class = torch.randint(0, num_classes, (input.shape[0],), device=input.device)
neg_noise = thres_noise / (num_classes - 1)
# Randomly add noise to the input (we also add neg_noise bc we substract it later)
input_detached[:, noise_class] += thres_noise + neg_noise
input_detached -= neg_noise.unsqueeze(1)
# Re-attach the modified input to the computation graph without affecting gradients
modified_input = input_detached
else:
modified_input = input
# create argmax encoding using torch
argmax_preds = torch.argmax(modified_input, dim=1)
argmax_gts = torch.argmax(target, dim=1)
if self.sphere:
argmax_preds = F.pad(argmax_preds, (1, 1, 1, 1), value=0)
argmax_gts = F.pad(argmax_gts, (1, 1, 1, 1), value=0)
single_calc_inputs = []
# get critical nodes for each class
for class_index in range(starting_class, num_classes):
# binarize image
bin_preds = torch.zeros_like(argmax_preds)
bin_gts = torch.zeros_like(argmax_gts)
bin_preds[argmax_preds == class_index] = 1
bin_gts[argmax_gts == class_index] = 1
paired_imgs = bin_preds + 2 * bin_gts
diag_val_1, diag_val_2 = (-4, 16) if self.eight_connectivity else (16, -4)
paired_imgs[paired_imgs == 0] = diag_val_1
paired_imgs[paired_imgs == 3] = diag_val_2
h_diff, v_diff = new_compute_diffs(paired_imgs)
diagr, diagl, special_diag_r, special_diag_l = new_compute_diag_diffs(paired_imgs, th=7)
# move all to cpu
# TODO: Fix device handling
bin_preds = bin_preds.cpu().numpy()
bin_gts = bin_gts.cpu().numpy()
h_diff = h_diff.cpu().numpy()
v_diff = v_diff.cpu().numpy()
diagr = diagr.cpu().numpy()
diagl = diagl.cpu().numpy()
special_diag_r = special_diag_r.cpu().numpy()
special_diag_l = special_diag_l.cpu().numpy()
for i in range(input.shape[0]):
# create dict with function arguments
single_calc_input = {
"argmax_pred": bin_preds[i],
"argmax_gt": bin_gts[i],
"h_diff": h_diff[i],
"v_diff": v_diff[i],
"diagr": diagr[i],
"diagl": diagl[i],
"special_diagr": special_diag_r[i],
"special_diagl": special_diag_l[i],
"sample_no": i,
"use_c": self.use_c,
}
single_calc_inputs.append(single_calc_input)
relabel_masks = []
if self.num_processes > 1:
chunksize = (
len(single_calc_inputs) // self.num_processes if len(single_calc_inputs) > self.num_processes else 1
)
relabel_masks = self.pool.imap_unordered(single_sample_class_loss, single_calc_inputs, chunksize=chunksize)
else:
relabel_masks = map(single_sample_class_loss, single_calc_inputs)
# calculate the topological loss for each class
g_loss = torch.tensor(0.0, device=input.device)
for region_error_infos, sample_no, error_types in relabel_masks:
if self.aggregation == AggregationType.LEG:
# clone and detach the input to avoid gradients
input_mask_util = input.detach().clone()
# calculate the local maxima
local_maxima = F.max_pool2d(
input_mask_util[sample_no, :, :, :].unsqueeze(0), kernel_size=3, stride=1, padding=1
)
# calculate the difference to all neighboring pixels
local_maxima = local_maxima.squeeze(0)
# get the local maxima in the region
local_maxima_in_region = input_mask_util[sample_no, :, :, :] >= local_maxima
# put the local maxima on the gpu
local_maxima_in_region = local_maxima_in_region.bool().to(input.device)
# get the saddle points
saddle_mask = find_saddle_points_in_8_neighborhood(input_mask_util[sample_no, :, :, :].unsqueeze(0))
print("local maxima shape: ", local_maxima_in_region.shape)
print("saddle shape: ", saddle_mask.shape)
nominator_means = [] # torch.zeros(len(region_error_infos), device=input.device)
num_elements = [] # torch.zeros(len(region_error_infos), device=input.device)
found_or_not_dict = {
"saddle_yes": 0,
"saddle_no": 0,
"single_local_maxima": 0,
"multiple_local_maxima": 0,
"no_local_maxima": 0,
}
for i, region_indices in enumerate(region_error_infos):
if self.sphere:
region_indices = torch.tensor(region_indices)
region_indices -= 1
if self.aggregation != AggregationType.CE and self.aggregation != AggregationType.LEG:
class_indices = argmax_preds[sample_no, region_indices[0], region_indices[1]]
nominator = input[sample_no, class_indices, region_indices[0], region_indices[1]]
match self.aggregation:
case AggregationType.MEAN:
g_loss += nominator.mean()
case AggregationType.RMS:
g_loss += torch.sqrt((nominator**2).mean())
case AggregationType.SUM:
g_loss += nominator.sum()
case AggregationType.MAX:
g_loss += nominator.max()
case AggregationType.MIN:
g_loss += nominator.min()
case AggregationType.LEG:
class_indices = argmax_preds[sample_no, region_indices[0], region_indices[1]]
# local_maxima_indices = np.nonzero(local_maxima_in_region[class_index, :, :])
# local_maxima_mask = local_maxima_in_region[class_index, :, :]
# check for indices that are in both region_indices and local_maxima_indices
region_mask = torch.zeros(input.shape[1:], device=input.device, dtype=torch.bool)
region_mask[class_indices, region_indices[0], region_indices[1]] = True
# create the mean loss that is for all elements in the region
# mean_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].mean()
# nominator_means.append(mean_loss)
# num_elements.append(1)
# this is the isolated component case
if error_types[i] == 0 or error_types[i] == 1:
# 0 has a correct foreground neighbor and is wrongly background
# if error_types[i] == 0 and class_index == 0:
# base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].max()
# # 0 has a correct foreground neighbor and is wrongly foreground
# elif error_types[i] == 0 and class_index == 1:
# base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].min()
# # 1 has a correct background neighbor and is wrongly foreground
# elif error_types[i] == 1 and class_index == 1:
# base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].max()
# # 1 has a correct background neigbhor and is wrongly background
# elif error_types[i] == 1 and class_index == 0:
# base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].min()
base_loss = input[sample_no, class_index, region_indices[0], region_indices[1]].mean()
nominator_means.append(base_loss)
num_elements.append(1)
intersection_mask = region_mask & local_maxima_in_region
if not intersection_mask.any():
found_or_not_dict["no_local_maxima"] += 1
else:
# create a loss based on all but the largest local maxima in the region
leg_indices = torch.nonzero(intersection_mask, as_tuple=True)
if len(leg_indices[0]) <= 1:
found_or_not_dict["single_local_maxima"] += 1
else:
found_or_not_dict["multiple_local_maxima"] += 1
nominator = input[sample_no, leg_indices[0], leg_indices[1], leg_indices[2]]
# print(f"Found {len(nominator)} local maxima")
# remove the largest local maxima, if there are multiple just remove one
largest_local_maxima = nominator == nominator.max()
# get tthe index of the first largest local maxima
remaining = largest_local_maxima.nonzero()[1:]
non_max = nominator[nominator != nominator.max()].nonzero()
all_idcs = torch.cat([remaining, non_max])
nominator = nominator[all_idcs]
# check if there are still elements in the nominator, otherwise continue
if len(nominator) < 1:
continue
# nominator = nominator[nominator != nominator.max()]
# append every element of the nominator to the nominator_means, so that we can calculate the mean later
# change nominator such that it can be appended without having different shapes
else:
mean_local_maxima = nominator.mean()
nominator_means.append(mean_local_maxima)
num_elements.append(len(nominator))
# this is the connecectivity error
else:
# add a mean loss as a base loss
base_loss = input[sample_no, class_indices, region_indices[0], region_indices[1]].mean()
nominator_means.append(base_loss)
num_elements.append(1)
# create the loss based on the saddle points in the region
print("region mask shape", region_mask.shape)
intersection_mask = region_mask & saddle_mask
print("intersection mask shape", intersection_mask.shape)
if not intersection_mask.any():
found_or_not_dict["saddle_no"] += 1
else:
found_or_not_dict["saddle_yes"] += 1
# create a loss based on all saddle points in the region
leg_indices = torch.nonzero(intersection_mask, as_tuple=True)
nominator = input[sample_no, leg_indices[0], leg_indices[1], leg_indices[2]]
mean_saddle_points = nominator.mean()
nominator_means.append(mean_saddle_points)
num_elements.append(1)
# num_elements.append(1)
# print(f"Found {len(nominator)} saddle points")
# check if there are local maxima in the region
# intersection_mask = region_mask & local_maxima_mask
# print the number of local maxima in the region
# print(f"Found {len(torch.nonzero(intersection_mask, as_tuple=True)[0])} local maxima for connectivity error")
case AggregationType.CE:
masked_input = input[sample_no, :, region_indices[0], region_indices[1]].unsqueeze(0)
masked_target = target[sample_no, :, region_indices[0], region_indices[1]].unsqueeze(0)
g_loss += F.cross_entropy(masked_input, masked_target, reduction="mean")
case _:
raise ValueError(f"Invalid aggregation type: {self.aggregation}")
if self.aggregation == AggregationType.LEG:
for i, lo in enumerate(nominator_means):
g_loss += lo * num_elements[i]
if sum(num_elements) != 0:
g_loss *= len(nominator_means) / sum(num_elements)
# normalize by number of classes and batch size
g_loss /= input.shape[0] * (num_classes - starting_class)
return g_loss