Module laplace.lllaplace
Classes
class LLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See
BaseLaplace
for the full interface.A Laplace approximation is represented by a MAP which is given by the
model
parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Parameters
model
:torch.nn.Module
orFeatureExtractor
likelihood
:Likelihood
or{'classification', 'regression'}
- determines the log likelihood Hessian approximation
sigma_noise
:torch.Tensor
orfloat
, default=1
- observation noise for the regression setting; must be 1 for classification
prior_precision
:torch.Tensor
orfloat
, default=1
- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean
:torch.Tensor
orfloat
, default=0
- prior mean of a Gaussian prior, useful for continual learning
temperature
:float
, default=1
- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop
:bool
, default=False
- whether to enable backprop to the input
x
through the Laplace predictive. Useful for e.g. Bayesian optimization. feature_reduction
:FeatureReduction
orstr
, optional, default=None
- when the last-layer
features
is a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape(batch_size, seq_len, embd_dim)
. But the last layer maps(batch_size, embd_dim)
to(batch_size, n_classes)
. Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are{'pick_first', 'pick_last', 'average'}
. dict_key_x
:str
, default='input_ids'
- The dictionary key under which the input tensor
x
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. dict_key_y
:str
, default='labels'
- The dictionary key under which the target tensor
y
is stored. Only has effect when the model takes aMutableMapping
as the input. Useful for Huggingface LLM models. backend
:subclasses
ofCurvatureInterface
- backend for access to curvature/Hessian approximations
last_layer_name
:str
, default=None
- name of the model's last layer, if None it will be determined automatically
backend_kwargs
:dict
, default=None
- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
Ancestors
Subclasses
Instance variables
var prior_precision_diag : torch.Tensor
-
Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.
Returns
prior_precision_diag
:torch.Tensor
Methods
def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None
-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader
:torch.data.utils.DataLoader
- each iterate is a training batch, either
(X, y)
tensors or a dict-like object containing keys as expressed byself.dict_key_x
andself.dict_key_y
.train_loader.dataset
needs to be set to access N, size of the data set. override
:bool
, default=True
- whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar
:bool
, default=False
def functional_variance_fast(self, X)
-
Should be overriden if there exists a trick to make this fast!
Parameters
X
:torch.Tensor
ofshape (batch_size, input_dim)
Returns
f_var_diag
:torch.Tensor
ofshape (batch_size, num_outputs)
- Corresponding to the diagonal of the covariance matrix of the outputs
def state_dict(self) ‑> dict[str, typing.Any]
def load_state_dict(self, state_dict: dict[str, Any]) ‑> None
Inherited members
class FullLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backend
parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeFullLaplace
,LLLaplace
, andBaseLaplace
for the full interface.Ancestors
Inherited members
class KronLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See
KronLaplace
,LLLaplace
, andBaseLaplace
for the full interface and seeKron
andKronDecomposed
for the structure of the Kronecker factors.Kron
is used to aggregate factors by summing up andKronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use ofdamping
is possible by initializing or settingdamping=True
.Ancestors
Inherited members
class DiagLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)
-
Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
DiagLaplace
,LLLaplace
, andBaseLaplace
for the full interface.Ancestors
Inherited members
class FunctionalLLLaplace (model: nn.Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', last_layer_name: str = None, backend: type[CurvatureInterface] | None = laplace.curvature.backpack.BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0)
-
Here not much changes in terms of GP inference compared to FunctionalLaplace class. Since now we treat only the last layer probabilistically and the rest of the network is used as a "fixed feature extractor", that means that the X \in \mathbb{R}^{M \times D} in GP inference changes to \tilde{X} \in \mathbb{R}^{M \times l_{n-1}} , where l_{n-1} is the dimension of the output of the penultimate NN layer.
See
FunctionalLaplace
for the full interface.Ancestors
Methods
def fit(self, train_loader: DataLoader) ‑> None
-
Fit the Laplace approximation of a GP posterior.
Parameters
train_loader
:torch.data.utils.DataLoader
train_loader.dataset
needs to be set to access N, size of the data settrain_loader.batch_size
needs to be set to access b batch_size
def state_dict(self) ‑> dict
def load_state_dict(self, state_dict: dict)
Inherited members