Module laplace.utils.metrics
Classes
class RunningNLLMetric (ignore_index: int = -100)
-
NLL metrics that
Parameters
ignore_index
:int
, default= -100
- which class label to ignore when computing the NLL loss
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torchmetrics.metric.Metric
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var is_differentiable : Optional[bool]
var higher_is_better : Optional[bool]
var full_state_update : Optional[bool]
var plot_lower_bound : Optional[float]
var plot_upper_bound : Optional[float]
var plot_legend_name : Optional[str]
Methods
def update(self, probs: torch.Tensor, targets: torch.Tensor) ‑> None
-
Parameters
probs
:torch.Tensor
- probability tensor of shape (…, n_classes)
targets
:torch.Tensor
- integer tensor of shape (…)
def compute(self) ‑> torch.Tensor
-
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.