Package laplace

Laplace

Main

The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. The library documentation is available at https://aleximmer.github.io/Laplace.

There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:

@article{daxberger2021laplace,
  title={Laplace Redux--Effortless Bayesian Deep Learning},
  author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
          and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
  journal={arXiv preprint arXiv:2106.14806},
  year={2021}
}

Setup

We assume python3.8 since the package was developed with that version. To install laplace with pip, run the following:

pip install laplace-torch

For development purposes, clone the repository and then install:

# or after cloning the repository for development
pip install -e .
# run tests
pip install -e .[tests]
pytest tests/

Structure

The laplace package consists of two main components:

  1. The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', and 'diag'). This results in six currently available options: FullLaplace, KronLaplace, DiagLaplace, and the corresponding last-layer variations FullLLLaplace, KronLLLaplace, and DiagLLLaplace, which are all subclasses of laplace.LLLaplace. All of these can be conveniently accessed via the laplace.Laplace function.
  2. The backends in laplace.curvature which provide access to Hessian approximations of the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.feature_extractor) and effectively dealing with Kronecker factors (laplace.matrix).

Extendability

To extend the laplace package, new BaseLaplace subclasses can be designed, for example, a block-diagonal structure or subset-of-weights Laplace. Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian approximations to the Laplace approximations. For example, currently the curvature.BackPackInterface based on BackPACK and curvature.AsdlInterface based on ASDL are available. The AsdlInterface provides a Kronecker factored empirical Fisher while the BackPackInterface does not, and only the BackPackInterface provides access to Hessian approximations for a regression (MSELoss) loss function.

Example usage

Post-hoc prior precision tuning of last-layer LA

In the following example, a pre-trained model is loaded, then the Laplace approximation is fit to the training data, and the prior precision is optimized with cross-validation 'CV'. After that, the resulting LA is used for prediction with the 'probit' predictive for classification.

from laplace import Laplace

# pre-trained model
model = load_map_model()  

# User-specified LA flavor
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='diag')
la.fit(train_loader)
la.optimize_prior_precision(method='CV', val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx='probit')

Differentiating the log marginal likelihood w.r.t. hyperparameters

The marginal likelihood can be used for model selection and is differentiable for continuous hyperparameters like the prior precision or observation noise. Here, we fit the library default, KFAC last-layer LA and differentiate the log marginal likelihood.

from laplace import Laplace

# Un- or pre-trained model
model = load_model()  

# Default to recommended last-layer KFAC LA:
la = Laplace(model, likelihood='regression')
la.fit(train_loader)

# ML w.r.t. prior precision and observation noise
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()

Documentation

The documentation is available here or can be generated and/or viewed locally:

# assuming the repository was cloned
pip install -e .[docs]
# create docs and write to html
bash update_docs.sh
# .. or serve the docs directly
pdoc --http 0.0.0.0:8080 laplace --template-dir template

References

This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].

Full example: post-hoc optimization of the marginal likelihood and prediction

Sinusoidal toy data

We show how the marginal likelihood can be used after training a MAP network on a simple sinusoidal regression task. Subsequently, we use the optimized LA to predict which provides uncertainty on top of the MAP prediction. First, we set up the training data for the problem with observation noise \(\sigma=0.3\):

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from laplace import Laplace

n_epochs = 1000
batch_size = 150  # full batch
true_sigma_noise = 0.3

# create simple sinusoid data set
X_train = (torch.rand(150) * 8).unsqueeze(-1)
y_train = torch.sin(X_train) + torch.randn_like(X_train) * true_sigma_noise
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size)
X_test = torch.linspace(-5, 13, 500).unsqueeze(-1)  # +-5 on top of the training X-range

Training a MAP

We now use pytorch to train a neural network with single hidden layer and Tanh activation. This is standard so nothing new here, yet:

# create and train MAP model
model = torch.nn.Sequential(torch.nn.Linear(1, 50),
                            torch.nn.Tanh(),
                            torch.nn.Linear(50, 1))

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=1e-2)
for i in range(n_epochs):
    for X, y in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(X), y)
        loss.backward()
        optimizer.step()

Fitting and optimizing the Laplace approximation using empirical Bayes

With the MAP-trained model at hand, we can estimate the prior precision and observation noise using empirical Bayes after training. The Laplace() method is called to construct a LA for 'regression' with 'all' weights. As default Laplace() returns a Kronecker factored LA. We fit the LA to the training data and initialize log_prior and log_sigma. Using Adam, we minimize the negative log marginal likelihood for n_epochs.

la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
la.fit(train_loader)
log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
for i in range(n_epochs):
    hyper_optimizer.zero_grad()
    neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
    neg_marglik.backward()
    hyper_optimizer.step()

The obtained observation noise is close to the ground truth with a value of \(\sigma \approx 0.28\) without the need for any validation data. The resulting prior precision is \(\delta \approx 0.18\).

Bayesian predictive

Lastly, we compare the MAP prediction to the obtained LA prediction. For LA, we have a closed-form predictive distribution on the output \(f\) which is a Gaussian \(\mathcal{N}(f(x;\theta_{MAP}), \mathbb{V}[f] + \sigma^2)\):

x = X_test.flatten().cpu().numpy()
f_mu, f_var = la(X_test)
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)

In comparison to the MAP, the predictive shows useful uncertainties:

import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True,
                               figsize=(4.5, 2.8))
ax1.set_title('MAP')
ax1.scatter(X_train.flatten(), y_train.flatten(), alpha=0.7, color='tab:orange')
ax1.plot(x, f_mu, color='black', label='$f_{MAP}$')
ax1.legend()

ax2.set_title('LA')
ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.7, color='tab:orange')
ax2.plot(x, f_mu, label='$\mathbb{E}[f]$')
ax2.fill_between(x, f_mu-pred_std*2, f_mu+pred_std*2, 
                 alpha=0.3, color='tab:blue', label='$2\sqrt{\mathbb{V}\,[f]}$')
ax2.legend()
ax1.set_ylim([-4, 6])
ax1.set_xlim([x.min(), x.max()])
ax2.set_xlim([x.min(), x.max()])
ax1.set_ylabel('$y$')
ax1.set_xlabel('$x$')
ax2.set_xlabel('$x$')
plt.tight_layout()
plt.show()

:align: center

Sub-modules

laplace.baselaplace
laplace.curvature
laplace.feature_extractor
laplace.laplace
laplace.lllaplace
laplace.matrix
laplace.utils

Functions

def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure='kron', *args, **kwargs)

Simplified Laplace access using strings instead of different classes.

Parameters

model : torch.nn.Module
 
likelihood : {'classification', 'regression'}
 
subset_of_weights : {'last_layer', 'all'}, default='last_layer'
subset of weights to consider for inference
hessian_structure : {'diag', 'kron', 'full'}, default='kron'
structure of the Hessian approximation

Returns

laplace : BaseLaplace
chosen subclass of BaseLaplace instantiated with additional arguments

Classes

class BaseLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)

Baseclass for all Laplace approximations in this library. Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

Parameters

model : torch.nn.Module
 
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Ancestors

  • abc.ABC

Subclasses

Instance variables

var backend
var log_likelihood

Compute log likelihood on the training data after .fit() has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.

Returns

log_likelihood : torch.Tensor
 
var scatter

Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization: scatter = (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .

Returns

[type] [description]

var log_det_prior_precision

Compute log determinant of the prior precision \log \det P_0

Returns

log_det : torch.Tensor
 
var log_det_posterior_precision

Compute log determinant of the posterior precision \log \det P which depends on the subclasses structure used for the Hessian approximation.

Returns

log_det : torch.Tensor
 
var log_det_ratio

Compute the log determinant ratio, a part of the log marginal likelihood. \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0

Returns

log_det_ratio : torch.Tensor
 
var prior_precision_diag

Obtain the diagonal prior precision p_0 constructed from either a scalar, layer-wise, or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 
var prior_mean
var prior_precision
var sigma_noise
var posterior_precision

Compute or return the posterior precision P.

Returns

posterior_prec : torch.Tensor
 

Methods

def fit(self, train_loader)

Fit the local Laplace approximation at the parameters of the model.

Parameters

train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y); train_loader.dataset needs to be set to access N, size of the data set
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)

Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in prior_precision and sigma_noise if these have gradients enabled. By passing prior_precision or sigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.

Parameters

prior_precision : torch.Tensor, optional
prior precision if should be changed from current prior_precision value
sigma_noise : [type], optional
observation noise standard deviation if should be changed

Returns

log_marglik : torch.Tensor
 
def predictive(self, x, pred_type='glm', link_approx='mc', n_samples=100)
def predictive_samples(self, x, pred_type='glm', n_samples=100)

Sample from the posterior predictive on input data x. Can be used, for example, for Thompson sampling.

Parameters

x : torch.Tensor
input data (batch_size, input_shape)
pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
n_samples : int
number of samples

Returns

samples : torch.Tensor
samples (n_samples, batch_size, output_shape)
def functional_variance(self, Jacs)

Compute functional variance for the 'glm' predictive: f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} the output covariance matrix \mathcal{J} P^{-1} \mathcal{J}^T .

Parameters

Jacs : torch.Tensor
Jacobians of model output wrt parameters (batch, outputs, parameters)

Returns

f_var : torch.Tensor
output covariance (batch, outputs, outputs)
def sample(self, n_samples=100)

Sample from the Laplace posterior approximation, i.e., \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).

Parameters

n_samples : int, default=100
number of samples
def optimize_prior_precision(self, method='marglik', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, pred_type='glm', link_approx='probit', n_samples=100, verbose=False)

Optimize the prior precision post-hoc using the method specified by the user.

Parameters

method : {'marglik', 'CV'}, default='marglik'
specifies how the prior precision should be optimized.
n_steps : int, default=100
the number of gradient descent steps to take.
lr : float, default=1e-1
the learning rate to use for gradient descent.
init_prior_prec : float, default=1.0
initial prior precision before the first optimization step.
val_loader : torch.data.utils.DataLoader, default=None
DataLoader for the validation set; each iterate is a training batch (X, y).
loss : callable, default=get_nll
loss function to use for CV.
log_prior_prec_min : float, default=-4
lower bound of gridsearch interval for CV.
log_prior_prec_max : float, default=4
upper bound of gridsearch interval for CV.
grid_size : int, default=100
number of values to consider inside the gridsearch interval for CV.
pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
link_approx : {'mc', 'probit', 'bridge'}, default='probit'
how to approximate the classification link function for the 'glm'. For pred_type='nn', only 'mc' is possible.
n_samples : int, default=100
number of samples for link_approx='mc'.
verbose : bool, default=False
if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
class FullLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)

Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See BaseLaplace for the full interface.

Ancestors

Subclasses

Instance variables

var posterior_scale

Posterior scale (square root of the covariance), i.e., P^{-\frac{1}{2}}.

Returns

scale : torch.tensor
(parameters, parameters)
var posterior_covariance

Posterior covariance, i.e., P^{-1}.

Returns

covariance : torch.tensor
(parameters, parameters)
var posterior_precision

Posterior precision P.

Returns

precision : torch.tensor
(parameters, parameters)

Inherited members

class KronLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, damping=False, **backend_kwargs)

Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See BaseLaplace for the full interface and see Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by setting damping=True.

Ancestors

Subclasses

Instance variables

var posterior_precision

Kronecker factored Posterior precision P.

Returns

precision : KronDecomposed
 
var prior_precision

Inherited members

class DiagLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)

Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See BaseLaplace for the full interface.

Ancestors

Subclasses

Instance variables

var posterior_precision

Diagonal posterior precision p.

Returns

precision : torch.tensor
(parameters)
var posterior_scale

Diagonal posterior scale \sqrt{p^{-1}}.

Returns

precision : torch.tensor
(parameters)
var posterior_variance

Diagonal posterior variance p^{-1}.

Returns

precision : torch.tensor
(parameters)

Inherited members

class LLLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)

Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See BaseLaplace for the full interface.

A Laplace approximation is represented by a MAP which is given by the model parameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.

Parameters

model : torch.nn.Module or FeatureExtractor
 
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
backend : subclasses of CurvatureInterface
backend for access to curvature/Hessian approximations
last_layer_name : str, default=None
name of the model's last layer, if None it will be determined automatically
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.

Ancestors

Subclasses

Instance variables

var prior_precision_diag

Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.

Returns

prior_precision_diag : torch.Tensor
 

Inherited members

class FullLLLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)

Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen backend parameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. See FullLaplace, LLLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members

class KronLLLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, damping=False, **backend_kwargs)

Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See KronLaplace, LLLaplace, and BaseLaplace for the full interface and see Kron and KronDecomposed for the structure of the Kronecker factors. Kron is used to aggregate factors by summing up and KronDecomposed is used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use of damping is possible by initializing or setting damping=True.

Ancestors

Inherited members

class DiagLLLaplace (model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)

Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See DiagLaplace, LLLaplace, and BaseLaplace for the full interface.

Ancestors

Inherited members