Module laplace.utils.utils
Functions
def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensor
def validate(laplace: BaseLaplace, val_loader: DataLoader, loss: torchmetrics.Metric | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels') ‑> float
def parameters_per_layer(model: nn.Module) ‑> list[int]
-
Get number of parameters per layer.
Parameters
model
:torch.nn.Module
Returns
params_per_layer
:list[int]
def invsqrt_precision(M: torch.Tensor) ‑> torch.Tensor
-
Compute
M^{-0.5}
as a tridiagonal matrix.Parameters
M
:torch.Tensor
Returns
M_invsqrt
:torch.Tensor
def kron(t1: torch.Tensor, t2: torch.Tensor) ‑> torch.Tensor
-
Computes the Kronecker product between two tensors.
Parameters
t1
:torch.Tensor
t2
:torch.Tensor
Returns
kron_product
:torch.Tensor
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) ‑> torch.Tensor
-
Add scalar value
value
to diagonal ofX
.Parameters
X
:torch.Tensor
value
:torch.Tensor
orfloat
Returns
X_add_scalar
:torch.Tensor
def symeig(M: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]
-
Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.
Parameters
M
:torch.Tensor
Returns
L
:torch.Tensor
- eigenvalues
W
:torch.Tensor
- eigenvectors
def block_diag(blocks: list[torch.Tensor]) ‑> torch.Tensor
-
Compose block-diagonal matrix of individual blocks.
Parameters
blocks
:list[torch.Tensor]
Returns
M
:torch.Tensor
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) ‑> torch.Tensor
-
Expand prior precision to match the shape of the model parameters.
Parameters
prior_prec
:torch.Tensor 1-dimensional
- prior precision
model
:torch.nn.Module
- torch model with parameters that are regularized by prior_prec
Returns
expanded_prior_prec
:torch.Tensor
- expanded prior precision has the same shape as model parameters
Classes
class SoDSampler (N, M, seed: int = 0)
-
Base class for all Samplers.
Every Sampler subclass has to provide an :meth:
__iter__
method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:__len__
method that returns the length of the returned iterators.Args
data_source
:Dataset
- This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it.
Example
>>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
Note: The :meth:
__len__
method isn't strictly required by:class:
~torch.utils.data.DataLoader
, but is expected in any calculation involving the length of a :class:~torch.utils.data.DataLoader
.Ancestors
- torch.utils.data.sampler.Sampler
- typing.Generic