Module laplace.baselaplace
Classes
class BaseLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
-
Baseclass for all Laplace approximations in this library. Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.
A Laplace approximation is represented by a MAP which is given by the
model
parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Parameters
model
:torch.nn.Module
likelihood
:{'classification', 'regression'}
- determines the log likelihood Hessian approximation
sigma_noise
:torch.Tensor
orfloat
, default=1
- observation noise for the regression setting; must be 1 for classification
prior_precision
:torch.Tensor
orfloat
, default=1
- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean
:torch.Tensor
orfloat
, default=0
- prior mean of a Gaussian prior, useful for continual learning
temperature
:float
, default=1
- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend
:subclasses
ofCurvatureInterface
- backend for access to curvature/Hessian approximations
backend_kwargs
:dict
, default=None
- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
Ancestors
- abc.ABC
Subclasses
- DiagLaplace
- FullLaplace
- KronLaplace
- laplace.lllaplace.LLLaplace
Instance variables
var backend
var log_likelihood
-
Compute log likelihood on the training data after
.fit()
has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.Returns
log_likelihood
:torch.Tensor
var scatter
-
Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization:
scatter
= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .Returns
[type] [description]
var log_det_prior_precision
-
Compute log determinant of the prior precision \log \det P_0
Returns
log_det
:torch.Tensor
var log_det_posterior_precision
-
Compute log determinant of the posterior precision \log \det P which depends on the subclasses structure used for the Hessian approximation.
Returns
log_det
:torch.Tensor
var log_det_ratio
-
Compute the log determinant ratio, a part of the log marginal likelihood. \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
Returns
log_det_ratio
:torch.Tensor
var prior_precision_diag
-
Obtain the diagonal prior precision p_0 constructed from either a scalar, layer-wise, or diagonal prior precision.
Returns
prior_precision_diag
:torch.Tensor
var prior_mean
var prior_precision
var sigma_noise
var posterior_precision
-
Compute or return the posterior precision P.
Returns
posterior_prec
:torch.Tensor
Methods
def fit(self, train_loader)
-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader
:torch.data.utils.DataLoader
- each iterate is a training batch (X, y);
train_loader.dataset
needs to be set to access N, size of the data set
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)
-
Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in
prior_precision
andsigma_noise
if these have gradients enabled. By passingprior_precision
orsigma_noise
, the current value is overwritten. This is useful for iterating on the log marginal likelihood.Parameters
prior_precision
:torch.Tensor
, optional- prior precision if should be changed from current
prior_precision
value sigma_noise
:[type]
, optional- observation noise standard deviation if should be changed
Returns
log_marglik
:torch.Tensor
def predictive(self, x, pred_type='glm', link_approx='mc', n_samples=100)
def predictive_samples(self, x, pred_type='glm', n_samples=100)
-
Sample from the posterior predictive on input data
x
. Can be used, for example, for Thompson sampling.Parameters
x
:torch.Tensor
- input data
(batch_size, input_shape)
pred_type
:{'glm', 'nn'}
, default='glm'
- type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
n_samples
:int
- number of samples
Returns
samples
:torch.Tensor
- samples
(n_samples, batch_size, output_shape)
def functional_variance(self, Jacs)
-
Compute functional variance for the
'glm'
predictive:f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T
, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} the output covariance matrix \mathcal{J} P^{-1} \mathcal{J}^T .Parameters
Jacs
:torch.Tensor
- Jacobians of model output wrt parameters
(batch, outputs, parameters)
Returns
f_var
:torch.Tensor
- output covariance
(batch, outputs, outputs)
def sample(self, n_samples=100)
-
Sample from the Laplace posterior approximation, i.e., \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
Parameters
n_samples
:int
, default=100
- number of samples
def optimize_prior_precision(self, method='marglik', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, pred_type='glm', link_approx='probit', n_samples=100, verbose=False)
-
Optimize the prior precision post-hoc using the
method
specified by the user.Parameters
method
:{'marglik', 'CV'}
, default='marglik'
- specifies how the prior precision should be optimized.
n_steps
:int
, default=100
- the number of gradient descent steps to take.
lr
:float
, default=1e-1
- the learning rate to use for gradient descent.
init_prior_prec
:float
, default=1.0
- initial prior precision before the first optimization step.
val_loader
:torch.data.utils.DataLoader
, default=None
- DataLoader for the validation set; each iterate is a training batch (X, y).
loss
:callable
, default=get_nll
- loss function to use for CV.
log_prior_prec_min
:float
, default=-4
- lower bound of gridsearch interval for CV.
log_prior_prec_max
:float
, default=4
- upper bound of gridsearch interval for CV.
grid_size
:int
, default=100
- number of values to consider inside the gridsearch interval for CV.
pred_type
:{'glm', 'nn'}
, default='glm'
- type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
link_approx
:{'mc', 'probit', 'bridge'}
, default='probit'
- how to approximate the classification link function for the
'glm'
. Forpred_type='nn'
, only'mc'
is possible. n_samples
:int
, default=100
- number of samples for
link_approx='mc'
. verbose
:bool
, default=False
- if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
class FullLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
-
Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backend
parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeBaseLaplace
for the full interface.Ancestors
- BaseLaplace
- abc.ABC
Subclasses
Instance variables
var posterior_scale
-
Posterior scale (square root of the covariance), i.e., P^{-\frac{1}{2}}.
Returns
scale
:torch.tensor
(parameters, parameters)
var posterior_covariance
-
Posterior covariance, i.e., P^{-1}.
Returns
covariance
:torch.tensor
(parameters, parameters)
var posterior_precision
-
Posterior precision P.
Returns
precision
:torch.tensor
(parameters, parameters)
Inherited members
class KronLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, damping=False, **backend_kwargs)
-
Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See
BaseLaplace
for the full interface and seeKron
andKronDecomposed
for the structure of the Kronecker factors.Kron
is used to aggregate factors by summing up andKronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by settingdamping=True
.Ancestors
- BaseLaplace
- abc.ABC
Subclasses
Instance variables
var posterior_precision
var prior_precision
Inherited members
class DiagLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
-
Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
BaseLaplace
for the full interface.Ancestors
- BaseLaplace
- abc.ABC
Subclasses
Instance variables
var posterior_precision
-
Diagonal posterior precision p.
Returns
precision
:torch.tensor
(parameters)
var posterior_scale
-
Diagonal posterior scale \sqrt{p^{-1}}.
Returns
precision
:torch.tensor
(parameters)
var posterior_variance
-
Diagonal posterior variance p^{-1}.
Returns
precision
:torch.tensor
(parameters)
Inherited members