Coverage for src/diffusionlab/utils.py: 100%
33 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from math import prod
2from typing import Callable, cast
4import torch
5from torch.func import jacrev # type: ignore # this is a good function, somehow not exposed by Pytorch
8def scalar_derivative(
9 f: Callable[[torch.Tensor], torch.Tensor],
10) -> Callable[[torch.Tensor], torch.Tensor]:
11 """
12 Computes the scalar derivative of a function f: R -> R.
13 Returns a function f_prime: R -> R that computes the derivative of f at a given point,
14 and is broadcastable with the same broadcast rules as f.
16 Arguments:
17 f: A function whose input is a scalar (0-dimensional Pytorch tensor) and whose output is
18 a scalar, that can be broadcasted to a tensor of any shape.
20 Returns:
21 f_prime: A function that computes the derivative of f at a given point, and is broadcastable
22 with the same broadcast rules as f. For input of shape (N,), output will be of shape (N,).
23 """
24 df = jacrev(f)
26 def f_prime(x: torch.Tensor) -> torch.Tensor:
27 dfx = cast(torch.Tensor, df(x))
28 if dfx.ndim > 1:
29 x_size = prod(x.shape)
30 dfx = dfx.reshape(x_size, x_size)
31 dfx = dfx.diagonal(dim1=0, dim2=1)
32 dfx = dfx.reshape(x.shape)
33 return dfx
35 return f_prime
38def pad_shape_front(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
39 """
40 Pads the front of a tensor with singleton dimensions until it can broadcast with target_shape.
42 Arguments:
43 x: A tensor of any shape, say (P, Q, R, S).
44 target_shape: A shape to which x can broadcast, say (M, N, O, P, Q, R, S).
46 Returns:
47 x_padded: The tensor x reshaped to be broadcastable with target_shape, say (1, 1, 1, P, Q, R, S).
48 The returned tensor has shape (1, ..., 1, *x.shape) with enough leading 1s to match
49 the dimensionality of target_shape.
51 Note:
52 This function does not use any additional memory, returning a different view of the same underlying data.
53 """
54 ndim_target = len(target_shape)
55 ndim_x = x.ndim
56 expand_dims = (1,) * max(ndim_target - ndim_x, 0)
57 return x.view(*expand_dims, *x.shape)
60def pad_shape_back(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
61 """
62 Pads the back of a tensor with singleton dimensions until it can broadcast with target_shape.
64 Arguments:
65 x: A tensor of any shape, say (P, Q, R, S).
66 target_shape: A shape to which x can broadcast, say (P, Q, R, S, T, U, V).
68 Returns:
69 x_padded: The tensor x reshaped to be broadcastable with target_shape, say (P, Q, R, S, 1, 1, 1).
70 The returned tensor has shape (*x.shape, 1, ..., 1) with enough trailing 1s to match
71 the dimensionality of target_shape.
73 Note:
74 This function does not use any additional memory, returning a different view of the same underlying data.
75 """
76 ndim_target = len(target_shape)
77 ndim_x = x.ndim
78 expand_dims = (1,) * max(ndim_target - ndim_x, 0)
79 return x.view(*x.shape, *expand_dims)
82def logdet_pd(A: torch.Tensor) -> torch.Tensor:
83 """
84 Computes the log-determinant of a positive-definite matrix A, broadcasting over A.
86 Arguments:
87 A: A positive-definite matrix of shape (..., N, N) where ... represents any number of batch dimensions.
89 Returns:
90 logdet_A: The log-determinant of A of shape (...) with the same batch dimensions as A.
91 """
92 L = torch.linalg.cholesky(A)
93 eigvals = torch.diagonal(L, dim1=-2, dim2=-1)
94 return 2 * torch.sum(torch.log(eigvals), dim=-1)
97def sqrt_psd(A: torch.Tensor) -> torch.Tensor:
98 """
99 Computes the matrix square root of a positive-semidefinite matrix A, broadcasting over A.
101 Arguments:
102 A: A positive-semidefinite matrix of shape (..., N, N) where ... represents any number of batch dimensions.
104 Returns:
105 sqrt_A: The matrix square root of A of shape (..., N, N) with the same shape as A.
106 """
107 L, Q = torch.linalg.eigh(A)
108 L_new = torch.where(L > 0, torch.sqrt(L), torch.zeros_like(L))
109 return Q @ torch.diag_embed(L_new) @ Q.transpose(-2, -1)