Source code for topolosses.losses.utils

import torch
from torch import Tensor

import enum
import numpy as np
from typing import List, Optional
import torch.nn.functional as F

# to only expose this function in the spynx documentation
__all__ = ["compute_default_dice_loss"]


[docs] def compute_default_dice_loss( input: torch.Tensor, target: torch.Tensor, reduce_axis: Optional[List[int]] = None, smooth: float = 1e-5, ) -> torch.Tensor: """Function to compute the (weighted) Dice loss with default settings for the default base loss Args: input (torch.Tensor): The predicted segmentation map with shape (N, C, ...), where N is batch size, C is the number of classes. target (torch.Tensor): The ground truth segmentation map with the same shape as `input`. reduce_axis (List[int]): The axes along which to reduce the loss computation. To decide whether to sum the intersection and union areas over the batch dimension before the dividing. Returns: torch.Tensor: The Dice loss as a scalar """ if reduce_axis == None: reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, dim=reduce_axis) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o dice = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth) dice = torch.mean(dice) return dice
# TODO for all these enums and function check if they are used accross losses or only in topograph, move accordingly class AggregationType(enum.Enum): MEAN = "mean" SUM = "sum" MAX = "max" MIN = "min" CE = "ce" RMS = "rms" LEG = "leg" class ThresholdDistribution(enum.Enum): UNIFORM = "uniform" GAUSSIAN = "gaussian" NONE = "none" def new_compute_diffs(paired_img_batch: torch.Tensor): h_diff = paired_img_batch[:, :-1, :] - paired_img_batch[:, 1:, :] v_diff = paired_img_batch[:, :, :-1] - paired_img_batch[:, :, 1:] h_diff = h_diff != 0 v_diff = v_diff != 0 return h_diff, v_diff def new_compute_diag_diffs(paired_img_batch: torch.Tensor, th: int = 11): weight = torch.tensor([[1, -1], [-1, 1]], device=paired_img_batch.device).unsqueeze(0).unsqueeze(0) diag_connections = F.conv2d(paired_img_batch.unsqueeze(1).float(), weight.float()).squeeze(1) diagr = diag_connections > th diagl = diag_connections < -th special_case_r = torch.logical_or(diag_connections == 7, diag_connections == 4) special_case_l = torch.logical_or(diag_connections == -7, diag_connections == -4) # special_case_r = torch.zeros_like(diagr) # special_case_l = torch.zeros_like(diagl) return diagr, diagl, special_case_r, special_case_l def fill_adj_matr(adj_matrix, h_edges, v_edges): adj_matrix[tuple(h_edges)] = True adj_matrix[tuple(h_edges[::-1])] = True # Add the transposed edges adj_matrix[tuple(v_edges)] = True adj_matrix[tuple(v_edges[::-1])] = True # Add the transposed edges # remove self loops np.fill_diagonal(adj_matrix, False) return adj_matrix # TODO Check if only used in betti matching if so move accordingly class FiltrationType(enum.Enum): SUPERLEVEL = "superlevel" SUBLEVEL = "sublevel" BOTHLEVELS = "bothlevels"