Module laplace.utils.metrics

Classes

class RunningNLLMetric (ignore_index: int = -100)

NLL metrics that

Parameters

ignore_index : int, default = -100
which class label to ignore when computing the NLL loss

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

  • torchmetrics.metric.Metric
  • torch.nn.modules.module.Module
  • abc.ABC

Class variables

var is_differentiable : Optional[bool]
var higher_is_better : Optional[bool]
var full_state_update : Optional[bool]
var plot_lower_bound : Optional[float]
var plot_upper_bound : Optional[float]
var plot_legend_name : Optional[str]

Methods

def update(self, probs: torch.Tensor, targets: torch.Tensor) ‑> None

Parameters

probs : torch.Tensor
probability tensor of shape (…, n_classes)
targets : torch.Tensor
integer tensor of shape (…)
def compute(self) ‑> torch.Tensor

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.