Skip to content

Vector Fields

This module contains functionality related to vector fields.

VectorField

A wrapper around a function (x, t) -> f(x, t) which provides some extra data, namely the type of vector field the function f represents.

This class encapsulates a vector field function and its type, allowing for consistent handling of different vector field representations in diffusion models.

Attributes:

Name Type Description
f Callable

A function that takes tensors x of shape (N, D) and t of shape (N,) and returns a tensor of shape (N, D).

vector_field_type VectorFieldType

The type of vector field the function represents.

Source code in src/diffusionlab/vector_fields.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class VectorField:
    """
    A wrapper around a function (x, t) -> f(x, t) which provides some extra data,
    namely the type of vector field the function f represents.

    This class encapsulates a vector field function and its type, allowing for
    consistent handling of different vector field representations in diffusion models.

    Attributes:
        f (Callable): A function that takes tensors x of shape (N, *D) and t of shape (N,)
            and returns a tensor of shape (N, *D).
        vector_field_type (VectorFieldType): The type of vector field the function represents.
    """

    def __init__(
        self,
        f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        vector_field_type: VectorFieldType,
    ):
        """
        Initialize a vector field wrapper.

        Args:
            f (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): A function that takes tensors x of shape (N, *D) and t of shape (N,)
                and returns a tensor of shape (N, *D).
            vector_field_type (VectorFieldType): The type of vector field the function represents
                                (SCORE, X0, EPS, or V).
        """
        self.f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = f
        self.vector_field_type: VectorFieldType = vector_field_type

    def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Call the wrapped vector field function.

        Args:
            x (torch.Tensor): Input tensor of shape (N, *D) where N is the batch size and D represents the data dimensions.
            t (torch.Tensor): Time parameter tensor of shape (N,).

        Returns:
            torch.Tensor: Output of the vector field function, of shape (N, *D).
        """
        return self.f(x, t)

f = f instance-attribute

vector_field_type = vector_field_type instance-attribute

__call__(x, t)

Call the wrapped vector field function.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (N, *D) where N is the batch size and D represents the data dimensions.

required
t Tensor

Time parameter tensor of shape (N,).

required

Returns:

Type Description
Tensor

torch.Tensor: Output of the vector field function, of shape (N, *D).

Source code in src/diffusionlab/vector_fields.py
47
48
49
50
51
52
53
54
55
56
57
58
def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    Call the wrapped vector field function.

    Args:
        x (torch.Tensor): Input tensor of shape (N, *D) where N is the batch size and D represents the data dimensions.
        t (torch.Tensor): Time parameter tensor of shape (N,).

    Returns:
        torch.Tensor: Output of the vector field function, of shape (N, *D).
    """
    return self.f(x, t)

__init__(f, vector_field_type)

Initialize a vector field wrapper.

Parameters:

Name Type Description Default
f Callable[[Tensor, Tensor], Tensor]

A function that takes tensors x of shape (N, D) and t of shape (N,) and returns a tensor of shape (N, D).

required
vector_field_type VectorFieldType

The type of vector field the function represents (SCORE, X0, EPS, or V).

required
Source code in src/diffusionlab/vector_fields.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    vector_field_type: VectorFieldType,
):
    """
    Initialize a vector field wrapper.

    Args:
        f (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): A function that takes tensors x of shape (N, *D) and t of shape (N,)
            and returns a tensor of shape (N, *D).
        vector_field_type (VectorFieldType): The type of vector field the function represents
                            (SCORE, X0, EPS, or V).
    """
    self.f: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = f
    self.vector_field_type: VectorFieldType = vector_field_type

VectorFieldType

Bases: Enum

Source code in src/diffusionlab/vector_fields.py
 9
10
11
12
13
class VectorFieldType(enum.Enum):
    SCORE = enum.auto()
    X0 = enum.auto()
    EPS = enum.auto()
    V = enum.auto()

EPS = enum.auto() class-attribute instance-attribute

SCORE = enum.auto() class-attribute instance-attribute

V = enum.auto() class-attribute instance-attribute

X0 = enum.auto() class-attribute instance-attribute

convert_vector_field_type(x, fx, alpha, sigma, alpha_prime, sigma_prime, in_type, out_type)

Converts the output of a vector field from one type to another.

Parameters:

Name Type Description Default
x Tensor

A tensor of shape (N, *D), where N is the batch size and D is the shape of the data (e.g., (C, H, W) for images, (D,) for vectors, or (N, D) for token sequences).

required
fx Tensor

The output of the vector field f, of shape (N, *D).

required
alpha Tensor

A tensor of shape (N,) representing the scale parameter.

required
sigma Tensor

A tensor of shape (N,) representing the noise level parameter.

required
alpha_prime Tensor

A tensor of shape (N,) representing the scale derivative parameter.

required
sigma_prime Tensor

A tensor of shape (N,) representing the noise level derivative parameter.

required
in_type VectorFieldType

The type of the input vector field (e.g. Score, X0, Eps, V).

required
out_type VectorFieldType

The type of the output vector field.

required

Returns:

Type Description
Tensor

torch.Tensor: The converted output of the vector field, of shape (N, *D).

Source code in src/diffusionlab/vector_fields.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def convert_vector_field_type(
    x: torch.Tensor,
    fx: torch.Tensor,
    alpha: torch.Tensor,
    sigma: torch.Tensor,
    alpha_prime: torch.Tensor,
    sigma_prime: torch.Tensor,
    in_type: VectorFieldType,
    out_type: VectorFieldType,
) -> torch.Tensor:
    """
    Converts the output of a vector field from one type to another.

    Arguments:
        x (torch.Tensor): A tensor of shape (N, *D), where N is the batch size and D is the shape
           of the data (e.g., (C, H, W) for images, (D,) for vectors, or (N, D) for token sequences).
        fx (torch.Tensor): The output of the vector field f, of shape (N, *D).
        alpha (torch.Tensor): A tensor of shape (N,) representing the scale parameter.
        sigma (torch.Tensor): A tensor of shape (N,) representing the noise level parameter.
        alpha_prime (torch.Tensor): A tensor of shape (N,) representing the scale derivative parameter.
        sigma_prime (torch.Tensor): A tensor of shape (N,) representing the noise level derivative parameter.
        in_type (VectorFieldType): The type of the input vector field (e.g. Score, X0, Eps, V).
        out_type (VectorFieldType): The type of the output vector field.

    Returns:
        torch.Tensor: The converted output of the vector field, of shape (N, *D).
    """
    """
    Derivation:
    ----------------------------
    Define certain quantities:
    alpha_r = alpha' / alpha
    sigma_r = sigma' / sigma
    diff_r = sigma_r - alpha_r
    and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0. 
    Under the data model 
    (1) x := alpha * x0 + sigma * eps
    it holds that 
    (2) x = alpha * E[x0 | x] + sigma * E[eps | x]
    Therefore 
    (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha
    (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma
    Furthermore, from (1) it holds that
    (5) v := x' = alpha' * x0 + sigma' * eps,
    or in particular
    (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x]
    Using (3), (4), (6) it holds 
    (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x] 
    => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x]
    => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x]
    => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x]
    (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r)
    and, similarly,
    (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x]) 
    => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x]
    => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x]
    => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x]
    (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r)
    To connect the score function to the other types, we use Tweedie's formula:
    (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma).
    Therefore, from (11):
    (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha
    From (12):
    (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2
    From (11) and (4):
    (14) E[eps | x] = -sigma * score(x, alpha, sigma)
    From (14):
    (15) score(x, alpha, sigma) = -E[eps | x] / sigma
    From (14) and (7):
    (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma)
    From (16):
    (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r)
    """
    alpha = pad_shape_back(alpha, x.shape)
    alpha_prime = pad_shape_back(alpha_prime, x.shape)
    sigma = pad_shape_back(sigma, x.shape)
    sigma_prime = pad_shape_back(sigma_prime, x.shape)
    alpha_ratio = alpha_prime / alpha
    sigma_ratio = sigma_prime / sigma
    ratio_diff = sigma_ratio - alpha_ratio
    converted_fx = fx

    if in_type == VectorFieldType.SCORE:
        if out_type == VectorFieldType.X0:
            converted_fx = (x + sigma**2 * fx) / alpha  # From equation (12)
        elif out_type == VectorFieldType.EPS:
            converted_fx = -sigma * fx  # From equation (14)
        elif out_type == VectorFieldType.V:
            converted_fx = (
                alpha_ratio * x - sigma**2 * ratio_diff * fx
            )  # From equation (16)

    elif in_type == VectorFieldType.X0:
        if out_type == VectorFieldType.SCORE:
            converted_fx = (alpha * fx - x) / sigma**2  # From equation (13)
        elif out_type == VectorFieldType.EPS:
            converted_fx = (x - alpha * fx) / sigma  # From equation (4)
        elif out_type == VectorFieldType.V:
            converted_fx = (
                sigma_ratio * x - alpha * ratio_diff * fx
            )  # From equation (9)

    elif in_type == VectorFieldType.EPS:
        if out_type == VectorFieldType.SCORE:
            converted_fx = -fx / sigma  # From equation (15)
        elif out_type == VectorFieldType.X0:
            converted_fx = (x - sigma * fx) / alpha  # From equation (3)
        elif out_type == VectorFieldType.V:
            converted_fx = (
                alpha_ratio * x + sigma * ratio_diff * fx
            )  # From equation (7)

    elif in_type == VectorFieldType.V:
        if out_type == VectorFieldType.SCORE:
            converted_fx = (alpha_ratio * x - fx) / (
                sigma**2 * ratio_diff
            )  # From equation (17)
        elif out_type == VectorFieldType.X0:
            converted_fx = (sigma_ratio * x - fx) / (
                alpha * ratio_diff
            )  # From equation (10)
        elif out_type == VectorFieldType.EPS:
            converted_fx = (fx - alpha_ratio * x) / (
                sigma * ratio_diff
            )  # From equation (8)

    return converted_fx