Coverage for src/diffusionlab/diffusions.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from typing import Any, Callable 

2import torch 

3 

4from diffusionlab.utils import pad_shape_back, scalar_derivative 

5 

6 

7class DiffusionProcess: 

8 """ 

9 Base class for implementing various diffusion processes. 

10 

11 A diffusion process defines how data evolves over time when noise is added according to 

12 specific dynamics. This class provides a framework for implementing different types of 

13 diffusion processes used in generative modeling. 

14 

15 The diffusion is parameterized by two functions: 

16 - alpha(t): Controls how much of the original signal is preserved at time t 

17 - sigma(t): Controls how much noise is added at time t 

18 

19 The forward process is defined as: x_t = alpha(t) * x_0 + sigma(t) * eps, where: 

20 - x_0 is the original data 

21 - x_t is the noised data at time t 

22 - eps is random noise sampled from a standard Gaussian distribution 

23 - t is the diffusion time parameter, typically in range [0, 1] 

24 

25 Attributes: 

26 alpha (Callable): Function that determines signal preservation at time t, differentiable, 

27 maps any tensor to tensor of same shape 

28 sigma (Callable): Function that determines noise level at time t, differentiable, 

29 maps any tensor to tensor of same shape 

30 alpha_prime (Callable): Derivative of alpha, maps any tensor to tensor of same shape 

31 sigma_prime (Callable): Derivative of sigma, maps any tensor to tensor of same shape 

32 """ 

33 

34 def __init__(self, **dynamics_hparams: Any) -> None: 

35 """ 

36 Initialize a diffusion process with specific dynamics parameters. 

37 

38 Args: 

39 **dynamics_hparams: Keyword arguments containing the dynamics parameters. 

40 Must include: 

41 - alpha: Callable that maps time t to signal coefficient 

42 - sigma: Callable that maps time t to noise coefficient 

43 

44 Raises: 

45 AssertionError: If alpha or sigma is not provided in dynamics_hparams 

46 """ 

47 super().__init__() 

48 assert "alpha" in dynamics_hparams 

49 assert "sigma" in dynamics_hparams 

50 alpha: Callable[[torch.Tensor], torch.Tensor] = dynamics_hparams["alpha"] 

51 sigma: Callable[[torch.Tensor], torch.Tensor] = dynamics_hparams["sigma"] 

52 self.alpha: Callable[[torch.Tensor], torch.Tensor] = alpha 

53 self.sigma: Callable[[torch.Tensor], torch.Tensor] = sigma 

54 self.alpha_prime: Callable[[torch.Tensor], torch.Tensor] = scalar_derivative( 

55 alpha 

56 ) 

57 self.sigma_prime: Callable[[torch.Tensor], torch.Tensor] = scalar_derivative( 

58 sigma 

59 ) 

60 

61 def forward( 

62 self, x: torch.Tensor, t: torch.Tensor, eps: torch.Tensor 

63 ) -> torch.Tensor: 

64 """ 

65 Forward pass of the dynamics model. 

66 

67 This method implements the forward diffusion process, which gradually adds noise to the input data 

68 according to the specified dynamics (alpha and sigma functions). 

69 

70 Args: 

71 x (torch.Tensor): The input data tensor of shape (N, *D), where N is the batch size 

72 and D represents the data dimensions. 

73 t (torch.Tensor): The time parameter tensor of shape (N,) or broadcastable to x's shape, 

74 with values typically in the range [0, 1]. 

75 eps (torch.Tensor): The Gaussian noise tensor of shape (N, *D), where N is the batch size 

76 and D represents the data dimensions. 

77 Returns: 

78 torch.Tensor: The noised data at time t, computed as alpha(t) * x + sigma(t) * eps, 

79 of shape (N, *D) matching the input shape. 

80 """ 

81 alpha = pad_shape_back(self.alpha(t), x.shape) 

82 sigma = pad_shape_back(self.sigma(t), x.shape) 

83 return alpha * x + sigma * eps 

84 

85 

86class VarianceExplodingProcess(DiffusionProcess): 

87 """ 

88 Implements a Variance Exploding (VE) diffusion process. 

89 

90 In a VE process, the signal component remains constant (alpha(t) = 1) while the 

91 noise component increases according to the provided sigma function. This leads to 

92 the variance of the process "exploding" as t increases. 

93 

94 The forward process is defined as: x_t = x_0 + sigma(t) * eps 

95 

96 This is used in models like NCSN (Noise Conditional Score Network) and Score SDE. 

97 """ 

98 

99 def __init__(self, sigma: Callable[[torch.Tensor], torch.Tensor]) -> None: 

100 """ 

101 Initialize a Variance Exploding diffusion process. 

102 

103 Args: 

104 sigma (Callable): Function that determines how noise scales with time t. 

105 Should map a tensor of time values of shape (N,) to noise 

106 coefficients of the same shape. 

107 """ 

108 super().__init__(alpha=lambda t: torch.ones_like(t), sigma=sigma) 

109 

110 

111class OrnsteinUhlenbeckProcess(DiffusionProcess): 

112 """ 

113 Implements an Ornstein-Uhlenbeck diffusion process. 

114 

115 The Ornstein-Uhlenbeck process is a mean-reverting stochastic process that describes 

116 the velocity of a particle undergoing Brownian motion while being subject to friction. 

117 

118 In this implementation: 

119 - alpha(t) = sqrt(1 - t²) 

120 - sigma(t) = t 

121 

122 This process has properties that make it useful for certain generative modeling tasks, 

123 particularly when a smooth transition between clean and noisy states is desired. 

124 """ 

125 

126 def __init__(self) -> None: 

127 """ 

128 Initialize an Ornstein-Uhlenbeck diffusion process with predefined dynamics. 

129 

130 The process uses: 

131 - alpha(t) = sqrt(1 - t²) 

132 - sigma(t) = t 

133 

134 Both functions map tensors of shape (N,) to tensors of the same shape. 

135 """ 

136 super().__init__(alpha=lambda t: torch.sqrt(1 - t**2), sigma=lambda t: t) 

137 

138 

139class FlowMatchingProcess(DiffusionProcess): 

140 """ 

141 Implements a Flow Matching diffusion process. 

142 

143 Flow Matching is a technique used in generative modeling where the goal is to learn 

144 a continuous transformation (flow) between a simple distribution and a complex data 

145 distribution. 

146 

147 In this implementation: 

148 - alpha(t) = 1 - t 

149 - sigma(t) = t 

150 

151 This creates a linear interpolation between the original data (at t=0) and 

152 the noise (at t=1), which is useful for training flow-based generative models. 

153 """ 

154 

155 def __init__(self) -> None: 

156 """ 

157 Initialize a Flow Matching diffusion process with predefined dynamics. 

158 

159 The process uses: 

160 - alpha(t) = 1 - t 

161 - sigma(t) = t 

162 

163 Both functions map tensors of shape (N,) to tensors of the same shape. 

164 This creates a linear interpolation between the original data and noise. 

165 """ 

166 super().__init__(alpha=lambda t: 1 - t, sigma=lambda t: t)