Skip to content

Losses

This module contains functionality related to losses.

SamplewiseDiffusionLoss

Bases: Module

Sample-wise loss function for training diffusion models.

This class implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model's prediction and the target, which depends on the chosen vector field type.

The loss supports different target types: - X0: Learn to predict the original clean data x_0 - EPS: Learn to predict the noise component eps - V: Learn to predict the velocity field v - SCORE: Not directly supported (raises ValueError)

Attributes:

Name Type Description
diffusion DiffusionProcess

The diffusion process defining the forward dynamics

target_type VectorFieldType

The type of target to learn via minimizing the loss function

target Callable

Function that computes the target based on the specified target_type. Takes tensors of shapes (N, D) for x_t, f_x_t, x_0, eps and (N,) for t, and returns a tensor of shape (N, D).

Source code in src/diffusionlab/losses.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class SamplewiseDiffusionLoss(nn.Module):
    """
    Sample-wise loss function for training diffusion models.

    This class implements various loss functions for diffusion models based on the specified
    target type. The loss is computed as the mean squared error between the model's prediction
    and the target, which depends on the chosen vector field type.

    The loss supports different target types:
    - X0: Learn to predict the original clean data x_0
    - EPS: Learn to predict the noise component eps
    - V: Learn to predict the velocity field v
    - SCORE: Not directly supported (raises ValueError)

    Attributes:
        diffusion (DiffusionProcess): The diffusion process defining the forward dynamics
        target_type (VectorFieldType): The type of target to learn via minimizing the loss function
        target (Callable): Function that computes the target based on the specified target_type.
                          Takes tensors of shapes (N, *D) for x_t, f_x_t, x_0, eps and (N,) for t,
                          and returns a tensor of shape (N, *D).
    """

    def __init__(
        self, diffusion_process: DiffusionProcess, target_type: VectorFieldType
    ) -> None:
        """
        Initialize the diffusion loss function.

        Args:
            diffusion_process: The diffusion process to use, containing data about the forward evolution.
            target_type: The type of target to learn via minimizing the loss function.
                         Must be one of VectorFieldType.X0, VectorFieldType.EPS, or VectorFieldType.V.

        Raises:
            ValueError: If target_type is VectorFieldType.SCORE, which is not directly supported.
        """
        super().__init__()
        self.diffusion_process: DiffusionProcess = diffusion_process
        self.target_type: VectorFieldType = target_type

        if target_type == VectorFieldType.X0:

            def target(
                x_t: torch.Tensor,
                f_x_t: torch.Tensor,
                x_0: torch.Tensor,
                eps: torch.Tensor,
                t: torch.Tensor,
            ) -> torch.Tensor:
                """
                Target function for predicting the original clean data x_0.

                Args:
                    x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                    f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                    x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                    eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                    t (torch.Tensor): The time parameter, of shape (N,).

                Returns:
                    torch.Tensor: The target tensor x_0, of shape (N, *D).
                """
                return x_0

        elif target_type == VectorFieldType.EPS:

            def target(
                x_t: torch.Tensor,
                f_x_t: torch.Tensor,
                x_0: torch.Tensor,
                eps: torch.Tensor,
                t: torch.Tensor,
            ) -> torch.Tensor:
                """
                Target function for predicting the noise component eps.

                Args:
                    x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                    f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                    x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                    eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                    t (torch.Tensor): The time parameter, of shape (N,).

                Returns:
                    torch.Tensor: The target tensor eps, of shape (N, *D).
                """
                return eps

        elif target_type == VectorFieldType.V:

            def target(
                x_t: torch.Tensor,
                f_x_t: torch.Tensor,
                x_0: torch.Tensor,
                eps: torch.Tensor,
                t: torch.Tensor,
            ) -> torch.Tensor:
                """
                Target function for predicting the velocity field v.

                Args:
                    x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                    f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                    x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                    eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                    t (torch.Tensor): The time parameter, of shape (N,).

                Returns:
                    torch.Tensor: The velocity field target tensor, of shape (N, *D).
                """
                return (
                    pad_shape_back(self.diffusion_process.alpha_prime(t), x_0.shape)
                    * x_0
                    + pad_shape_back(self.diffusion_process.sigma_prime(t), x_0.shape)
                    * eps
                )

        elif target_type == VectorFieldType.SCORE:
            raise ValueError(
                "Direct score matching is not supported due to lack of a known target function, and other ways (like Hutchinson's trace estimator) are very high variance."
            )

        self.target: Callable[
            [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
            torch.Tensor,
        ] = target

    def forward(
        self,
        x_t: torch.Tensor,
        f_x_t: torch.Tensor,
        x_0: torch.Tensor,
        eps: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute the loss for each sample in the batch.

        This method calculates the mean squared error between the model's prediction (f_x_t)
        and the target value determined by the target_type.

        Args:
            x_t (torch.Tensor): The noised data at time t, of shape (N, *D) where N is the batch size
                               and D represents the data dimensions.
            f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
            x_0 (torch.Tensor): The original clean data, of shape (N, *D).
            eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
            t (torch.Tensor): The time parameter, of shape (N,).

        Returns:
            torch.Tensor: The per-sample loss values, of shape (N,) where N is the batch size.
        """
        # Compute squared error between prediction and target
        squared_residuals = (f_x_t - self.target(x_t, f_x_t, x_0, eps, t)) ** 2

        # Sum over all dimensions except batch dimension
        samplewise_loss = torch.sum(
            torch.flatten(squared_residuals, start_dim=1, end_dim=-1), dim=1
        )

        return samplewise_loss

diffusion_process = diffusion_process instance-attribute

target = target instance-attribute

target_type = target_type instance-attribute

__init__(diffusion_process, target_type)

Initialize the diffusion loss function.

Parameters:

Name Type Description Default
diffusion_process DiffusionProcess

The diffusion process to use, containing data about the forward evolution.

required
target_type VectorFieldType

The type of target to learn via minimizing the loss function. Must be one of VectorFieldType.X0, VectorFieldType.EPS, or VectorFieldType.V.

required

Raises:

Type Description
ValueError

If target_type is VectorFieldType.SCORE, which is not directly supported.

Source code in src/diffusionlab/losses.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self, diffusion_process: DiffusionProcess, target_type: VectorFieldType
) -> None:
    """
    Initialize the diffusion loss function.

    Args:
        diffusion_process: The diffusion process to use, containing data about the forward evolution.
        target_type: The type of target to learn via minimizing the loss function.
                     Must be one of VectorFieldType.X0, VectorFieldType.EPS, or VectorFieldType.V.

    Raises:
        ValueError: If target_type is VectorFieldType.SCORE, which is not directly supported.
    """
    super().__init__()
    self.diffusion_process: DiffusionProcess = diffusion_process
    self.target_type: VectorFieldType = target_type

    if target_type == VectorFieldType.X0:

        def target(
            x_t: torch.Tensor,
            f_x_t: torch.Tensor,
            x_0: torch.Tensor,
            eps: torch.Tensor,
            t: torch.Tensor,
        ) -> torch.Tensor:
            """
            Target function for predicting the original clean data x_0.

            Args:
                x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                t (torch.Tensor): The time parameter, of shape (N,).

            Returns:
                torch.Tensor: The target tensor x_0, of shape (N, *D).
            """
            return x_0

    elif target_type == VectorFieldType.EPS:

        def target(
            x_t: torch.Tensor,
            f_x_t: torch.Tensor,
            x_0: torch.Tensor,
            eps: torch.Tensor,
            t: torch.Tensor,
        ) -> torch.Tensor:
            """
            Target function for predicting the noise component eps.

            Args:
                x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                t (torch.Tensor): The time parameter, of shape (N,).

            Returns:
                torch.Tensor: The target tensor eps, of shape (N, *D).
            """
            return eps

    elif target_type == VectorFieldType.V:

        def target(
            x_t: torch.Tensor,
            f_x_t: torch.Tensor,
            x_0: torch.Tensor,
            eps: torch.Tensor,
            t: torch.Tensor,
        ) -> torch.Tensor:
            """
            Target function for predicting the velocity field v.

            Args:
                x_t (torch.Tensor): The noised data at time t, of shape (N, *D).
                f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
                x_0 (torch.Tensor): The original clean data, of shape (N, *D).
                eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
                t (torch.Tensor): The time parameter, of shape (N,).

            Returns:
                torch.Tensor: The velocity field target tensor, of shape (N, *D).
            """
            return (
                pad_shape_back(self.diffusion_process.alpha_prime(t), x_0.shape)
                * x_0
                + pad_shape_back(self.diffusion_process.sigma_prime(t), x_0.shape)
                * eps
            )

    elif target_type == VectorFieldType.SCORE:
        raise ValueError(
            "Direct score matching is not supported due to lack of a known target function, and other ways (like Hutchinson's trace estimator) are very high variance."
        )

    self.target: Callable[
        [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
        torch.Tensor,
    ] = target

forward(x_t, f_x_t, x_0, eps, t)

Compute the loss for each sample in the batch.

This method calculates the mean squared error between the model's prediction (f_x_t) and the target value determined by the target_type.

Parameters:

Name Type Description Default
x_t Tensor

The noised data at time t, of shape (N, *D) where N is the batch size and D represents the data dimensions.

required
f_x_t Tensor

The model's prediction at time t, of shape (N, *D).

required
x_0 Tensor

The original clean data, of shape (N, *D).

required
eps Tensor

The noise used to generate x_t, of shape (N, *D).

required
t Tensor

The time parameter, of shape (N,).

required

Returns:

Type Description
Tensor

torch.Tensor: The per-sample loss values, of shape (N,) where N is the batch size.

Source code in src/diffusionlab/losses.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(
    self,
    x_t: torch.Tensor,
    f_x_t: torch.Tensor,
    x_0: torch.Tensor,
    eps: torch.Tensor,
    t: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the loss for each sample in the batch.

    This method calculates the mean squared error between the model's prediction (f_x_t)
    and the target value determined by the target_type.

    Args:
        x_t (torch.Tensor): The noised data at time t, of shape (N, *D) where N is the batch size
                           and D represents the data dimensions.
        f_x_t (torch.Tensor): The model's prediction at time t, of shape (N, *D).
        x_0 (torch.Tensor): The original clean data, of shape (N, *D).
        eps (torch.Tensor): The noise used to generate x_t, of shape (N, *D).
        t (torch.Tensor): The time parameter, of shape (N,).

    Returns:
        torch.Tensor: The per-sample loss values, of shape (N,) where N is the batch size.
    """
    # Compute squared error between prediction and target
    squared_residuals = (f_x_t - self.target(x_t, f_x_t, x_0, eps, t)) ** 2

    # Sum over all dimensions except batch dimension
    samplewise_loss = torch.sum(
        torch.flatten(squared_residuals, start_dim=1, end_dim=-1), dim=1
    )

    return samplewise_loss