Coverage for src/diffusionlab/utils.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from math import prod 

2from typing import Callable, cast 

3 

4import torch 

5from torch.func import jacrev # type: ignore # this is a good function, somehow not exposed by Pytorch 

6 

7 

8def scalar_derivative( 

9 f: Callable[[torch.Tensor], torch.Tensor], 

10) -> Callable[[torch.Tensor], torch.Tensor]: 

11 """ 

12 Computes the scalar derivative of a function f: R -> R. 

13 Returns a function f_prime: R -> R that computes the derivative of f at a given point, 

14 and is broadcastable with the same broadcast rules as f. 

15 

16 Arguments: 

17 f: A function whose input is a scalar (0-dimensional Pytorch tensor) and whose output is 

18 a scalar, that can be broadcasted to a tensor of any shape. 

19 

20 Returns: 

21 f_prime: A function that computes the derivative of f at a given point, and is broadcastable 

22 with the same broadcast rules as f. For input of shape (N,), output will be of shape (N,). 

23 """ 

24 df = jacrev(f) 

25 

26 def f_prime(x: torch.Tensor) -> torch.Tensor: 

27 dfx = cast(torch.Tensor, df(x)) 

28 if dfx.ndim > 1: 

29 x_size = prod(x.shape) 

30 dfx = dfx.reshape(x_size, x_size) 

31 dfx = dfx.diagonal(dim1=0, dim2=1) 

32 dfx = dfx.reshape(x.shape) 

33 return dfx 

34 

35 return f_prime 

36 

37 

38def pad_shape_front(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: 

39 """ 

40 Pads the front of a tensor with singleton dimensions until it can broadcast with target_shape. 

41 

42 Arguments: 

43 x: A tensor of any shape, say (P, Q, R, S). 

44 target_shape: A shape to which x can broadcast, say (M, N, O, P, Q, R, S). 

45 

46 Returns: 

47 x_padded: The tensor x reshaped to be broadcastable with target_shape, say (1, 1, 1, P, Q, R, S). 

48 The returned tensor has shape (1, ..., 1, *x.shape) with enough leading 1s to match 

49 the dimensionality of target_shape. 

50 

51 Note: 

52 This function does not use any additional memory, returning a different view of the same underlying data. 

53 """ 

54 ndim_target = len(target_shape) 

55 ndim_x = x.ndim 

56 expand_dims = (1,) * max(ndim_target - ndim_x, 0) 

57 return x.view(*expand_dims, *x.shape) 

58 

59 

60def pad_shape_back(x: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: 

61 """ 

62 Pads the back of a tensor with singleton dimensions until it can broadcast with target_shape. 

63 

64 Arguments: 

65 x: A tensor of any shape, say (P, Q, R, S). 

66 target_shape: A shape to which x can broadcast, say (P, Q, R, S, T, U, V). 

67 

68 Returns: 

69 x_padded: The tensor x reshaped to be broadcastable with target_shape, say (P, Q, R, S, 1, 1, 1). 

70 The returned tensor has shape (*x.shape, 1, ..., 1) with enough trailing 1s to match 

71 the dimensionality of target_shape. 

72 

73 Note: 

74 This function does not use any additional memory, returning a different view of the same underlying data. 

75 """ 

76 ndim_target = len(target_shape) 

77 ndim_x = x.ndim 

78 expand_dims = (1,) * max(ndim_target - ndim_x, 0) 

79 return x.view(*x.shape, *expand_dims) 

80 

81 

82def logdet_pd(A: torch.Tensor) -> torch.Tensor: 

83 """ 

84 Computes the log-determinant of a positive-definite matrix A, broadcasting over A. 

85 

86 Arguments: 

87 A: A positive-definite matrix of shape (..., N, N) where ... represents any number of batch dimensions. 

88 

89 Returns: 

90 logdet_A: The log-determinant of A of shape (...) with the same batch dimensions as A. 

91 """ 

92 L = torch.linalg.cholesky(A) 

93 eigvals = torch.diagonal(L, dim1=-2, dim2=-1) 

94 return 2 * torch.sum(torch.log(eigvals), dim=-1) 

95 

96 

97def sqrt_psd(A: torch.Tensor) -> torch.Tensor: 

98 """ 

99 Computes the matrix square root of a positive-semidefinite matrix A, broadcasting over A. 

100 

101 Arguments: 

102 A: A positive-semidefinite matrix of shape (..., N, N) where ... represents any number of batch dimensions. 

103 

104 Returns: 

105 sqrt_A: The matrix square root of A of shape (..., N, N) with the same shape as A. 

106 """ 

107 L, Q = torch.linalg.eigh(A) 

108 L_new = torch.where(L > 0, torch.sqrt(L), torch.zeros_like(L)) 

109 return Q @ torch.diag_embed(L_new) @ Q.transpose(-2, -1)