Skip to content

Models

This module contains functionality related to models.

DiffusionModel

Bases: LightningModule, VectorField

A PyTorch Lightning module for training and evaluating diffusion models.

This class implements a diffusion model that can be trained using various vector field types (score, x0, eps, v) and diffusion processes. It handles the training loop, loss computation, and evaluation metrics.

The model inherits from both LightningModule (for training) and VectorField (for sampling), making it compatible with both the Lightning training framework and the diffusion sampling algorithms.

Attributes:

Name Type Description
net Module

The neural network that predicts the vector field.

vector_field_type VectorFieldType

The type of vector field the model predicts.

diffusion_process DiffusionProcess

The diffusion process used for training.

train_scheduler Scheduler

The scheduler for generating training time steps.

optimizer Optimizer

The optimizer for training the model.

lr_scheduler LRScheduler

The learning rate scheduler.

batchwise_metrics ModuleDict

Metrics computed on each batch during validation.

batchfree_metrics ModuleDict

Metrics computed at the end of validation epoch.

t_loss_weights Callable

Function that weights loss at different time steps.

t_loss_probs Callable

Function that determines sampling probability of time steps.

N_noise_per_sample int

Number of noise samples per data point.

samplewise_loss SamplewiseDiffusionLoss

Loss function for each sample.

train_ts Tensor

Precomputed time steps for training.

train_ts_loss_weights Tensor

Precomputed weights for each time step.

train_ts_loss_probs Tensor

Precomputed sampling probabilities for each time step.

LOG_ON_STEP_TRAIN_LOSS bool

Whether to log training loss on each step. Default is True.

LOG_ON_EPOCH_TRAIN_LOSS bool

Whether to log training loss on each epoch. Default is True.

LOG_ON_PROGRESS_BAR_TRAIN_LOSS bool

Whether to display training loss on the progress bar. Default is True.

LOG_ON_STEP_BATCHWISE_METRICS bool

Whether to log batchwise metrics on each step. Default is False.

LOG_ON_EPOCH_BATCHWISE_METRICS bool

Whether to log batchwise metrics on each epoch. Default is True.

LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS bool

Whether to display batchwise metrics on the progress bar. Default is False.

LOG_ON_STEP_BATCHFREE_METRICS bool

Whether to log batchfree metrics on each step. Default is False.

LOG_ON_EPOCH_BATCHFREE_METRICS bool

Whether to log batchfree metrics on each epoch. Default is True.

LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS bool

Whether to display batchfree metrics on the progress bar. Default is False.

Source code in src/diffusionlab/models.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class DiffusionModel(LightningModule, VectorField):
    """
    A PyTorch Lightning module for training and evaluating diffusion models.

    This class implements a diffusion model that can be trained using various vector field types
    (score, x0, eps, v) and diffusion processes. It handles the training loop, loss computation,
    and evaluation metrics.

    The model inherits from both LightningModule (for training) and VectorField (for sampling),
    making it compatible with both the Lightning training framework and the diffusion sampling
    algorithms.

    Attributes:
        net (nn.Module): The neural network that predicts the vector field.
        vector_field_type (VectorFieldType): The type of vector field the model predicts.
        diffusion_process (DiffusionProcess): The diffusion process used for training.
        train_scheduler (Scheduler): The scheduler for generating training time steps.
        optimizer (optim.Optimizer): The optimizer for training the model.
        lr_scheduler (optim.lr_scheduler.LRScheduler): The learning rate scheduler.
        batchwise_metrics (nn.ModuleDict): Metrics computed on each batch during validation.
        batchfree_metrics (nn.ModuleDict): Metrics computed at the end of validation epoch.
        t_loss_weights (Callable): Function that weights loss at different time steps.
        t_loss_probs (Callable): Function that determines sampling probability of time steps.
        N_noise_per_sample (int): Number of noise samples per data point.
        samplewise_loss (SamplewiseDiffusionLoss): Loss function for each sample.
        train_ts (torch.Tensor): Precomputed time steps for training.
        train_ts_loss_weights (torch.Tensor): Precomputed weights for each time step.
        train_ts_loss_probs (torch.Tensor): Precomputed sampling probabilities for each time step.
        LOG_ON_STEP_TRAIN_LOSS (bool): Whether to log training loss on each step. Default is True.
        LOG_ON_EPOCH_TRAIN_LOSS (bool): Whether to log training loss on each epoch. Default is True.
        LOG_ON_PROGRESS_BAR_TRAIN_LOSS (bool): Whether to display training loss on the progress bar. Default is True.
        LOG_ON_STEP_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each step. Default is False.
        LOG_ON_EPOCH_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each epoch. Default is True.
        LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS (bool): Whether to display batchwise metrics on the progress bar. Default is False.
        LOG_ON_STEP_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each step. Default is False.
        LOG_ON_EPOCH_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each epoch. Default is True.
        LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS (bool): Whether to display batchfree metrics on the progress bar. Default is False.
    """

    LOG_ON_STEP_TRAIN_LOSS = True
    LOG_ON_EPOCH_TRAIN_LOSS = True
    LOG_ON_PROGRESS_BAR_TRAIN_LOSS = True

    LOG_ON_STEP_BATCHWISE_METRICS = False
    LOG_ON_EPOCH_BATCHWISE_METRICS = True
    LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS = False

    LOG_ON_STEP_BATCHFREE_METRICS = False
    LOG_ON_EPOCH_BATCHFREE_METRICS = True
    LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS = False

    def __init__(
        self,
        net: nn.Module,
        diffusion_process: DiffusionProcess,
        train_scheduler: Scheduler,
        vector_field_type: VectorFieldType,
        optimizer: optim.Optimizer,
        lr_scheduler: optim.lr_scheduler.LRScheduler,
        batchwise_metrics: Dict[str, nn.Module],
        batchfree_metrics: Dict[str, nn.Module],
        train_ts_hparams: Dict[str, Any],
        t_loss_weights: Callable[[torch.Tensor], torch.Tensor],
        t_loss_probs: Callable[[torch.Tensor], torch.Tensor],
        N_noise_draws_per_sample: int,
    ):
        """
        Initialize the diffusion model.

        Args:
            net (nn.Module): Neural network that predicts the vector field.
            diffusion_process (DiffusionProcess): The diffusion process used for training.
            train_scheduler (Scheduler): Scheduler for generating training time steps.
            vector_field_type (VectorFieldType): Type of vector field the model predicts.
            optimizer (optim.Optimizer): Optimizer for training the model.
            lr_scheduler (optim.lr_scheduler.LRScheduler): Learning rate scheduler.
            batchwise_metrics (Dict[str, nn.Module]): Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs.
            batchfree_metrics (Dict[str, nn.Module]): Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs.
            train_ts_hparams (Dict[str, Any]): Parameters for the training time step scheduler.
            t_loss_weights (Callable[[torch.Tensor], torch.Tensor]): Function that weights loss at different time steps.
            t_loss_probs (Callable[[torch.Tensor], torch.Tensor]): Function that determines sampling probability of time steps.
            N_noise_draws_per_sample (int): Number of noise draws per data point.
        """
        super().__init__()
        self.net: nn.Module = net
        self.vector_field_type: VectorFieldType = vector_field_type
        self.diffusion_process: DiffusionProcess = diffusion_process
        self.train_scheduler: Scheduler = train_scheduler
        self.optimizer: optim.Optimizer = optimizer
        self.lr_scheduler: optim.lr_scheduler.LRScheduler = lr_scheduler
        self.batchwise_metrics: nn.ModuleDict = nn.ModuleDict(batchwise_metrics)
        self.batchfree_metrics: nn.ModuleDict = nn.ModuleDict(batchfree_metrics)

        self.t_loss_weights: Callable[[torch.Tensor], torch.Tensor] = t_loss_weights
        self.t_loss_probs: Callable[[torch.Tensor], torch.Tensor] = t_loss_probs
        self.N_noise_draws_per_sample: int = N_noise_draws_per_sample

        self.samplewise_loss: SamplewiseDiffusionLoss = SamplewiseDiffusionLoss(
            diffusion_process, vector_field_type
        )

        self.register_buffer("train_ts", torch.zeros((0,)))
        self.register_buffer("train_ts_loss_weights", torch.zeros((0,)))
        self.register_buffer("train_ts_loss_probs", torch.zeros((0,)))
        self.precompute_train_schedule(train_ts_hparams)

    def precompute_train_schedule(self, train_ts_hparams: Dict[str, float]) -> None:
        """
        Precompute time steps and their associated weights for training.

        This method generates the time steps used during training and computes
        the loss weights and sampling probabilities for each time step.

        Args:
            train_ts_hparams (Dict[str, float]): Parameters for the training time step scheduler.
                Typically includes t_min, t_max, and the number of steps L.
        """
        self.train_ts = self.train_scheduler.get_ts(**train_ts_hparams).to(
            self.device, non_blocking=True
        )
        self.train_ts_loss_weights: torch.Tensor = self.t_loss_weights(self.train_ts)
        self.train_ts_loss_probs: torch.Tensor = self.t_loss_probs(self.train_ts)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Passes the input through the neural network to predict the vector field.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, *data_dims).
            t (torch.Tensor): Time tensor of shape (batch_size,).

        Returns:
            torch.Tensor: Predicted vector field of shape (batch_size, *data_dims).
        """
        return self.net(x, t)

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """
        Configure optimizers and learning rate schedulers for training.

        This method is called by PyTorch Lightning to set up the optimization process.

        Returns:
            OptimizerLRScheduler: Dictionary containing the optimizer and learning rate scheduler.
        """
        return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler}

    def loss(
        self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the loss for a batch of data at specified time steps.

        This method:
        1. Repeats each sample N_noise_per_sample times
        2. Adds noise to the data according to the diffusion process
        3. Predicts the vector field
        4. Computes the loss between the prediction and the ground truth

        Args:
            x (torch.Tensor): Input data of shape (batch_size, *data_dims).
            t (torch.Tensor): Time steps of shape (batch_size,).
            sample_weights (torch.Tensor): Weights for each sample of shape (batch_size,).

        Returns:
            torch.Tensor: Scalar loss value.
        """
        x = torch.repeat_interleave(x, self.N_noise_draws_per_sample, dim=0)
        t = torch.repeat_interleave(t, self.N_noise_draws_per_sample, dim=0)
        sample_weights = torch.repeat_interleave(
            sample_weights, self.N_noise_draws_per_sample, dim=0
        )

        eps = torch.randn_like(x)
        xt = self.diffusion_process.forward(x, t, eps)
        fxt = self(xt, t)

        samplewise_loss = self.samplewise_loss(xt, fxt, x, eps, t)
        mean_loss = torch.mean(samplewise_loss * sample_weights)
        return mean_loss

    def aggregate_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the loss for a batch of data with randomly sampled time steps.

        This method:
        1. Samples time steps according to the training distribution
        2. Computes the loss at those time steps

        Args:
            x (torch.Tensor): Input data of shape (batch_size, *data_dims).

        Returns:
            torch.Tensor: Scalar loss value.
        """
        t_idx = torch.multinomial(
            self.train_ts_loss_probs, x.shape[0], replacement=True
        ).to(self.device, non_blocking=True)
        t = self.train_ts[t_idx]
        t_weights = self.train_ts_loss_weights[t_idx]
        mean_loss = self.loss(x, t, t_weights)
        return mean_loss

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Perform a single training step.

        This method is called by PyTorch Lightning during training.

        Args:
            batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
            batch_idx (int): Index of the current batch.

        Returns:
            torch.Tensor: Loss value for the batch.
        """
        x, metadata = batch
        loss = self.aggregate_loss(x)
        self.log(
            "train_loss",
            loss,
            on_step=self.LOG_ON_STEP_TRAIN_LOSS,
            on_epoch=self.LOG_ON_EPOCH_TRAIN_LOSS,
            prog_bar=self.LOG_ON_PROGRESS_BAR_TRAIN_LOSS,
        )
        return loss

    def validation_step(
        self, batch: torch.Tensor, batch_idx: int
    ) -> Dict[str, torch.Tensor]:
        """
        Perform a single validation step.

        This method is called by PyTorch Lightning during validation.
        It computes the loss and any batch-wise metrics.

        Args:
            batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
            batch_idx (int): Index of the current batch.

        Returns:
            Dict[str, torch.Tensor]: Dictionary of metric values.
        """
        x, metadata = batch
        loss = self.aggregate_loss(x)
        metric_values = {"val_loss": loss}
        for metric_name, metric in self.batchwise_metrics.items():
            metric_values_dict = metric(x, metadata, self)
            for key, value in metric_values_dict.items():
                metric_values[f"{metric_name}_{key}"] = value
        self.log_dict(
            metric_values,
            on_step=self.LOG_ON_STEP_BATCHWISE_METRICS,
            on_epoch=self.LOG_ON_EPOCH_BATCHWISE_METRICS,
            prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS,
        )
        return metric_values

    def on_validation_epoch_end(self) -> None:
        """
        Perform operations at the end of a validation epoch.

        This method is called by PyTorch Lightning at the end of each validation epoch.
        It computes and logs any batch-free metrics that require the entire validation set.
        """
        metric_values = {}
        for metric_name, metric in self.batchfree_metrics.items():
            metric_values_dict = metric(self)
            for key, value in metric_values_dict.items():
                metric_values[f"{metric_name}_{key}"] = value
        self.log_dict(
            metric_values,
            on_step=self.LOG_ON_STEP_BATCHFREE_METRICS,
            on_epoch=self.LOG_ON_EPOCH_BATCHFREE_METRICS,
            prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS,
        )

LOG_ON_EPOCH_BATCHFREE_METRICS = True class-attribute instance-attribute

LOG_ON_EPOCH_BATCHWISE_METRICS = True class-attribute instance-attribute

LOG_ON_EPOCH_TRAIN_LOSS = True class-attribute instance-attribute

LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS = False class-attribute instance-attribute

LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS = False class-attribute instance-attribute

LOG_ON_PROGRESS_BAR_TRAIN_LOSS = True class-attribute instance-attribute

LOG_ON_STEP_BATCHFREE_METRICS = False class-attribute instance-attribute

LOG_ON_STEP_BATCHWISE_METRICS = False class-attribute instance-attribute

LOG_ON_STEP_TRAIN_LOSS = True class-attribute instance-attribute

N_noise_draws_per_sample = N_noise_draws_per_sample instance-attribute

batchfree_metrics = nn.ModuleDict(batchfree_metrics) instance-attribute

batchwise_metrics = nn.ModuleDict(batchwise_metrics) instance-attribute

diffusion_process = diffusion_process instance-attribute

lr_scheduler = lr_scheduler instance-attribute

net = net instance-attribute

optimizer = optimizer instance-attribute

samplewise_loss = SamplewiseDiffusionLoss(diffusion_process, vector_field_type) instance-attribute

t_loss_probs = t_loss_probs instance-attribute

t_loss_weights = t_loss_weights instance-attribute

train_scheduler = train_scheduler instance-attribute

vector_field_type = vector_field_type instance-attribute

__init__(net, diffusion_process, train_scheduler, vector_field_type, optimizer, lr_scheduler, batchwise_metrics, batchfree_metrics, train_ts_hparams, t_loss_weights, t_loss_probs, N_noise_draws_per_sample)

Initialize the diffusion model.

Parameters:

Name Type Description Default
net Module

Neural network that predicts the vector field.

required
diffusion_process DiffusionProcess

The diffusion process used for training.

required
train_scheduler Scheduler

Scheduler for generating training time steps.

required
vector_field_type VectorFieldType

Type of vector field the model predicts.

required
optimizer Optimizer

Optimizer for training the model.

required
lr_scheduler LRScheduler

Learning rate scheduler.

required
batchwise_metrics Dict[str, Module]

Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs.

required
batchfree_metrics Dict[str, Module]

Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs.

required
train_ts_hparams Dict[str, Any]

Parameters for the training time step scheduler.

required
t_loss_weights Callable[[Tensor], Tensor]

Function that weights loss at different time steps.

required
t_loss_probs Callable[[Tensor], Tensor]

Function that determines sampling probability of time steps.

required
N_noise_draws_per_sample int

Number of noise draws per data point.

required
Source code in src/diffusionlab/models.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(
    self,
    net: nn.Module,
    diffusion_process: DiffusionProcess,
    train_scheduler: Scheduler,
    vector_field_type: VectorFieldType,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler.LRScheduler,
    batchwise_metrics: Dict[str, nn.Module],
    batchfree_metrics: Dict[str, nn.Module],
    train_ts_hparams: Dict[str, Any],
    t_loss_weights: Callable[[torch.Tensor], torch.Tensor],
    t_loss_probs: Callable[[torch.Tensor], torch.Tensor],
    N_noise_draws_per_sample: int,
):
    """
    Initialize the diffusion model.

    Args:
        net (nn.Module): Neural network that predicts the vector field.
        diffusion_process (DiffusionProcess): The diffusion process used for training.
        train_scheduler (Scheduler): Scheduler for generating training time steps.
        vector_field_type (VectorFieldType): Type of vector field the model predicts.
        optimizer (optim.Optimizer): Optimizer for training the model.
        lr_scheduler (optim.lr_scheduler.LRScheduler): Learning rate scheduler.
        batchwise_metrics (Dict[str, nn.Module]): Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs.
        batchfree_metrics (Dict[str, nn.Module]): Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs.
        train_ts_hparams (Dict[str, Any]): Parameters for the training time step scheduler.
        t_loss_weights (Callable[[torch.Tensor], torch.Tensor]): Function that weights loss at different time steps.
        t_loss_probs (Callable[[torch.Tensor], torch.Tensor]): Function that determines sampling probability of time steps.
        N_noise_draws_per_sample (int): Number of noise draws per data point.
    """
    super().__init__()
    self.net: nn.Module = net
    self.vector_field_type: VectorFieldType = vector_field_type
    self.diffusion_process: DiffusionProcess = diffusion_process
    self.train_scheduler: Scheduler = train_scheduler
    self.optimizer: optim.Optimizer = optimizer
    self.lr_scheduler: optim.lr_scheduler.LRScheduler = lr_scheduler
    self.batchwise_metrics: nn.ModuleDict = nn.ModuleDict(batchwise_metrics)
    self.batchfree_metrics: nn.ModuleDict = nn.ModuleDict(batchfree_metrics)

    self.t_loss_weights: Callable[[torch.Tensor], torch.Tensor] = t_loss_weights
    self.t_loss_probs: Callable[[torch.Tensor], torch.Tensor] = t_loss_probs
    self.N_noise_draws_per_sample: int = N_noise_draws_per_sample

    self.samplewise_loss: SamplewiseDiffusionLoss = SamplewiseDiffusionLoss(
        diffusion_process, vector_field_type
    )

    self.register_buffer("train_ts", torch.zeros((0,)))
    self.register_buffer("train_ts_loss_weights", torch.zeros((0,)))
    self.register_buffer("train_ts_loss_probs", torch.zeros((0,)))
    self.precompute_train_schedule(train_ts_hparams)

aggregate_loss(x)

Compute the loss for a batch of data with randomly sampled time steps.

This method: 1. Samples time steps according to the training distribution 2. Computes the loss at those time steps

Parameters:

Name Type Description Default
x Tensor

Input data of shape (batch_size, *data_dims).

required

Returns:

Type Description
Tensor

torch.Tensor: Scalar loss value.

Source code in src/diffusionlab/models.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def aggregate_loss(self, x: torch.Tensor) -> torch.Tensor:
    """
    Compute the loss for a batch of data with randomly sampled time steps.

    This method:
    1. Samples time steps according to the training distribution
    2. Computes the loss at those time steps

    Args:
        x (torch.Tensor): Input data of shape (batch_size, *data_dims).

    Returns:
        torch.Tensor: Scalar loss value.
    """
    t_idx = torch.multinomial(
        self.train_ts_loss_probs, x.shape[0], replacement=True
    ).to(self.device, non_blocking=True)
    t = self.train_ts[t_idx]
    t_weights = self.train_ts_loss_weights[t_idx]
    mean_loss = self.loss(x, t, t_weights)
    return mean_loss

configure_optimizers()

Configure optimizers and learning rate schedulers for training.

This method is called by PyTorch Lightning to set up the optimization process.

Returns:

Name Type Description
OptimizerLRScheduler OptimizerLRScheduler

Dictionary containing the optimizer and learning rate scheduler.

Source code in src/diffusionlab/models.py
151
152
153
154
155
156
157
158
159
160
def configure_optimizers(self) -> OptimizerLRScheduler:
    """
    Configure optimizers and learning rate schedulers for training.

    This method is called by PyTorch Lightning to set up the optimization process.

    Returns:
        OptimizerLRScheduler: Dictionary containing the optimizer and learning rate scheduler.
    """
    return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler}

forward(x, t)

Forward pass of the model.

Passes the input through the neural network to predict the vector field.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, *data_dims).

required
t Tensor

Time tensor of shape (batch_size,).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted vector field of shape (batch_size, *data_dims).

Source code in src/diffusionlab/models.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the model.

    Passes the input through the neural network to predict the vector field.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, *data_dims).
        t (torch.Tensor): Time tensor of shape (batch_size,).

    Returns:
        torch.Tensor: Predicted vector field of shape (batch_size, *data_dims).
    """
    return self.net(x, t)

loss(x, t, sample_weights)

Compute the loss for a batch of data at specified time steps.

This method: 1. Repeats each sample N_noise_per_sample times 2. Adds noise to the data according to the diffusion process 3. Predicts the vector field 4. Computes the loss between the prediction and the ground truth

Parameters:

Name Type Description Default
x Tensor

Input data of shape (batch_size, *data_dims).

required
t Tensor

Time steps of shape (batch_size,).

required
sample_weights Tensor

Weights for each sample of shape (batch_size,).

required

Returns:

Type Description
Tensor

torch.Tensor: Scalar loss value.

Source code in src/diffusionlab/models.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def loss(
    self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor
) -> torch.Tensor:
    """
    Compute the loss for a batch of data at specified time steps.

    This method:
    1. Repeats each sample N_noise_per_sample times
    2. Adds noise to the data according to the diffusion process
    3. Predicts the vector field
    4. Computes the loss between the prediction and the ground truth

    Args:
        x (torch.Tensor): Input data of shape (batch_size, *data_dims).
        t (torch.Tensor): Time steps of shape (batch_size,).
        sample_weights (torch.Tensor): Weights for each sample of shape (batch_size,).

    Returns:
        torch.Tensor: Scalar loss value.
    """
    x = torch.repeat_interleave(x, self.N_noise_draws_per_sample, dim=0)
    t = torch.repeat_interleave(t, self.N_noise_draws_per_sample, dim=0)
    sample_weights = torch.repeat_interleave(
        sample_weights, self.N_noise_draws_per_sample, dim=0
    )

    eps = torch.randn_like(x)
    xt = self.diffusion_process.forward(x, t, eps)
    fxt = self(xt, t)

    samplewise_loss = self.samplewise_loss(xt, fxt, x, eps, t)
    mean_loss = torch.mean(samplewise_loss * sample_weights)
    return mean_loss

on_validation_epoch_end()

Perform operations at the end of a validation epoch.

This method is called by PyTorch Lightning at the end of each validation epoch. It computes and logs any batch-free metrics that require the entire validation set.

Source code in src/diffusionlab/models.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def on_validation_epoch_end(self) -> None:
    """
    Perform operations at the end of a validation epoch.

    This method is called by PyTorch Lightning at the end of each validation epoch.
    It computes and logs any batch-free metrics that require the entire validation set.
    """
    metric_values = {}
    for metric_name, metric in self.batchfree_metrics.items():
        metric_values_dict = metric(self)
        for key, value in metric_values_dict.items():
            metric_values[f"{metric_name}_{key}"] = value
    self.log_dict(
        metric_values,
        on_step=self.LOG_ON_STEP_BATCHFREE_METRICS,
        on_epoch=self.LOG_ON_EPOCH_BATCHFREE_METRICS,
        prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS,
    )

precompute_train_schedule(train_ts_hparams)

Precompute time steps and their associated weights for training.

This method generates the time steps used during training and computes the loss weights and sampling probabilities for each time step.

Parameters:

Name Type Description Default
train_ts_hparams Dict[str, float]

Parameters for the training time step scheduler. Typically includes t_min, t_max, and the number of steps L.

required
Source code in src/diffusionlab/models.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def precompute_train_schedule(self, train_ts_hparams: Dict[str, float]) -> None:
    """
    Precompute time steps and their associated weights for training.

    This method generates the time steps used during training and computes
    the loss weights and sampling probabilities for each time step.

    Args:
        train_ts_hparams (Dict[str, float]): Parameters for the training time step scheduler.
            Typically includes t_min, t_max, and the number of steps L.
    """
    self.train_ts = self.train_scheduler.get_ts(**train_ts_hparams).to(
        self.device, non_blocking=True
    )
    self.train_ts_loss_weights: torch.Tensor = self.t_loss_weights(self.train_ts)
    self.train_ts_loss_probs: torch.Tensor = self.t_loss_probs(self.train_ts)

training_step(batch, batch_idx)

Perform a single training step.

This method is called by PyTorch Lightning during training.

Parameters:

Name Type Description Default
batch Tensor

Batch of data, typically a tuple (x, metadata).

required
batch_idx int

Index of the current batch.

required

Returns:

Type Description
Tensor

torch.Tensor: Loss value for the batch.

Source code in src/diffusionlab/models.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
    """
    Perform a single training step.

    This method is called by PyTorch Lightning during training.

    Args:
        batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
        batch_idx (int): Index of the current batch.

    Returns:
        torch.Tensor: Loss value for the batch.
    """
    x, metadata = batch
    loss = self.aggregate_loss(x)
    self.log(
        "train_loss",
        loss,
        on_step=self.LOG_ON_STEP_TRAIN_LOSS,
        on_epoch=self.LOG_ON_EPOCH_TRAIN_LOSS,
        prog_bar=self.LOG_ON_PROGRESS_BAR_TRAIN_LOSS,
    )
    return loss

validation_step(batch, batch_idx)

Perform a single validation step.

This method is called by PyTorch Lightning during validation. It computes the loss and any batch-wise metrics.

Parameters:

Name Type Description Default
batch Tensor

Batch of data, typically a tuple (x, metadata).

required
batch_idx int

Index of the current batch.

required

Returns:

Type Description
Dict[str, Tensor]

Dict[str, torch.Tensor]: Dictionary of metric values.

Source code in src/diffusionlab/models.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def validation_step(
    self, batch: torch.Tensor, batch_idx: int
) -> Dict[str, torch.Tensor]:
    """
    Perform a single validation step.

    This method is called by PyTorch Lightning during validation.
    It computes the loss and any batch-wise metrics.

    Args:
        batch (torch.Tensor): Batch of data, typically a tuple (x, metadata).
        batch_idx (int): Index of the current batch.

    Returns:
        Dict[str, torch.Tensor]: Dictionary of metric values.
    """
    x, metadata = batch
    loss = self.aggregate_loss(x)
    metric_values = {"val_loss": loss}
    for metric_name, metric in self.batchwise_metrics.items():
        metric_values_dict = metric(x, metadata, self)
        for key, value in metric_values_dict.items():
            metric_values[f"{metric_name}_{key}"] = value
    self.log_dict(
        metric_values,
        on_step=self.LOG_ON_STEP_BATCHWISE_METRICS,
        on_epoch=self.LOG_ON_EPOCH_BATCHWISE_METRICS,
        prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS,
    )
    return metric_values