Coverage for src/diffusionlab/models.py: 100%
85 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:33 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-21 15:33 -0700
1from typing import Any, Callable, Dict
2import torch
3from torch import nn, optim
4from lightning import LightningModule
5from lightning.pytorch.utilities.types import OptimizerLRScheduler
7from diffusionlab.losses import SamplewiseDiffusionLoss
8from diffusionlab.diffusions import DiffusionProcess
9from diffusionlab.schedulers import Scheduler
10from diffusionlab.vector_fields import VectorField, VectorFieldType
13class DiffusionModel(LightningModule, VectorField):
14 """
15 A PyTorch Lightning module for training and evaluating diffusion models.
17 This class implements a diffusion model that can be trained using various vector field types
18 (score, x0, eps, v) and diffusion processes. It handles the training loop, loss computation,
19 and evaluation metrics.
21 The model inherits from both LightningModule (for training) and VectorField (for sampling),
22 making it compatible with both the Lightning training framework and the diffusion sampling
23 algorithms.
25 Attributes:
26 net (nn.Module): The neural network that predicts the vector field.
27 vector_field_type (VectorFieldType): The type of vector field the model predicts.
28 diffusion_process (DiffusionProcess): The diffusion process used for training.
29 train_scheduler (Scheduler): The scheduler for generating training time steps.
30 optimizer (optim.Optimizer): The optimizer for training the model.
31 lr_scheduler (optim.lr_scheduler.LRScheduler): The learning rate scheduler.
32 batchwise_metrics (nn.ModuleDict): Metrics computed on each batch during validation.
33 batchfree_metrics (nn.ModuleDict): Metrics computed at the end of validation epoch.
34 t_loss_weights (Callable): Function that weights loss at different time steps.
35 t_loss_probs (Callable): Function that determines sampling probability of time steps.
36 N_noise_draws_per_sample (int): Number of noise samples per data point.
37 samplewise_loss (SamplewiseDiffusionLoss): Loss function for each sample.
38 batchwise_loss (Callable): Factory-generated function that computes loss for a batch.
39 train_ts (torch.Tensor): Precomputed time steps for training.
40 train_ts_loss_weights (torch.Tensor): Precomputed weights for each time step.
41 train_ts_loss_probs (torch.Tensor): Precomputed sampling probabilities for each time step.
42 LOG_ON_STEP_TRAIN_LOSS (bool): Whether to log training loss on each step. Default is True.
43 LOG_ON_EPOCH_TRAIN_LOSS (bool): Whether to log training loss on each epoch. Default is True.
44 LOG_ON_PROGRESS_BAR_TRAIN_LOSS (bool): Whether to display training loss on the progress bar. Default is True.
45 LOG_ON_STEP_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each step. Default is False.
46 LOG_ON_EPOCH_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each epoch. Default is True.
47 LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS (bool): Whether to display batchwise metrics on the progress bar. Default is False.
48 LOG_ON_STEP_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each step. Default is False.
49 LOG_ON_EPOCH_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each epoch. Default is True.
50 LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS (bool): Whether to display batchfree metrics on the progress bar. Default is False.
51 """
53 LOG_ON_STEP_TRAIN_LOSS = True
54 LOG_ON_EPOCH_TRAIN_LOSS = True
55 LOG_ON_PROGRESS_BAR_TRAIN_LOSS = True
57 LOG_ON_STEP_BATCHWISE_METRICS = False
58 LOG_ON_EPOCH_BATCHWISE_METRICS = True
59 LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS = False
61 LOG_ON_STEP_BATCHFREE_METRICS = False
62 LOG_ON_EPOCH_BATCHFREE_METRICS = True
63 LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS = False
65 def __init__(
66 self,
67 net: nn.Module,
68 diffusion_process: DiffusionProcess,
69 train_scheduler: Scheduler,
70 vector_field_type: VectorFieldType,
71 optimizer: optim.Optimizer,
72 lr_scheduler: optim.lr_scheduler.LRScheduler,
73 batchwise_metrics: Dict[str, nn.Module],
74 batchfree_metrics: Dict[str, nn.Module],
75 train_ts_hparams: Dict[str, Any],
76 t_loss_weights: Callable[[torch.Tensor], torch.Tensor],
77 t_loss_probs: Callable[[torch.Tensor], torch.Tensor],
78 N_noise_draws_per_sample: int,
79 ):
80 """
81 Initialize the diffusion model.
83 Args:
84 net (nn.Module): Neural network that predicts the vector field.
85 diffusion_process (DiffusionProcess): The diffusion process used for training.
86 train_scheduler (Scheduler): Scheduler for generating training time steps.
87 vector_field_type (VectorFieldType): Type of vector field the model predicts.
88 optimizer (optim.Optimizer): Optimizer for training the model.
89 lr_scheduler (optim.lr_scheduler.LRScheduler): Learning rate scheduler.
90 batchwise_metrics (Dict[str, nn.Module]): Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs.
91 batchfree_metrics (Dict[str, nn.Module]): Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs.
92 train_ts_hparams (Dict[str, Any]): Parameters for the training time step scheduler.
93 t_loss_weights (Callable[[torch.Tensor], torch.Tensor]): Function that weights loss at different time steps.
94 t_loss_probs (Callable[[torch.Tensor], torch.Tensor]): Function that determines sampling probability of time steps.
95 N_noise_draws_per_sample (int): Number of noise draws per data point.
96 """
97 super().__init__()
98 # Initialize VectorField with a forward function for the current instance
99 VectorField.__init__(self, self.forward, vector_field_type)
101 self.net: nn.Module = net
102 self.vector_field_type: VectorFieldType = vector_field_type
103 self.diffusion_process: DiffusionProcess = diffusion_process
104 self.train_scheduler: Scheduler = train_scheduler
105 self.optimizer: optim.Optimizer = optimizer
106 self.lr_scheduler: optim.lr_scheduler.LRScheduler = lr_scheduler
107 self.batchwise_metrics: nn.ModuleDict = nn.ModuleDict(batchwise_metrics)
108 self.batchfree_metrics: nn.ModuleDict = nn.ModuleDict(batchfree_metrics)
110 self.t_loss_weights: Callable[[torch.Tensor], torch.Tensor] = t_loss_weights
111 self.t_loss_probs: Callable[[torch.Tensor], torch.Tensor] = t_loss_probs
112 self.N_noise_draws_per_sample: int = N_noise_draws_per_sample
114 # Create the samplewise loss function
115 self.samplewise_loss: SamplewiseDiffusionLoss = SamplewiseDiffusionLoss(
116 diffusion_process, vector_field_type
117 )
119 # Create the batchwise loss function using the factory method
120 self.batchwise_loss = self.samplewise_loss.batchwise_loss_factory(
121 N_noise_draws_per_sample=N_noise_draws_per_sample
122 )
124 self.register_buffer("train_ts", torch.zeros((0,)))
125 self.register_buffer("train_ts_loss_weights", torch.zeros((0,)))
126 self.register_buffer("train_ts_loss_probs", torch.zeros((0,)))
127 self.precompute_train_schedule(train_ts_hparams)
129 def precompute_train_schedule(self, train_ts_hparams: Dict[str, float]) -> None:
130 """
131 Precompute time steps and their associated weights for training.
133 This method generates the time steps used during training and computes
134 the loss weights and sampling probabilities for each time step.
136 Args:
137 train_ts_hparams (Dict[str, float]): Parameters for the training time step scheduler.
138 Typically includes t_min, t_max, and the number of steps L.
139 """
140 self.train_ts = self.train_scheduler.get_ts(**train_ts_hparams).to(
141 self.device, non_blocking=True
142 )
143 self.train_ts_loss_weights: torch.Tensor = self.t_loss_weights(self.train_ts)
144 self.train_ts_loss_probs: torch.Tensor = self.t_loss_probs(self.train_ts)
146 def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
147 """
148 Forward pass of the model.
150 Passes the input through the neural network to predict the vector field.
152 Args:
153 x (torch.Tensor): Input tensor of shape (batch_size, *data_dims).
154 t (torch.Tensor): Time tensor of shape (batch_size,).
156 Returns:
157 torch.Tensor: Predicted vector field of shape (batch_size, *data_dims).
158 """
159 return self.net(x, t)
161 def configure_optimizers(self) -> OptimizerLRScheduler:
162 """
163 Configure optimizers and learning rate schedulers for training.
165 This method is called by PyTorch Lightning to set up the optimization process.
167 Returns:
168 OptimizerLRScheduler: Dictionary containing the optimizer and learning rate scheduler.
169 """
170 return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler}
172 def loss(
173 self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor
174 ) -> torch.Tensor:
175 """
176 Compute the loss for a batch of data at specified time steps.
178 Uses the batchwise_loss function created from the SamplewiseDiffusionLoss factory
179 to compute the loss for the batch.
181 Args:
182 x (torch.Tensor): Input data of shape (batch_size, *data_dims).
183 t (torch.Tensor): Time steps of shape (batch_size,).
184 sample_weights (torch.Tensor): Weights for each sample of shape (batch_size,).
186 Returns:
187 torch.Tensor: Scalar loss value.
188 """
189 return self.batchwise_loss(self, x, t, sample_weights)
191 def aggregate_loss(self, x: torch.Tensor) -> torch.Tensor:
192 """
193 Compute the loss for a batch of data with randomly sampled time steps.
195 This method:
196 1. Samples time steps according to the training distribution
197 2. Computes the loss at those time steps
199 Args:
200 x (torch.Tensor): Input data of shape (batch_size, *data_dims).
202 Returns:
203 torch.Tensor: Scalar loss value.
204 """
205 t_idx = torch.multinomial(
206 self.train_ts_loss_probs, x.shape[0], replacement=True
207 ).to(self.device, non_blocking=True)
208 t = self.train_ts[t_idx]
209 t_weights = self.train_ts_loss_weights[t_idx]
210 mean_loss = self.loss(x, t, t_weights)
211 return mean_loss
213 def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
214 """
215 Perform a single training step.
217 This method is called by PyTorch Lightning during training.
219 Args:
220 batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
221 batch_idx (int): Index of the current batch.
223 Returns:
224 torch.Tensor: Loss value for the batch.
225 """
226 x, metadata = batch
227 loss = self.aggregate_loss(x)
228 self.log(
229 "train_loss",
230 loss,
231 on_step=self.LOG_ON_STEP_TRAIN_LOSS,
232 on_epoch=self.LOG_ON_EPOCH_TRAIN_LOSS,
233 prog_bar=self.LOG_ON_PROGRESS_BAR_TRAIN_LOSS,
234 )
235 return loss
237 def validation_step(
238 self, batch: torch.Tensor, batch_idx: int
239 ) -> Dict[str, torch.Tensor]:
240 """
241 Perform a single validation step.
243 This method is called by PyTorch Lightning during validation.
244 It computes the loss and any batch-wise metrics.
246 Args:
247 batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
248 batch_idx (int): Index of the current batch.
250 Returns:
251 Dict[str, torch.Tensor]: Dictionary of metric values.
252 """
253 x, metadata = batch
254 loss = self.aggregate_loss(x)
255 metric_values = {"val_loss": loss}
256 for metric_name, metric in self.batchwise_metrics.items():
257 metric_values_dict = metric(x, metadata, self)
258 for key, value in metric_values_dict.items():
259 metric_label = self._get_metric_label(metric_name, key)
260 metric_values[metric_label] = value
261 self.log_dict(
262 metric_values,
263 on_step=self.LOG_ON_STEP_BATCHWISE_METRICS,
264 on_epoch=self.LOG_ON_EPOCH_BATCHWISE_METRICS,
265 prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS,
266 )
267 return metric_values
269 def on_validation_epoch_end(self) -> None:
270 """
271 Perform operations at the end of a validation epoch.
273 This method is called by PyTorch Lightning at the end of each validation epoch.
274 It computes and logs any batch-free metrics that require the entire validation set.
275 """
276 metric_values = {}
277 for metric_name, metric in self.batchfree_metrics.items():
278 metric_values_dict = metric(self)
279 for key, value in metric_values_dict.items():
280 metric_label = self._get_metric_label(metric_name, key)
281 metric_values[metric_label] = value
282 self.log_dict(
283 metric_values,
284 on_step=self.LOG_ON_STEP_BATCHFREE_METRICS,
285 on_epoch=self.LOG_ON_EPOCH_BATCHFREE_METRICS,
286 prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS,
287 )
289 def _get_metric_label(self, metric_name: str, key: str) -> str:
290 """
291 Get the label for a metric's values.
293 This method concatenates the metric name and key with an underscore if both are non-empty.
294 If one of the two is empty, it concatenates the non-empty one with the other.
296 Args:
297 metric_name (str): The name of the metric.
298 key (str): The key of the metric.
300 Returns:
301 str: The label for the metric's values.
302 """
303 metric_name, key = metric_name.strip(), key.strip()
304 if len(metric_name) > 0 and len(key) > 0:
305 metric_label = f"{metric_name}_{key}"
306 else:
307 metric_label = f"{metric_name}{key}"
308 return metric_label