Coverage for src/diffusionlab/vector_fields.py: 100%
54 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1import enum
2from typing import Callable
4import torch
6from diffusionlab.utils import pad_shape_back
9class VectorFieldType(enum.Enum):
10 SCORE = enum.auto()
11 X0 = enum.auto()
12 EPS = enum.auto()
13 V = enum.auto()
16class VectorField:
17 """
18 A wrapper around a function (x, t) -> f(x, t) which provides some extra data,
19 namely the type of vector field the function f represents.
21 This class encapsulates a vector field function and its type, allowing for
22 consistent handling of different vector field representations in diffusion models.
24 Attributes:
25 f (Callable): A function that takes tensors x of shape (N, *D) and t of shape (N,)
26 and returns a tensor of shape (N, *D).
27 vector_field_type (VectorFieldType): The type of vector field the function represents.
28 """
30 def __init__(
31 self,
32 f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
33 vector_field_type: VectorFieldType,
34 ):
35 """
36 Initialize a vector field wrapper.
38 Args:
39 f (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): A function that takes tensors x of shape (N, *D) and t of shape (N,)
40 and returns a tensor of shape (N, *D).
41 vector_field_type (VectorFieldType): The type of vector field the function represents
42 (SCORE, X0, EPS, or V).
43 """
44 self.f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = f
45 self.vector_field_type: VectorFieldType = vector_field_type
47 def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
48 """
49 Call the wrapped vector field function.
51 Args:
52 x (torch.Tensor): Input tensor of shape (N, *D) where N is the batch size and D represents the data dimensions.
53 t (torch.Tensor): Time parameter tensor of shape (N,).
55 Returns:
56 torch.Tensor: Output of the vector field function, of shape (N, *D).
57 """
58 return self.f(x, t)
61def convert_vector_field_type(
62 x: torch.Tensor,
63 fx: torch.Tensor,
64 alpha: torch.Tensor,
65 sigma: torch.Tensor,
66 alpha_prime: torch.Tensor,
67 sigma_prime: torch.Tensor,
68 in_type: VectorFieldType,
69 out_type: VectorFieldType,
70) -> torch.Tensor:
71 """
72 Converts the output of a vector field from one type to another.
74 Arguments:
75 x (torch.Tensor): A tensor of shape (N, *D), where N is the batch size and D is the shape
76 of the data (e.g., (C, H, W) for images, (D,) for vectors, or (N, D) for token sequences).
77 fx (torch.Tensor): The output of the vector field f, of shape (N, *D).
78 alpha (torch.Tensor): A tensor of shape (N,) representing the scale parameter.
79 sigma (torch.Tensor): A tensor of shape (N,) representing the noise level parameter.
80 alpha_prime (torch.Tensor): A tensor of shape (N,) representing the scale derivative parameter.
81 sigma_prime (torch.Tensor): A tensor of shape (N,) representing the noise level derivative parameter.
82 in_type (VectorFieldType): The type of the input vector field (e.g. Score, X0, Eps, V).
83 out_type (VectorFieldType): The type of the output vector field.
85 Returns:
86 torch.Tensor: The converted output of the vector field, of shape (N, *D).
87 """
88 """
89 Derivation:
90 ----------------------------
91 Define certain quantities:
92 alpha_r = alpha' / alpha
93 sigma_r = sigma' / sigma
94 diff_r = sigma_r - alpha_r
95 and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0.
96 Under the data model
97 (1) x := alpha * x0 + sigma * eps
98 it holds that
99 (2) x = alpha * E[x0 | x] + sigma * E[eps | x]
100 Therefore
101 (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha
102 (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma
103 Furthermore, from (1) it holds that
104 (5) v := x' = alpha' * x0 + sigma' * eps,
105 or in particular
106 (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x]
107 Using (3), (4), (6) it holds
108 (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x]
109 => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x]
110 => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x]
111 => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x]
112 (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r)
113 and, similarly,
114 (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x])
115 => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x]
116 => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x]
117 => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x]
118 (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r)
119 To connect the score function to the other types, we use Tweedie's formula:
120 (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma).
121 Therefore, from (11):
122 (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha
123 From (12):
124 (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2
125 From (11) and (4):
126 (14) E[eps | x] = -sigma * score(x, alpha, sigma)
127 From (14):
128 (15) score(x, alpha, sigma) = -E[eps | x] / sigma
129 From (14) and (7):
130 (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma)
131 From (16):
132 (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r)
133 """
134 alpha = pad_shape_back(alpha, x.shape)
135 alpha_prime = pad_shape_back(alpha_prime, x.shape)
136 sigma = pad_shape_back(sigma, x.shape)
137 sigma_prime = pad_shape_back(sigma_prime, x.shape)
138 alpha_ratio = alpha_prime / alpha
139 sigma_ratio = sigma_prime / sigma
140 ratio_diff = sigma_ratio - alpha_ratio
141 converted_fx = fx
143 if in_type == VectorFieldType.SCORE:
144 if out_type == VectorFieldType.X0:
145 converted_fx = (x + sigma**2 * fx) / alpha # From equation (12)
146 elif out_type == VectorFieldType.EPS:
147 converted_fx = -sigma * fx # From equation (14)
148 elif out_type == VectorFieldType.V:
149 converted_fx = (
150 alpha_ratio * x - sigma**2 * ratio_diff * fx
151 ) # From equation (16)
153 elif in_type == VectorFieldType.X0:
154 if out_type == VectorFieldType.SCORE:
155 converted_fx = (alpha * fx - x) / sigma**2 # From equation (13)
156 elif out_type == VectorFieldType.EPS:
157 converted_fx = (x - alpha * fx) / sigma # From equation (4)
158 elif out_type == VectorFieldType.V:
159 converted_fx = (
160 sigma_ratio * x - alpha * ratio_diff * fx
161 ) # From equation (9)
163 elif in_type == VectorFieldType.EPS:
164 if out_type == VectorFieldType.SCORE:
165 converted_fx = -fx / sigma # From equation (15)
166 elif out_type == VectorFieldType.X0:
167 converted_fx = (x - sigma * fx) / alpha # From equation (3)
168 elif out_type == VectorFieldType.V:
169 converted_fx = (
170 alpha_ratio * x + sigma * ratio_diff * fx
171 ) # From equation (7)
173 elif in_type == VectorFieldType.V:
174 if out_type == VectorFieldType.SCORE:
175 converted_fx = (alpha_ratio * x - fx) / (
176 sigma**2 * ratio_diff
177 ) # From equation (17)
178 elif out_type == VectorFieldType.X0:
179 converted_fx = (sigma_ratio * x - fx) / (
180 alpha * ratio_diff
181 ) # From equation (10)
182 elif out_type == VectorFieldType.EPS:
183 converted_fx = (fx - alpha_ratio * x) / (
184 sigma * ratio_diff
185 ) # From equation (8)
187 return converted_fx