Coverage for src/diffusionlab/losses.py: 100%
40 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:33 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:33 -0700
1from typing import Callable
3import torch
4from torch import nn
6from diffusionlab.diffusions import DiffusionProcess
7from diffusionlab.utils import pad_shape_back
8from diffusionlab.vector_fields import VectorField, VectorFieldType
11class SamplewiseDiffusionLoss(nn.Module):
12 """
13 Sample-wise loss function for training diffusion models.
15 This class implements various loss functions for diffusion models based on the specified
16 target type. The loss is computed as the mean squared error between the model's prediction
17 and the target, which depends on the chosen vector field type.
19 The loss supports different target types:
20 - X0: Learn to predict the original clean data x_0
21 - EPS: Learn to predict the noise component eps
22 - V: Learn to predict the velocity field v
23 - SCORE: Not directly supported (raises ValueError)
25 Attributes:
26 diffusion (DiffusionProcess): The diffusion process defining the forward dynamics
27 target_type (VectorFieldType): The type of target to learn via minimizing the loss function
28 target (Callable): Function that computes the target based on the specified target_type.
29 Takes tensors of shapes (N, *D) for x_t, f_x_t, x_0, eps and (N,) for t,
30 and returns a tensor of shape (N, *D).
31 """
33 def __init__(
34 self, diffusion_process: DiffusionProcess, target_type: VectorFieldType
35 ) -> None:
36 """
37 Initialize the diffusion loss function.
39 Args:
40 diffusion_process: The diffusion process to use, containing data about the forward evolution.
41 target_type: The type of target to learn via minimizing the loss function.
42 Must be one of VectorFieldType.X0, VectorFieldType.EPS, or VectorFieldType.V.
44 Raises:
45 ValueError: If target_type is VectorFieldType.SCORE, which is not directly supported.
46 """
47 super().__init__()
48 self.diffusion_process: DiffusionProcess = diffusion_process
49 self.target_type: VectorFieldType = target_type
51 if target_type == VectorFieldType.X0:
53 def target(
54 x_t: torch.Tensor,
55 f_x_t: torch.Tensor,
56 x_0: torch.Tensor,
57 eps: torch.Tensor,
58 t: torch.Tensor,
59 ) -> torch.Tensor:
60 """
61 Target function for predicting the original clean data x_0.
63 Args:
64 x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
65 f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
66 x_0 (torch.Tensor): The original clean data, of shape (N, *D).
67 eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
68 t (torch.Tensor): The time parameter, of shape (N,).
70 Returns:
71 torch.Tensor: The target tensor x_0, of shape (N, *D).
72 """
73 return x_0
75 elif target_type == VectorFieldType.EPS:
77 def target(
78 x_t: torch.Tensor,
79 f_x_t: torch.Tensor,
80 x_0: torch.Tensor,
81 eps: torch.Tensor,
82 t: torch.Tensor,
83 ) -> torch.Tensor:
84 """
85 Target function for predicting the noise component eps.
87 Args:
88 x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
89 f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
90 x_0 (torch.Tensor): The original clean data, of shape (N, *D).
91 eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
92 t (torch.Tensor): The time parameter, of shape (N,).
94 Returns:
95 torch.Tensor: The target tensor eps, of shape (N, *D).
96 """
97 return eps
99 elif target_type == VectorFieldType.V:
101 def target(
102 x_t: torch.Tensor,
103 f_x_t: torch.Tensor,
104 x_0: torch.Tensor,
105 eps: torch.Tensor,
106 t: torch.Tensor,
107 ) -> torch.Tensor:
108 """
109 Target function for predicting the velocity field v.
111 Args:
112 x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
113 f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
114 x_0 (torch.Tensor): The original clean data, of shape (N, *D).
115 eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
116 t (torch.Tensor): The time parameter, of shape (N,).
118 Returns:
119 torch.Tensor: The velocity field target tensor, of shape (N, *D).
120 """
121 return (
122 pad_shape_back(self.diffusion_process.alpha_prime(t), x_0.shape)
123 * x_0
124 + pad_shape_back(self.diffusion_process.sigma_prime(t), x_0.shape)
125 * eps
126 )
128 elif target_type == VectorFieldType.SCORE:
129 raise ValueError(
130 "Direct score matching is not supported due to lack of a known target function, and other ways (like Hutchinson's trace estimator) are very high variance."
131 )
133 self.target: Callable[
134 [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
135 torch.Tensor,
136 ] = target
138 def forward(
139 self,
140 x_t: torch.Tensor,
141 f_x_t: torch.Tensor,
142 x_0: torch.Tensor,
143 eps: torch.Tensor,
144 t: torch.Tensor,
145 ) -> torch.Tensor:
146 """
147 Compute the loss for each sample in the batch.
149 This method calculates the mean squared error between the model's prediction (f_x_t)
150 and the target value determined by the target_type.
152 Args:
153 x_t (torch.Tensor): The noised data at time t, of shape (N, *D) where N is the batch size
154 and D represents the data dimensions.
155 f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
156 x_0 (torch.Tensor): The original clean data, of shape (N, *D).
157 eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
158 t (torch.Tensor): The time parameter, of shape (N,).
160 Returns:
161 torch.Tensor: The per-sample loss values, of shape (N,) where N is the batch size.
162 """
163 # Compute squared error between prediction and target
164 squared_residuals = (f_x_t - self.target(x_t, f_x_t, x_0, eps, t)) ** 2
166 # Sum over all dimensions except batch dimension
167 samplewise_loss = torch.sum(
168 torch.flatten(squared_residuals, start_dim=1, end_dim=-1), dim=1
169 )
171 return samplewise_loss
173 def batchwise_loss_factory(
174 self, N_noise_draws_per_sample: int
175 ) -> Callable[
176 [VectorField, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
177 ]:
178 """
179 Create a batchwise loss function that averages the samplewise loss over multiple noise draws per sample.
181 This factory method returns a function that can be used during training to compute the loss
182 for a batch of data. The returned function handles the process of:
183 1. Repeating each sample N times to apply different noise realizations
184 2. Adding noise according to the diffusion process
185 3. Computing model predictions
186 4. Calculating and weighting the loss
188 Args:
189 N_noise_draws_per_sample (int): The number of different noise realizations to use
190 for each data sample. Higher values can reduce variance but increase computation.
192 Returns:
193 Callable[[VectorField, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
194 A function that computes the weighted average loss across a batch with the signature:
195 (vector_field, data, timesteps, sample_weights) -> scalar_loss
196 """
198 def batchwise_loss(
199 f: VectorField,
200 x: torch.Tensor,
201 t: torch.Tensor,
202 sample_weights: torch.Tensor,
203 ) -> torch.Tensor:
204 """
205 Compute the weighted average loss across a batch with multiple noise draws per sample.
207 This function:
208 1. Verifies the vector field type matches the target type
209 2. Repeats each sample N_noise_draws_per_sample times to apply different noise realizations
210 3. Adds noise to the data according to the diffusion process at time t
211 4. Computes the model's predictions
212 5. Calculates the per-sample loss and applies sample weights
213 6. Returns the mean loss across all samples and noise draws
215 Args:
216 f (VectorField): The vector field model to evaluate, must match the target type
217 of this loss function.
218 x (torch.Tensor): The clean input data, of shape (N, *D).
219 t (torch.Tensor): The diffusion timesteps, of shape (N,).
220 sample_weights (torch.Tensor): The importance weights for each sample in the batch,
221 of shape (N,). Used to prioritize certain samples in the loss.
223 Returns:
224 torch.Tensor: A scalar tensor containing the weighted average loss.
225 """
226 assert f.vector_field_type == self.target_type
227 x = torch.repeat_interleave(x, N_noise_draws_per_sample, dim=0)
228 t = torch.repeat_interleave(t, N_noise_draws_per_sample, dim=0)
229 sample_weights = torch.repeat_interleave(
230 sample_weights, N_noise_draws_per_sample, dim=0
231 )
233 eps = torch.randn_like(x)
234 xt = self.diffusion_process.forward(x, t, eps)
235 fxt = f(xt, t)
237 samplewise_loss = self(xt, fxt, x, eps, t)
238 mean_loss = torch.mean(samplewise_loss * sample_weights)
239 return mean_loss
241 return batchwise_loss