Module laplace.utils.utils

Functions

def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensor
def validate(laplace: BaseLaplace, val_loader: DataLoader, loss: torchmetrics.Metric | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels') ‑> float
def parameters_per_layer(model: nn.Module) ‑> list[int]

Get number of parameters per layer.

Parameters

model : torch.nn.Module
 

Returns

params_per_layer : list[int]
 
def invsqrt_precision(M: torch.Tensor) ‑> torch.Tensor

Compute M^{-0.5} as a tridiagonal matrix.

Parameters

M : torch.Tensor
 

Returns

M_invsqrt : torch.Tensor
 
def kron(t1: torch.Tensor, t2: torch.Tensor) ‑> torch.Tensor

Computes the Kronecker product between two tensors.

Parameters

t1 : torch.Tensor
 
t2 : torch.Tensor
 

Returns

kron_product : torch.Tensor
 
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) ‑> torch.Tensor

Add scalar value value to diagonal of X.

Parameters

X : torch.Tensor
 
value : torch.Tensor or float
 

Returns

X_add_scalar : torch.Tensor
 
def symeig(M: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]

Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.

Parameters

M : torch.Tensor
 

Returns

L : torch.Tensor
eigenvalues
W : torch.Tensor
eigenvectors
def block_diag(blocks: list[torch.Tensor]) ‑> torch.Tensor

Compose block-diagonal matrix of individual blocks.

Parameters

blocks : list[torch.Tensor]
 

Returns

M : torch.Tensor
 
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) ‑> torch.Tensor

Expand prior precision to match the shape of the model parameters.

Parameters

prior_prec : torch.Tensor 1-dimensional
prior precision
model : torch.nn.Module
torch model with parameters that are regularized by prior_prec

Returns

expanded_prior_prec : torch.Tensor
expanded prior precision has the same shape as model parameters

Classes

class SoDSampler (N, M, seed: int = 0)

Base class for all Samplers.

Every Sampler subclass has to provide an :meth:__iter__ method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:__len__ method that returns the length of the returned iterators.

Args

data_source : Dataset
This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it.

Example

>>> # xdoctest: +SKIP
>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

Note: The :meth:__len__ method isn't strictly required by

:class:~torch.utils.data.DataLoader, but is expected in any calculation involving the length of a :class:~torch.utils.data.DataLoader.

Ancestors

  • torch.utils.data.sampler.Sampler
  • typing.Generic