Coverage for src/diffusionlab/vector_fields.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1import enum 

2from typing import Callable 

3 

4import torch 

5 

6from diffusionlab.utils import pad_shape_back 

7 

8 

9class VectorFieldType(enum.Enum): 

10 SCORE = enum.auto() 

11 X0 = enum.auto() 

12 EPS = enum.auto() 

13 V = enum.auto() 

14 

15 

16class VectorField: 

17 """ 

18 A wrapper around a function (x, t) -> f(x, t) which provides some extra data, 

19 namely the type of vector field the function f represents. 

20 

21 This class encapsulates a vector field function and its type, allowing for 

22 consistent handling of different vector field representations in diffusion models. 

23 

24 Attributes: 

25 f (Callable): A function that takes tensors x of shape (N, *D) and t of shape (N,) 

26 and returns a tensor of shape (N, *D). 

27 vector_field_type (VectorFieldType): The type of vector field the function represents. 

28 """ 

29 

30 def __init__( 

31 self, 

32 f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 

33 vector_field_type: VectorFieldType, 

34 ): 

35 """ 

36 Initialize a vector field wrapper. 

37 

38 Args: 

39 f (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): A function that takes tensors x of shape (N, *D) and t of shape (N,) 

40 and returns a tensor of shape (N, *D). 

41 vector_field_type (VectorFieldType): The type of vector field the function represents 

42 (SCORE, X0, EPS, or V). 

43 """ 

44 self.f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = f 

45 self.vector_field_type: VectorFieldType = vector_field_type 

46 

47 def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 

48 """ 

49 Call the wrapped vector field function. 

50 

51 Args: 

52 x (torch.Tensor): Input tensor of shape (N, *D) where N is the batch size and D represents the data dimensions. 

53 t (torch.Tensor): Time parameter tensor of shape (N,). 

54 

55 Returns: 

56 torch.Tensor: Output of the vector field function, of shape (N, *D). 

57 """ 

58 return self.f(x, t) 

59 

60 

61def convert_vector_field_type( 

62 x: torch.Tensor, 

63 fx: torch.Tensor, 

64 alpha: torch.Tensor, 

65 sigma: torch.Tensor, 

66 alpha_prime: torch.Tensor, 

67 sigma_prime: torch.Tensor, 

68 in_type: VectorFieldType, 

69 out_type: VectorFieldType, 

70) -> torch.Tensor: 

71 """ 

72 Converts the output of a vector field from one type to another. 

73 

74 Arguments: 

75 x (torch.Tensor): A tensor of shape (N, *D), where N is the batch size and D is the shape 

76 of the data (e.g., (C, H, W) for images, (D,) for vectors, or (N, D) for token sequences). 

77 fx (torch.Tensor): The output of the vector field f, of shape (N, *D). 

78 alpha (torch.Tensor): A tensor of shape (N,) representing the scale parameter. 

79 sigma (torch.Tensor): A tensor of shape (N,) representing the noise level parameter. 

80 alpha_prime (torch.Tensor): A tensor of shape (N,) representing the scale derivative parameter. 

81 sigma_prime (torch.Tensor): A tensor of shape (N,) representing the noise level derivative parameter. 

82 in_type (VectorFieldType): The type of the input vector field (e.g. Score, X0, Eps, V). 

83 out_type (VectorFieldType): The type of the output vector field. 

84 

85 Returns: 

86 torch.Tensor: The converted output of the vector field, of shape (N, *D). 

87 """ 

88 """ 

89 Derivation: 

90 ---------------------------- 

91 Define certain quantities: 

92 alpha_r = alpha' / alpha 

93 sigma_r = sigma' / sigma 

94 diff_r = sigma_r - alpha_r 

95 and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0.  

96 Under the data model  

97 (1) x := alpha * x0 + sigma * eps 

98 it holds that  

99 (2) x = alpha * E[x0 | x] + sigma * E[eps | x] 

100 Therefore  

101 (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha 

102 (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma 

103 Furthermore, from (1) it holds that 

104 (5) v := x' = alpha' * x0 + sigma' * eps, 

105 or in particular 

106 (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x] 

107 Using (3), (4), (6) it holds  

108 (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x]  

109 => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x] 

110 => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x] 

111 => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x] 

112 (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r) 

113 and, similarly, 

114 (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x])  

115 => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x] 

116 => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x] 

117 => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x] 

118 (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r) 

119 To connect the score function to the other types, we use Tweedie's formula: 

120 (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma). 

121 Therefore, from (11): 

122 (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha 

123 From (12): 

124 (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2 

125 From (11) and (4): 

126 (14) E[eps | x] = -sigma * score(x, alpha, sigma) 

127 From (14): 

128 (15) score(x, alpha, sigma) = -E[eps | x] / sigma 

129 From (14) and (7): 

130 (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma) 

131 From (16): 

132 (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r) 

133 """ 

134 alpha = pad_shape_back(alpha, x.shape) 

135 alpha_prime = pad_shape_back(alpha_prime, x.shape) 

136 sigma = pad_shape_back(sigma, x.shape) 

137 sigma_prime = pad_shape_back(sigma_prime, x.shape) 

138 alpha_ratio = alpha_prime / alpha 

139 sigma_ratio = sigma_prime / sigma 

140 ratio_diff = sigma_ratio - alpha_ratio 

141 converted_fx = fx 

142 

143 if in_type == VectorFieldType.SCORE: 

144 if out_type == VectorFieldType.X0: 

145 converted_fx = (x + sigma**2 * fx) / alpha # From equation (12) 

146 elif out_type == VectorFieldType.EPS: 

147 converted_fx = -sigma * fx # From equation (14) 

148 elif out_type == VectorFieldType.V: 

149 converted_fx = ( 

150 alpha_ratio * x - sigma**2 * ratio_diff * fx 

151 ) # From equation (16) 

152 

153 elif in_type == VectorFieldType.X0: 

154 if out_type == VectorFieldType.SCORE: 

155 converted_fx = (alpha * fx - x) / sigma**2 # From equation (13) 

156 elif out_type == VectorFieldType.EPS: 

157 converted_fx = (x - alpha * fx) / sigma # From equation (4) 

158 elif out_type == VectorFieldType.V: 

159 converted_fx = ( 

160 sigma_ratio * x - alpha * ratio_diff * fx 

161 ) # From equation (9) 

162 

163 elif in_type == VectorFieldType.EPS: 

164 if out_type == VectorFieldType.SCORE: 

165 converted_fx = -fx / sigma # From equation (15) 

166 elif out_type == VectorFieldType.X0: 

167 converted_fx = (x - sigma * fx) / alpha # From equation (3) 

168 elif out_type == VectorFieldType.V: 

169 converted_fx = ( 

170 alpha_ratio * x + sigma * ratio_diff * fx 

171 ) # From equation (7) 

172 

173 elif in_type == VectorFieldType.V: 

174 if out_type == VectorFieldType.SCORE: 

175 converted_fx = (alpha_ratio * x - fx) / ( 

176 sigma**2 * ratio_diff 

177 ) # From equation (17) 

178 elif out_type == VectorFieldType.X0: 

179 converted_fx = (sigma_ratio * x - fx) / ( 

180 alpha * ratio_diff 

181 ) # From equation (10) 

182 elif out_type == VectorFieldType.EPS: 

183 converted_fx = (fx - alpha_ratio * x) / ( 

184 sigma * ratio_diff 

185 ) # From equation (8) 

186 

187 return converted_fx