Coverage for src/diffusionlab/distributions/base.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:48 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:48 -0700
1from typing import Any, Dict, Tuple, Callable
3import torch
5from diffusionlab.diffusions import DiffusionProcess
6from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type
9class Distribution:
10 """
11 Base class for all distributions.
13 This class should be subclassed by other distributions when you want to use ground truth
14 scores, denoisers, noise predictors, or velocity estimators.
16 Each distribution implementation provides methods to compute various vector fields
17 related to the diffusion process, such as denoising (x0), noise prediction (eps),
18 velocity estimation (v), and score estimation.
19 """
21 @classmethod
22 def validate_hparams(cls, dist_hparams: Dict[str, Any]) -> None:
23 """
24 Validate the hyperparameters for the distribution.
26 Arguments:
27 dist_hparams: A dictionary of hyperparameters for the distribution.
29 Returns:
30 None
32 Throws:
33 AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure.
34 """
35 assert len(dist_hparams) == 0
37 @classmethod
38 def get_vector_field_method(
39 cls, vector_field_type: VectorFieldType
40 ) -> Callable[[torch.Tensor, torch.Tensor, DiffusionProcess, Dict[str, torch.Tensor], Dict[str, Any]], torch.Tensor]:
41 """
42 Returns the appropriate method to compute the specified vector field type.
44 Arguments:
45 vector_field_type: The type of vector field to compute.
47 Returns:
48 A method that computes the specified vector field, with signature:
49 (x_t, t, diffusion_process, batched_dist_params, dist_hparams) -> tensor
51 Raises:
52 ValueError: If the vector field type is not recognized.
53 """
54 if vector_field_type == VectorFieldType.X0:
55 return cls.x0
56 elif vector_field_type == VectorFieldType.EPS:
57 return cls.eps
58 elif vector_field_type == VectorFieldType.V:
59 return cls.v
60 elif vector_field_type == VectorFieldType.SCORE:
61 return cls.score
62 else:
63 raise ValueError(f"Unrecognized vector field type: {vector_field_type}")
65 @classmethod
66 def validate_params(
67 cls, possibly_batched_dist_params: Dict[str, torch.Tensor]
68 ) -> None:
69 """
70 Validate the parameters for the distribution.
72 Arguments:
73 possibly_batched_dist_params: A dictionary of parameters for the distribution.
74 Each value is a PyTorch tensor, possibly having a batch dimension.
76 Returns:
77 None
79 Throws:
80 AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure.
81 """
82 assert len(possibly_batched_dist_params) == 0
84 @classmethod
85 def x0(
86 cls,
87 x_t: torch.Tensor,
88 t: torch.Tensor,
89 diffusion_process: DiffusionProcess,
90 batched_dist_params: Dict[str, torch.Tensor],
91 dist_hparams: Dict[str, Any],
92 ) -> torch.Tensor:
93 """
94 Computes the denoiser E[x_0 | x_t] at a given time t and input x_t, under the data model
96 x_t = alpha(t) * x_0 + sigma(t) * eps
98 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I).
100 Arguments:
101 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data.
102 t: The time tensor, of shape (N, ).
103 diffusion_process: The diffusion process whose forward and reverse dynamics determine
104 the time-evolution of the vector fields corresponding to the distribution.
105 batched_dist_params: A dictionary of batched parameters for the distribution.
106 Each parameter is of shape (N, *P) where P is the shape of the parameter.
107 dist_hparams: A dictionary of hyperparameters for the distribution.
109 Returns:
110 The prediction of x_0, of shape (N, *D).
112 Note:
113 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
114 """
115 raise NotImplementedError
117 @classmethod
118 def eps(
119 cls,
120 x_t: torch.Tensor,
121 t: torch.Tensor,
122 diffusion_process: DiffusionProcess,
123 batched_dist_params: Dict[str, torch.Tensor],
124 dist_hparams: Dict[str, Any],
125 ) -> torch.Tensor:
126 """
127 Computes the noise predictor E[eps | x_t] at a given time t and input x_t, under the data model
129 x_t = alpha(t) * x_0 + sigma(t) * eps
131 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I).
132 This is stateless for the same reason as the denoiser method.
134 Arguments:
135 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data.
136 t: The time tensor, of shape (N, ).
137 diffusion_process: The diffusion process whose forward and reverse dynamics determine
138 the time-evolution of the vector fields corresponding to the distribution.
139 batched_dist_params: A dictionary of batched parameters for the distribution.
140 Each parameter is of shape (N, *P) where P is the shape of the parameter.
141 dist_hparams: A dictionary of hyperparameters for the distribution.
143 Returns:
144 The prediction of eps, of shape (N, *D).
146 Note:
147 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
148 """
149 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)
150 eps_hat = convert_vector_field_type(
151 x_t,
152 x0_hat,
153 diffusion_process.alpha(t),
154 diffusion_process.sigma(t),
155 diffusion_process.alpha_prime(t),
156 diffusion_process.sigma_prime(t),
157 in_type=VectorFieldType.X0,
158 out_type=VectorFieldType.EPS,
159 )
160 return eps_hat
162 @classmethod
163 def v(
164 cls,
165 x_t: torch.Tensor,
166 t: torch.Tensor,
167 diffusion_process: DiffusionProcess,
168 batched_dist_params: Dict[str, torch.Tensor],
169 dist_hparams: Dict[str, Any],
170 ) -> torch.Tensor:
171 """
172 Computes the velocity estimator E[d/dt x_t | x_t] at a given time t and input x_t, under the data model
174 x_t = alpha(t) * x_0 + sigma(t) * eps
176 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I).
177 This is stateless for the same reason as the denoiser method.
179 Arguments:
180 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data.
181 t: The time tensor, of shape (N, ).
182 diffusion_process: The diffusion process whose forward and reverse dynamics determine
183 the time-evolution of the vector fields corresponding to the distribution.
184 batched_dist_params: A dictionary of batched parameters for the distribution.
185 Each parameter is of shape (N, *P) where P is the shape of the parameter.
186 dist_hparams: A dictionary of hyperparameters for the distribution.
188 Returns:
189 The prediction of d/dt x_t, of shape (N, *D).
191 Note:
192 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
193 """
194 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)
195 v_hat = convert_vector_field_type(
196 x_t,
197 x0_hat,
198 diffusion_process.alpha(t),
199 diffusion_process.sigma(t),
200 diffusion_process.alpha_prime(t),
201 diffusion_process.sigma_prime(t),
202 in_type=VectorFieldType.X0,
203 out_type=VectorFieldType.V,
204 )
205 return v_hat
207 @classmethod
208 def score(
209 cls,
210 x_t: torch.Tensor,
211 t: torch.Tensor,
212 diffusion_process: DiffusionProcess,
213 batched_dist_params: Dict[str, torch.Tensor],
214 dist_hparams: Dict[str, Any],
215 ) -> torch.Tensor:
216 """
217 Computes the score estimator grad_x log p(x_t, t) at a given time t and input x_t, under the data model
219 x_t = alpha(t) * x_0 + sigma(t) * eps
221 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I).
222 This is stateless for the same reason as the denoiser method.
224 Arguments:
225 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data.
226 t: The time tensor, of shape (N, ).
227 diffusion_process: The diffusion process whose forward and reverse dynamics determine
228 the time-evolution of the vector fields corresponding to the distribution.
229 batched_dist_params: A dictionary of batched parameters for the distribution.
230 Each parameter is of shape (N, *P) where P is the shape of the parameter.
231 dist_hparams: A dictionary of hyperparameters for the distribution.
233 Returns:
234 The prediction of grad_x log p(x_t, t), of shape (N, *D).
236 Note:
237 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
238 """
239 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams)
240 score_hat = convert_vector_field_type(
241 x_t,
242 x0_hat,
243 diffusion_process.alpha(t),
244 diffusion_process.sigma(t),
245 diffusion_process.alpha_prime(t),
246 diffusion_process.sigma_prime(t),
247 in_type=VectorFieldType.X0,
248 out_type=VectorFieldType.SCORE,
249 )
250 return score_hat
252 @classmethod
253 def sample(
254 cls,
255 N: int,
256 dist_params: Dict[str, torch.Tensor],
257 dist_hparams: Dict[str, Any],
258 ) -> Tuple[torch.Tensor, Any]:
259 """
260 Draws N i.i.d. samples from the data distribution.
262 Arguments:
263 N: The number of samples to draw.
264 dist_params: A dictionary of parameters for the distribution.
265 dist_hparams: A dictionary of hyperparameters for the distribution.
267 Returns:
268 A tuple (samples, metadata), where samples is a tensor of shape (N, *D) and metadata is any additional information.
269 For example, if the distribution has labels, the metadata is a tensor of shape (N, ) containing the labels.
270 Note that the samples are always placed on the CPU.
271 """
272 raise NotImplementedError
274 @staticmethod
275 def batch_dist_params(
276 N: int, dist_params: Dict[str, torch.Tensor]
277 ) -> Dict[str, torch.Tensor]:
278 """
279 Add a batch dimension to the distribution parameters.
281 Arguments:
282 N: The number of samples in the batch.
283 dist_params: A dictionary of parameters for the distribution.
285 Returns:
286 A dictionary of parameters for the distribution, with a batch dimension added.
287 """
288 return {k: v.unsqueeze(0).expand(N, *v.shape) for k, v in dist_params.items()}