Coverage for src/diffusionlab/samplers.py: 100%
179 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from typing import Callable, Tuple
3import torch
5from diffusionlab.diffusions import DiffusionProcess
6from diffusionlab.utils import pad_shape_back
7from diffusionlab.vector_fields import (
8 VectorField,
9 VectorFieldType,
10 convert_vector_field_type,
11)
14class Sampler:
15 """
16 Class for sampling from diffusion models using various vector field types.
18 A Sampler combines a diffusion process and a scheduler to generate samples from
19 a trained diffusion model. It handles both the forward process (adding noise) and
20 the reverse process (denoising/sampling).
22 The sampler supports different vector field types (SCORE, X0, EPS, V) and can perform
23 both stochastic and deterministic sampling.
25 Attributes:
26 diffusion_process (DiffusionProcess): The diffusion process defining the forward and reverse dynamics
27 is_stochastic (bool): Whether the reverse process is stochastic or deterministic
28 """
30 def __init__(
31 self,
32 diffusion_process: DiffusionProcess,
33 is_stochastic: bool,
34 ):
35 """
36 Initialize a sampler with a diffusion process and sampling strategy.
38 Args:
39 diffusion_process (DiffusionProcess): The diffusion process to use for sampling
40 is_stochastic (bool): Whether the reverse process should be stochastic
41 """
42 self.diffusion_process: DiffusionProcess = diffusion_process
43 self.is_stochastic: bool = is_stochastic
45 def sample(
46 self,
47 vector_field: VectorField,
48 x_init: torch.Tensor,
49 zs: torch.Tensor,
50 ts: torch.Tensor,
51 ) -> torch.Tensor:
52 """
53 Sample from the model using the reverse diffusion process.
55 This method generates a sample by iteratively applying the appropriate sampling step
56 function based on the vector field type.
58 Args:
59 vector_field (VectorField): The vector field model to use for sampling
60 x_init (torch.Tensor): The initial noisy tensor to start sampling from, of shape (N, *D) where N is the batch size and D represents the data dimensions
61 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D)
62 where L is the number of time steps
63 ts (torch.Tensor): The time schedule for sampling, of shape (L,)
64 where L is the number of time steps
66 Returns:
67 torch.Tensor: The generated sample, of shape (N, *D)
68 """
69 sample_step_function = self.get_sample_step_function(
70 vector_field.vector_field_type
71 )
72 x = x_init
73 for i in range(ts.shape[0] - 1):
74 x = sample_step_function(vector_field, x, zs, i, ts)
75 return x
77 def sample_trajectory(
78 self,
79 vector_field: VectorField,
80 x_init: torch.Tensor,
81 zs: torch.Tensor,
82 ts: torch.Tensor,
83 ) -> torch.Tensor:
84 """
85 Sample a trajectory from the model using the reverse diffusion process.
87 This method is similar to sample() but returns the entire trajectory of
88 intermediate samples rather than just the final sample.
90 Args:
91 vector_field (VectorField): The vector field model to use for sampling
92 x_init (torch.Tensor): The initial noisy tensor to start sampling from, of shape (N, *D)
93 where N is the batch size and D represents the data dimensions
94 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D)
95 where L is the number of time steps
96 ts (torch.Tensor): The time schedule for sampling, of shape (L,)
97 where L is the number of time steps
99 Returns:
100 torch.Tensor: The generated trajectory, of shape (L, N, *D)
101 where L is the number of time steps
102 """
103 sample_step_function = self.get_sample_step_function(
104 vector_field.vector_field_type
105 )
106 xs = [x_init]
107 x = x_init
108 for i in range(ts.shape[0] - 1):
109 x = sample_step_function(vector_field, x, zs, i, ts)
110 xs.append(x)
111 return torch.stack(xs)
113 def get_sample_step_function(
114 self, vector_field_type: VectorFieldType
115 ) -> Callable[
116 [VectorField, torch.Tensor, torch.Tensor, int, torch.Tensor], torch.Tensor
117 ]:
118 """
119 Get the appropriate sampling step function based on the vector field type.
121 This method selects the correct sampling function based on the vector field type
122 and whether sampling is stochastic or deterministic.
124 Args:
125 vector_field_type (VectorFieldType): The type of vector field being used
126 (SCORE, X0, EPS, or V)
128 Returns:
129 Callable: A function that performs one step of the sampling process with signature:
130 (vector_field, x, zs, idx, ts) -> next_x
131 where:
132 - vector_field is the model
133 - x is the current state tensor of shape (N, *D)
134 where N is the batch size and D represents the data dimensions
135 - zs is the noise tensors of shape (L-1, N, *D)
136 where L is the number of time steps
137 - idx is the current step index
138 - ts is the time steps tensor of shape (L,)
139 where L is the number of time steps
140 - next_x is the next state tensor of shape (N, *D)
141 """
142 f = None
143 if self.is_stochastic:
144 if vector_field_type == VectorFieldType.SCORE:
145 f = self.sample_step_stochastic_score
146 elif vector_field_type == VectorFieldType.X0:
147 f = self.sample_step_stochastic_x0
148 elif vector_field_type == VectorFieldType.EPS:
149 f = self.sample_step_stochastic_eps
150 elif vector_field_type == VectorFieldType.V:
151 f = self.sample_step_stochastic_v
152 else:
153 if vector_field_type == VectorFieldType.SCORE:
154 f = self.sample_step_deterministic_score
155 elif vector_field_type == VectorFieldType.X0:
156 f = self.sample_step_deterministic_x0
157 elif vector_field_type == VectorFieldType.EPS:
158 f = self.sample_step_deterministic_eps
159 elif vector_field_type == VectorFieldType.V:
160 f = self.sample_step_deterministic_v
161 return f
163 def _fix_t_shape(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
164 """
165 Reshape the time tensor to be compatible with the batch dimension of x.
167 Args:
168 x (torch.Tensor): The data tensor of shape (N, *D)
169 where N is the batch size and D represents the data dimensions
170 t (torch.Tensor): The time tensor to reshape, of shape (1, 1, ..., 1)
171 or any shape that can be broadcast to match the batch size
173 Returns:
174 torch.Tensor: The reshaped time tensor of shape (N,)
175 where N is the batch size of x
176 """
177 t = t.view((1,)).expand(x.shape[0])
178 return t
180 def sample_step_stochastic_score(
181 self,
182 score: VectorField,
183 x: torch.Tensor,
184 zs: torch.Tensor,
185 idx: int,
186 ts: torch.Tensor,
187 ) -> torch.Tensor:
188 """
189 Perform a stochastic sampling step using the score vector field.
191 This method implements one step of the stochastic reverse process using the score function.
193 Args:
194 score (VectorField): The score vector field model
195 x (torch.Tensor): The current state tensor, of shape (N, *D)
196 where N is the batch size and D represents the data dimensions
197 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D)
198 where L is the number of time steps
199 idx (int): The current step index
200 ts (torch.Tensor): The time steps tensor, of shape (L,)
201 where L is the number of time steps
203 Returns:
204 torch.Tensor: The next state tensor, of shape (N, *D)
205 """
206 raise NotImplementedError
208 def sample_step_deterministic_score(
209 self,
210 score: VectorField,
211 x: torch.Tensor,
212 zs: torch.Tensor,
213 idx: int,
214 ts: torch.Tensor,
215 ) -> torch.Tensor:
216 """
217 Perform one step of deterministic sampling using the score vector field.
219 Args:
220 score (VectorField): The score vector field model
221 x (torch.Tensor): The current state, of shape (N, *D)
222 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
223 idx (int): The current step index
224 ts (torch.Tensor): The time steps tensor, of shape (L,)
226 Returns:
227 torch.Tensor: The next state after one sampling step, of shape (N, *D)
228 """
229 raise NotImplementedError
231 def sample_step_stochastic_x0(
232 self,
233 x0: VectorField,
234 x: torch.Tensor,
235 zs: torch.Tensor,
236 idx: int,
237 ts: torch.Tensor,
238 ) -> torch.Tensor:
239 """
240 Perform one step of stochastic sampling using the x0 vector field.
242 Args:
243 x0 (VectorField): The x0 vector field model
244 x (torch.Tensor): The current state, of shape (N, *D)
245 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
246 idx (int): The current step index
247 ts (torch.Tensor): The time steps tensor, of shape (L,)
249 Returns:
250 torch.Tensor: The next state after one sampling step, of shape (N, *D)
251 """
252 raise NotImplementedError
254 def sample_step_deterministic_x0(
255 self,
256 x0: VectorField,
257 x: torch.Tensor,
258 zs: torch.Tensor,
259 idx: int,
260 ts: torch.Tensor,
261 ) -> torch.Tensor:
262 """
263 Perform one step of deterministic sampling using the x0 vector field.
265 Args:
266 x0 (VectorField): The x0 vector field model
267 x (torch.Tensor): The current state, of shape (N, *D)
268 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
269 idx (int): The current step index
270 ts (torch.Tensor): The time steps tensor, of shape (L,)
272 Returns:
273 torch.Tensor: The next state after one sampling step, of shape (N, *D)
274 """
275 raise NotImplementedError
277 def sample_step_stochastic_eps(
278 self,
279 eps: VectorField,
280 x: torch.Tensor,
281 zs: torch.Tensor,
282 idx: int,
283 ts: torch.Tensor,
284 ) -> torch.Tensor:
285 """
286 Perform one step of stochastic sampling using the eps vector field.
288 Args:
289 eps (VectorField): The eps vector field model
290 x (torch.Tensor): The current state, of shape (N, *D)
291 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
292 idx (int): The current step index
293 ts (torch.Tensor): The time steps tensor, of shape (L,)
295 Returns:
296 torch.Tensor: The next state after one sampling step, of shape (N, *D)
297 """
298 raise NotImplementedError
300 def sample_step_deterministic_eps(
301 self,
302 eps: VectorField,
303 x: torch.Tensor,
304 zs: torch.Tensor,
305 idx: int,
306 ts: torch.Tensor,
307 ) -> torch.Tensor:
308 """
309 Perform one step of deterministic sampling using the eps vector field.
311 Args:
312 eps (VectorField): The eps vector field model
313 x (torch.Tensor): The current state, of shape (N, *D)
314 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
315 idx (int): The current step index
316 ts (torch.Tensor): The time steps tensor, of shape (L,)
318 Returns:
319 torch.Tensor: The next state after one sampling step, of shape (N, *D)
320 """
321 raise NotImplementedError
323 def sample_step_stochastic_v(
324 self,
325 v: VectorField,
326 x: torch.Tensor,
327 zs: torch.Tensor,
328 idx: int,
329 ts: torch.Tensor,
330 ) -> torch.Tensor:
331 """
332 Perform one step of stochastic sampling using the v vector field.
334 Args:
335 v (VectorField): The velocity vector field model
336 x (torch.Tensor): The current state, of shape (N, *D)
337 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
338 idx (int): The current step index
339 ts (torch.Tensor): The time steps tensor, of shape (L,)
341 Returns:
342 torch.Tensor: The next state after one sampling step, of shape (N, *D)
343 """
344 raise NotImplementedError
346 def sample_step_deterministic_v(
347 self,
348 v: VectorField,
349 x: torch.Tensor,
350 zs: torch.Tensor,
351 idx: int,
352 ts: torch.Tensor,
353 ) -> torch.Tensor:
354 """
355 Perform one step of deterministic sampling using the v vector field.
357 Args:
358 v (VectorField): The velocity vector field model
359 x (torch.Tensor): The current state, of shape (N, *D)
360 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
361 idx (int): The current step index
362 ts (torch.Tensor): The time steps tensor, of shape (L,)
364 Returns:
365 torch.Tensor: The next state after one sampling step, of shape (N, *D)
366 """
367 raise NotImplementedError
370class EulerMaruyamaSampler(Sampler):
371 def _get_step_quantities(
372 self, zs: torch.Tensor, idx: int, ts: torch.Tensor
373 ) -> Tuple[
374 torch.Tensor,
375 torch.Tensor,
376 torch.Tensor,
377 torch.Tensor,
378 torch.Tensor,
379 torch.Tensor,
380 torch.Tensor,
381 torch.Tensor,
382 torch.Tensor,
383 torch.Tensor,
384 torch.Tensor,
385 ]:
386 """
387 Calculate various quantities needed for a sampling step.
389 This helper method computes time-dependent quantities used in the sampling
390 step functions.
392 Args:
393 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D)
394 where L is the number of time steps, N is the batch size, and D represents the data dimensions
395 idx (int): The current step index
396 ts (torch.Tensor): The time steps tensor, of shape (L,)
397 where L is the number of time steps
399 Returns:
400 Tuple: A tuple containing various time-dependent quantities:
401 - t (torch.Tensor): Current time, of shape (1*), where 1* is a tuple with the same number of dimensions as (N, D*)
402 - t1 (torch.Tensor): Next time, of shape (1*)
403 - alpha_t (torch.Tensor): Alpha at current time, of shape (1*)
404 - sigma_t (torch.Tensor): Sigma at current time, of shape (1*)
405 - alpha_prime_t (torch.Tensor): Derivative of alpha at current time, of shape (1*)
406 - sigma_prime_t (torch.Tensor): Derivative of sigma at current time, of shape (1*)
407 - dt (torch.Tensor): Time difference, of shape (1*)
408 - dwt (torch.Tensor): Scaled noise, of shape (N, *D)
409 - alpha_ratio_t (torch.Tensor): alpha_prime_t / alpha_t, of shape (1*)
410 - sigma_ratio_t (torch.Tensor): sigma_prime_t / sigma_t, of shape (1*)
411 - diff_ratio_t (torch.Tensor): sigma_ratio_t - alpha_ratio_t, of shape (1*)
412 """
413 x_shape = zs.shape[1:]
414 t = pad_shape_back(ts[idx], x_shape)
415 t1 = pad_shape_back(ts[idx + 1], x_shape)
416 dt = t1 - t
417 dwt = zs[idx] * torch.sqrt(-dt)
419 alpha_t = pad_shape_back(self.diffusion_process.alpha(ts[idx]), x_shape)
420 sigma_t = pad_shape_back(self.diffusion_process.sigma(ts[idx]), x_shape)
421 alpha_prime_t = pad_shape_back(
422 self.diffusion_process.alpha_prime(ts[idx]), x_shape
423 )
424 sigma_prime_t = pad_shape_back(
425 self.diffusion_process.sigma_prime(ts[idx]), x_shape
426 )
427 alpha_ratio_t = alpha_prime_t / alpha_t
428 sigma_ratio_t = sigma_prime_t / sigma_t
429 diff_ratio_t = sigma_ratio_t - alpha_ratio_t
430 return (
431 t,
432 t1,
433 alpha_t,
434 sigma_t,
435 alpha_prime_t,
436 sigma_prime_t,
437 dt,
438 dwt,
439 alpha_ratio_t,
440 sigma_ratio_t,
441 diff_ratio_t,
442 )
444 def sample_step_deterministic_score(
445 self,
446 score: VectorField,
447 x: torch.Tensor,
448 zs: torch.Tensor,
449 idx: int,
450 ts: torch.Tensor,
451 ) -> torch.Tensor:
452 """
453 Perform one step of deterministic sampling using the score vector field.
455 Args:
456 score (VectorField): The score vector field model
457 x (torch.Tensor): The current state, of shape (N, *D)
458 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
459 idx (int): The current step index
460 ts (torch.Tensor): The time steps tensor, of shape (L,)
462 Returns:
463 torch.Tensor: The next state after one sampling step, of shape (N, *D)
464 """
465 (
466 t,
467 t1,
468 alpha_t,
469 sigma_t,
470 alpha_prime_t,
471 sigma_prime_t,
472 dt,
473 dwt,
474 alpha_ratio_t,
475 sigma_ratio_t,
476 diff_ratio_t,
477 ) = self._get_step_quantities(zs, idx, ts)
478 drift_t = alpha_ratio_t * x - (sigma_t**2) * diff_ratio_t * score(
479 x, self._fix_t_shape(x, t)
480 )
481 return x + drift_t * dt
483 def sample_step_stochastic_score(
484 self,
485 score: VectorField,
486 x: torch.Tensor,
487 zs: torch.Tensor,
488 idx: int,
489 ts: torch.Tensor,
490 ) -> torch.Tensor:
491 """
492 Perform a stochastic sampling step using the score vector field.
494 This implements the stochastic reverse SDE for score-based models using the
495 Euler-Maruyama discretization method.
497 Args:
498 score (VectorField): The score vector field model
499 x (torch.Tensor): The current state tensor, of shape (N, *D)
500 where N is the batch size and D represents the data dimensions
501 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D)
502 where L is the number of time steps
503 idx (int): The current step index
504 ts (torch.Tensor): The time steps tensor, of shape (L,)
505 where L is the number of time steps
507 Returns:
508 torch.Tensor: The next state tensor, of shape (N, *D)
509 """
510 (
511 t,
512 t1,
513 alpha_t,
514 sigma_t,
515 alpha_prime_t,
516 sigma_prime_t,
517 dt,
518 dwt,
519 alpha_ratio_t,
520 sigma_ratio_t,
521 diff_ratio_t,
522 ) = self._get_step_quantities(zs, idx, ts)
524 # Compute score at current state
525 score_x_t = score(x, ts[idx])
527 # Compute drift and diffusion terms
528 drift = alpha_prime_t * x / alpha_t - sigma_t * sigma_prime_t * score_x_t
529 diffusion = sigma_prime_t * dwt
531 # Update state using Euler-Maruyama method
532 x_next = x + drift * dt + diffusion
534 return x_next
536 def sample_step_deterministic_x0(
537 self,
538 x0: VectorField,
539 x: torch.Tensor,
540 zs: torch.Tensor,
541 idx: int,
542 ts: torch.Tensor,
543 ) -> torch.Tensor:
544 """
545 Perform one step of deterministic sampling using the x0 vector field.
547 Args:
548 x0 (VectorField): The x0 vector field model
549 x (torch.Tensor): The current state, of shape (N, *D)
550 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
551 idx (int): The current step index
552 ts (torch.Tensor): The time steps tensor, of shape (L,)
554 Returns:
555 torch.Tensor: The next state after one sampling step, of shape (N, *D)
556 """
557 (
558 t,
559 t1,
560 alpha_t,
561 sigma_t,
562 alpha_prime_t,
563 sigma_prime_t,
564 dt,
565 dwt,
566 alpha_ratio_t,
567 sigma_ratio_t,
568 diff_ratio_t,
569 ) = self._get_step_quantities(zs, idx, ts)
570 drift_t = sigma_ratio_t * x - alpha_t * diff_ratio_t * x0(
571 x, self._fix_t_shape(x, t)
572 )
573 return x + drift_t * dt
575 def sample_step_stochastic_x0(
576 self,
577 x0: VectorField,
578 x: torch.Tensor,
579 zs: torch.Tensor,
580 idx: int,
581 ts: torch.Tensor,
582 ) -> torch.Tensor:
583 """
584 Perform one step of stochastic sampling using the x0 vector field.
586 Args:
587 x0 (VectorField): The x0 vector field model
588 x (torch.Tensor): The current state, of shape (N, *D)
589 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
590 idx (int): The current step index
591 ts (torch.Tensor): The time steps tensor, of shape (L,)
593 Returns:
594 torch.Tensor: The next state after one sampling step, of shape (N, *D)
595 """
596 (
597 t,
598 t1,
599 alpha_t,
600 sigma_t,
601 alpha_prime_t,
602 sigma_prime_t,
603 dt,
604 dwt,
605 alpha_ratio_t,
606 sigma_ratio_t,
607 diff_ratio_t,
608 ) = self._get_step_quantities(zs, idx, ts)
609 drift_t = (
610 alpha_ratio_t + 2 * diff_ratio_t
611 ) * x - 2 * alpha_t * diff_ratio_t * x0(x, self._fix_t_shape(x, t))
612 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t
613 return x + drift_t * dt + diffusion_t * dwt
615 def sample_step_deterministic_eps(
616 self,
617 eps: VectorField,
618 x: torch.Tensor,
619 zs: torch.Tensor,
620 idx: int,
621 ts: torch.Tensor,
622 ) -> torch.Tensor:
623 """
624 Perform one step of deterministic sampling using the eps vector field.
626 Args:
627 eps (VectorField): The eps vector field model
628 x (torch.Tensor): The current state, of shape (N, *D)
629 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
630 idx (int): The current step index
631 ts (torch.Tensor): The time steps tensor, of shape (L,)
633 Returns:
634 torch.Tensor: The next state after one sampling step, of shape (N, *D)
635 """
636 (
637 t,
638 t1,
639 alpha_t,
640 sigma_t,
641 alpha_prime_t,
642 sigma_prime_t,
643 dt,
644 dwt,
645 alpha_ratio_t,
646 sigma_ratio_t,
647 diff_ratio_t,
648 ) = self._get_step_quantities(zs, idx, ts)
649 drift_t = alpha_ratio_t * x + sigma_t * diff_ratio_t * eps(
650 x, self._fix_t_shape(x, t)
651 )
652 return x + drift_t * dt
654 def sample_step_stochastic_eps(
655 self,
656 eps: VectorField,
657 x: torch.Tensor,
658 zs: torch.Tensor,
659 idx: int,
660 ts: torch.Tensor,
661 ) -> torch.Tensor:
662 """
663 Perform one step of stochastic sampling using the eps vector field.
665 Args:
666 eps (VectorField): The eps vector field model
667 x (torch.Tensor): The current state, of shape (N, *D)
668 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
669 idx (int): The current step index
670 ts (torch.Tensor): The time steps tensor, of shape (L,)
672 Returns:
673 torch.Tensor: The next state after one sampling step, of shape (N, *D)
674 """
675 (
676 t,
677 t1,
678 alpha_t,
679 sigma_t,
680 alpha_prime_t,
681 sigma_prime_t,
682 dt,
683 dwt,
684 alpha_ratio_t,
685 sigma_ratio_t,
686 diff_ratio_t,
687 ) = self._get_step_quantities(zs, idx, ts)
688 drift_t = alpha_ratio_t * x + 2 * sigma_t * diff_ratio_t * eps(
689 x, self._fix_t_shape(x, t)
690 )
691 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t
692 return x + drift_t * dt + diffusion_t * dwt
694 def sample_step_deterministic_v(
695 self,
696 v: VectorField,
697 x: torch.Tensor,
698 zs: torch.Tensor,
699 idx: int,
700 ts: torch.Tensor,
701 ) -> torch.Tensor:
702 """
703 Perform one step of deterministic sampling using the v vector field.
705 Args:
706 v (VectorField): The velocity vector field model
707 x (torch.Tensor): The current state, of shape (N, *D)
708 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D)
709 idx (int): The current step index
710 ts (torch.Tensor): The time steps tensor, of shape (L,)
712 Returns:
713 torch.Tensor: The next state after one sampling step, of shape (N, *D)
714 """
715 (
716 t,
717 t1,
718 alpha_t,
719 sigma_t,
720 alpha_prime_t,
721 sigma_prime_t,
722 dt,
723 dwt,
724 alpha_ratio_t,
725 sigma_ratio_t,
726 diff_ratio_t,
727 ) = self._get_step_quantities(zs, idx, ts)
728 drift_t = v(x, self._fix_t_shape(x, t))
729 return x + drift_t * dt
731 def sample_step_stochastic_v(
732 self,
733 v: VectorField,
734 x: torch.Tensor,
735 zs: torch.Tensor,
736 idx: int,
737 ts: torch.Tensor,
738 ) -> torch.Tensor:
739 """
740 Perform one step of stochastic sampling using the v vector field.
742 Args:
743 v (VectorField): The velocity vector field model
744 x (torch.Tensor): The current state, of shape (N, *D)
745 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D)
746 idx (int): The current step index
747 ts (torch.Tensor): The time steps tensor, of shape (L,)
749 Returns:
750 torch.Tensor: The next state after one sampling step, of shape (N, *D)
751 """
752 (
753 t,
754 t1,
755 alpha_t,
756 sigma_t,
757 alpha_prime_t,
758 sigma_prime_t,
759 dt,
760 dwt,
761 alpha_ratio_t,
762 sigma_ratio_t,
763 diff_ratio_t,
764 ) = self._get_step_quantities(zs, idx, ts)
765 drift_t = -alpha_ratio_t * x + v(x, self._fix_t_shape(x, t))
766 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t
767 return x + drift_t * dt + diffusion_t * dwt
770class DDMSampler(Sampler):
771 """
772 Class for sampling from diffusion models using the DDPM/DDIM sampler.
773 """
775 def _convert_to_x0(
776 self,
777 x: torch.Tensor,
778 t: torch.Tensor,
779 fx: torch.Tensor,
780 fx_type: VectorFieldType,
781 ):
782 x0 = convert_vector_field_type(
783 x,
784 fx,
785 self.diffusion_process.alpha(t),
786 self.diffusion_process.sigma(t),
787 self.diffusion_process.alpha_prime(t),
788 self.diffusion_process.sigma_prime(t),
789 fx_type,
790 VectorFieldType.X0,
791 )
792 return x0
794 def _ddpm_step_x0_tensor(
795 self,
796 x0: torch.Tensor,
797 x: torch.Tensor,
798 zs: torch.Tensor,
799 idx: int,
800 ts: torch.Tensor,
801 ) -> torch.Tensor:
802 t = pad_shape_back(ts[idx], x.shape)
803 t1 = pad_shape_back(ts[idx + 1], x.shape)
804 alpha_t = self.diffusion_process.alpha(t)
805 sigma_t = self.diffusion_process.sigma(t)
806 alpha_t1 = self.diffusion_process.alpha(t1)
807 sigma_t1 = self.diffusion_process.sigma(t1)
809 r11 = (alpha_t / alpha_t1) * (sigma_t1 / sigma_t)
810 r12 = r11 * (sigma_t1 / sigma_t)
811 r22 = r12 * (alpha_t / alpha_t1)
813 mean = r12 * x + alpha_t1 * (1 - r22) * x0
814 std = sigma_t1 * (1 - r11**2) ** (1 / 2)
815 return mean + std * zs[idx]
817 def _ddim_step_x0_tensor(
818 self,
819 x0: torch.Tensor,
820 x: torch.Tensor,
821 zs: torch.Tensor,
822 idx: int,
823 ts: torch.Tensor,
824 ) -> torch.Tensor:
825 t = pad_shape_back(ts[idx], x.shape)
826 t1 = pad_shape_back(ts[idx + 1], x.shape)
827 alpha_t = self.diffusion_process.alpha(t)
828 sigma_t = self.diffusion_process.sigma(t)
829 alpha_t1 = self.diffusion_process.alpha(t1)
830 sigma_t1 = self.diffusion_process.sigma(t1)
832 r01 = sigma_t1 / sigma_t
833 r11 = (alpha_t / alpha_t1) * r01
835 mean = r01 * x + alpha_t1 * (1 - r11) * x0
836 return mean
838 def sample_step_deterministic_x0(
839 self,
840 x0: VectorField,
841 x: torch.Tensor,
842 zs: torch.Tensor,
843 idx: int,
844 ts: torch.Tensor,
845 ) -> torch.Tensor:
846 x0_value = x0(x, self._fix_t_shape(x, ts[idx]))
847 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts)
849 def sample_step_stochastic_x0(
850 self,
851 x0: VectorField,
852 x: torch.Tensor,
853 zs: torch.Tensor,
854 idx: int,
855 ts: torch.Tensor,
856 ) -> torch.Tensor:
857 x0_value = x0(x, self._fix_t_shape(x, ts[idx]))
858 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)
860 def sample_step_deterministic_score(
861 self,
862 score: VectorField,
863 x: torch.Tensor,
864 zs: torch.Tensor,
865 idx: int,
866 ts: torch.Tensor,
867 ) -> torch.Tensor:
868 t = pad_shape_back(ts[idx], x.shape)
869 score_value = score(x, self._fix_t_shape(x, t))
870 x0_value = self._convert_to_x0(x, t, score_value, VectorFieldType.SCORE)
871 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts)
873 def sample_step_stochastic_score(
874 self,
875 score: VectorField,
876 x: torch.Tensor,
877 zs: torch.Tensor,
878 idx: int,
879 ts: torch.Tensor,
880 ) -> torch.Tensor:
881 t = pad_shape_back(ts[idx], x.shape)
882 score_value = score(x, self._fix_t_shape(x, t))
883 x0_value = self._convert_to_x0(x, t, score_value, VectorFieldType.SCORE)
884 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)
886 def sample_step_deterministic_eps(
887 self,
888 eps: VectorField,
889 x: torch.Tensor,
890 zs: torch.Tensor,
891 idx: int,
892 ts: torch.Tensor,
893 ) -> torch.Tensor:
894 t = pad_shape_back(ts[idx], x.shape)
895 eps_value = eps(x, self._fix_t_shape(x, t))
896 x0_value = self._convert_to_x0(x, t, eps_value, VectorFieldType.EPS)
897 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)
899 def sample_step_stochastic_eps(
900 self,
901 eps: VectorField,
902 x: torch.Tensor,
903 zs: torch.Tensor,
904 idx: int,
905 ts: torch.Tensor,
906 ) -> torch.Tensor:
907 t = pad_shape_back(ts[idx], x.shape)
908 eps_value = eps(x, self._fix_t_shape(x, t))
909 x0_value = self._convert_to_x0(x, t, eps_value, VectorFieldType.EPS)
910 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)
912 def sample_step_deterministic_v(
913 self,
914 v: VectorField,
915 x: torch.Tensor,
916 zs: torch.Tensor,
917 idx: int,
918 ts: torch.Tensor,
919 ) -> torch.Tensor:
920 t = pad_shape_back(ts[idx], x.shape)
921 v_value = v(x, self._fix_t_shape(x, t))
922 x0_value = self._convert_to_x0(x, t, v_value, VectorFieldType.V)
923 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts)
925 def sample_step_stochastic_v(
926 self,
927 v: VectorField,
928 x: torch.Tensor,
929 zs: torch.Tensor,
930 idx: int,
931 ts: torch.Tensor,
932 ) -> torch.Tensor:
933 t = pad_shape_back(ts[idx], x.shape)
934 v_value = v(x, self._fix_t_shape(x, t))
935 x0_value = self._convert_to_x0(x, t, v_value, VectorFieldType.V)
936 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)