Coverage for src/diffusionlab/diffusions.py: 100%
27 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from typing import Any, Callable
2import torch
4from diffusionlab.utils import pad_shape_back, scalar_derivative
7class DiffusionProcess:
8 """
9 Base class for implementing various diffusion processes.
11 A diffusion process defines how data evolves over time when noise is added according to
12 specific dynamics. This class provides a framework for implementing different types of
13 diffusion processes used in generative modeling.
15 The diffusion is parameterized by two functions:
16 - alpha(t): Controls how much of the original signal is preserved at time t
17 - sigma(t): Controls how much noise is added at time t
19 The forward process is defined as: x_t = alpha(t) * x_0 + sigma(t) * eps, where:
20 - x_0 is the original data
21 - x_t is the noised data at time t
22 - eps is random noise sampled from a standard Gaussian distribution
23 - t is the diffusion time parameter, typically in range [0, 1]
25 Attributes:
26 alpha (Callable): Function that determines signal preservation at time t, differentiable,
27 maps any tensor to tensor of same shape
28 sigma (Callable): Function that determines noise level at time t, differentiable,
29 maps any tensor to tensor of same shape
30 alpha_prime (Callable): Derivative of alpha, maps any tensor to tensor of same shape
31 sigma_prime (Callable): Derivative of sigma, maps any tensor to tensor of same shape
32 """
34 def __init__(self, **dynamics_hparams: Any) -> None:
35 """
36 Initialize a diffusion process with specific dynamics parameters.
38 Args:
39 **dynamics_hparams: Keyword arguments containing the dynamics parameters.
40 Must include:
41 - alpha: Callable that maps time t to signal coefficient
42 - sigma: Callable that maps time t to noise coefficient
44 Raises:
45 AssertionError: If alpha or sigma is not provided in dynamics_hparams
46 """
47 super().__init__()
48 assert "alpha" in dynamics_hparams
49 assert "sigma" in dynamics_hparams
50 alpha: Callable[[torch.Tensor], torch.Tensor] = dynamics_hparams["alpha"]
51 sigma: Callable[[torch.Tensor], torch.Tensor] = dynamics_hparams["sigma"]
52 self.alpha: Callable[[torch.Tensor], torch.Tensor] = alpha
53 self.sigma: Callable[[torch.Tensor], torch.Tensor] = sigma
54 self.alpha_prime: Callable[[torch.Tensor], torch.Tensor] = scalar_derivative(
55 alpha
56 )
57 self.sigma_prime: Callable[[torch.Tensor], torch.Tensor] = scalar_derivative(
58 sigma
59 )
61 def forward(
62 self, x: torch.Tensor, t: torch.Tensor, eps: torch.Tensor
63 ) -> torch.Tensor:
64 """
65 Forward pass of the dynamics model.
67 This method implements the forward diffusion process, which gradually adds noise to the input data
68 according to the specified dynamics (alpha and sigma functions).
70 Args:
71 x (torch.Tensor): The input data tensor of shape (N, *D), where N is the batch size
72 and D represents the data dimensions.
73 t (torch.Tensor): The time parameter tensor of shape (N,) or broadcastable to x's shape,
74 with values typically in the range [0, 1].
75 eps (torch.Tensor): The Gaussian noise tensor of shape (N, *D), where N is the batch size
76 and D represents the data dimensions.
77 Returns:
78 torch.Tensor: The noised data at time t, computed as alpha(t) * x + sigma(t) * eps,
79 of shape (N, *D) matching the input shape.
80 """
81 alpha = pad_shape_back(self.alpha(t), x.shape)
82 sigma = pad_shape_back(self.sigma(t), x.shape)
83 return alpha * x + sigma * eps
86class VarianceExplodingProcess(DiffusionProcess):
87 """
88 Implements a Variance Exploding (VE) diffusion process.
90 In a VE process, the signal component remains constant (alpha(t) = 1) while the
91 noise component increases according to the provided sigma function. This leads to
92 the variance of the process "exploding" as t increases.
94 The forward process is defined as: x_t = x_0 + sigma(t) * eps
96 This is used in models like NCSN (Noise Conditional Score Network) and Score SDE.
97 """
99 def __init__(self, sigma: Callable[[torch.Tensor], torch.Tensor]) -> None:
100 """
101 Initialize a Variance Exploding diffusion process.
103 Args:
104 sigma (Callable): Function that determines how noise scales with time t.
105 Should map a tensor of time values of shape (N,) to noise
106 coefficients of the same shape.
107 """
108 super().__init__(alpha=lambda t: torch.ones_like(t), sigma=sigma)
111class OrnsteinUhlenbeckProcess(DiffusionProcess):
112 """
113 Implements an Ornstein-Uhlenbeck diffusion process.
115 The Ornstein-Uhlenbeck process is a mean-reverting stochastic process that describes
116 the velocity of a particle undergoing Brownian motion while being subject to friction.
118 In this implementation:
119 - alpha(t) = sqrt(1 - t²)
120 - sigma(t) = t
122 This process has properties that make it useful for certain generative modeling tasks,
123 particularly when a smooth transition between clean and noisy states is desired.
124 """
126 def __init__(self) -> None:
127 """
128 Initialize an Ornstein-Uhlenbeck diffusion process with predefined dynamics.
130 The process uses:
131 - alpha(t) = sqrt(1 - t²)
132 - sigma(t) = t
134 Both functions map tensors of shape (N,) to tensors of the same shape.
135 """
136 super().__init__(alpha=lambda t: torch.sqrt(1 - t**2), sigma=lambda t: t)
139class FlowMatchingProcess(DiffusionProcess):
140 """
141 Implements a Flow Matching diffusion process.
143 Flow Matching is a technique used in generative modeling where the goal is to learn
144 a continuous transformation (flow) between a simple distribution and a complex data
145 distribution.
147 In this implementation:
148 - alpha(t) = 1 - t
149 - sigma(t) = t
151 This creates a linear interpolation between the original data (at t=0) and
152 the noise (at t=1), which is useful for training flow-based generative models.
153 """
155 def __init__(self) -> None:
156 """
157 Initialize a Flow Matching diffusion process with predefined dynamics.
159 The process uses:
160 - alpha(t) = 1 - t
161 - sigma(t) = t
163 Both functions map tensors of shape (N,) to tensors of the same shape.
164 This creates a linear interpolation between the original data and noise.
165 """
166 super().__init__(alpha=lambda t: 1 - t, sigma=lambda t: t)