Coverage for src/diffusionlab/distributions/base.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-21 15:48 -0700

1from typing import Any, Dict, Tuple, Callable 

2 

3import torch 

4 

5from diffusionlab.diffusions import DiffusionProcess 

6from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type 

7 

8 

9class Distribution: 

10 """ 

11 Base class for all distributions. 

12 

13 This class should be subclassed by other distributions when you want to use ground truth 

14 scores, denoisers, noise predictors, or velocity estimators. 

15 

16 Each distribution implementation provides methods to compute various vector fields 

17 related to the diffusion process, such as denoising (x0), noise prediction (eps), 

18 velocity estimation (v), and score estimation. 

19 """ 

20 

21 @classmethod 

22 def validate_hparams(cls, dist_hparams: Dict[str, Any]) -> None: 

23 """ 

24 Validate the hyperparameters for the distribution. 

25 

26 Arguments: 

27 dist_hparams: A dictionary of hyperparameters for the distribution. 

28 

29 Returns: 

30 None 

31 

32 Throws: 

33 AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure. 

34 """ 

35 assert len(dist_hparams) == 0 

36 

37 @classmethod 

38 def get_vector_field_method( 

39 cls, vector_field_type: VectorFieldType 

40 ) -> Callable[[torch.Tensor, torch.Tensor, DiffusionProcess, Dict[str, torch.Tensor], Dict[str, Any]], torch.Tensor]: 

41 """ 

42 Returns the appropriate method to compute the specified vector field type. 

43 

44 Arguments: 

45 vector_field_type: The type of vector field to compute. 

46 

47 Returns: 

48 A method that computes the specified vector field, with signature: 

49 (x_t, t, diffusion_process, batched_dist_params, dist_hparams) -> tensor 

50 

51 Raises: 

52 ValueError: If the vector field type is not recognized. 

53 """ 

54 if vector_field_type == VectorFieldType.X0: 

55 return cls.x0 

56 elif vector_field_type == VectorFieldType.EPS: 

57 return cls.eps 

58 elif vector_field_type == VectorFieldType.V: 

59 return cls.v 

60 elif vector_field_type == VectorFieldType.SCORE: 

61 return cls.score 

62 else: 

63 raise ValueError(f"Unrecognized vector field type: {vector_field_type}") 

64 

65 @classmethod 

66 def validate_params( 

67 cls, possibly_batched_dist_params: Dict[str, torch.Tensor] 

68 ) -> None: 

69 """ 

70 Validate the parameters for the distribution. 

71 

72 Arguments: 

73 possibly_batched_dist_params: A dictionary of parameters for the distribution. 

74 Each value is a PyTorch tensor, possibly having a batch dimension. 

75 

76 Returns: 

77 None 

78 

79 Throws: 

80 AssertionError: If the parameters are invalid, the assertion fails at exactly the point of failure. 

81 """ 

82 assert len(possibly_batched_dist_params) == 0 

83 

84 @classmethod 

85 def x0( 

86 cls, 

87 x_t: torch.Tensor, 

88 t: torch.Tensor, 

89 diffusion_process: DiffusionProcess, 

90 batched_dist_params: Dict[str, torch.Tensor], 

91 dist_hparams: Dict[str, Any], 

92 ) -> torch.Tensor: 

93 """ 

94 Computes the denoiser E[x_0 | x_t] at a given time t and input x_t, under the data model 

95 

96 x_t = alpha(t) * x_0 + sigma(t) * eps 

97 

98 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). 

99 

100 Arguments: 

101 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data. 

102 t: The time tensor, of shape (N, ). 

103 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

104 the time-evolution of the vector fields corresponding to the distribution. 

105 batched_dist_params: A dictionary of batched parameters for the distribution. 

106 Each parameter is of shape (N, *P) where P is the shape of the parameter. 

107 dist_hparams: A dictionary of hyperparameters for the distribution. 

108 

109 Returns: 

110 The prediction of x_0, of shape (N, *D). 

111 

112 Note: 

113 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

114 """ 

115 raise NotImplementedError 

116 

117 @classmethod 

118 def eps( 

119 cls, 

120 x_t: torch.Tensor, 

121 t: torch.Tensor, 

122 diffusion_process: DiffusionProcess, 

123 batched_dist_params: Dict[str, torch.Tensor], 

124 dist_hparams: Dict[str, Any], 

125 ) -> torch.Tensor: 

126 """ 

127 Computes the noise predictor E[eps | x_t] at a given time t and input x_t, under the data model 

128 

129 x_t = alpha(t) * x_0 + sigma(t) * eps 

130 

131 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). 

132 This is stateless for the same reason as the denoiser method. 

133 

134 Arguments: 

135 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data. 

136 t: The time tensor, of shape (N, ). 

137 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

138 the time-evolution of the vector fields corresponding to the distribution. 

139 batched_dist_params: A dictionary of batched parameters for the distribution. 

140 Each parameter is of shape (N, *P) where P is the shape of the parameter. 

141 dist_hparams: A dictionary of hyperparameters for the distribution. 

142 

143 Returns: 

144 The prediction of eps, of shape (N, *D). 

145 

146 Note: 

147 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

148 """ 

149 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams) 

150 eps_hat = convert_vector_field_type( 

151 x_t, 

152 x0_hat, 

153 diffusion_process.alpha(t), 

154 diffusion_process.sigma(t), 

155 diffusion_process.alpha_prime(t), 

156 diffusion_process.sigma_prime(t), 

157 in_type=VectorFieldType.X0, 

158 out_type=VectorFieldType.EPS, 

159 ) 

160 return eps_hat 

161 

162 @classmethod 

163 def v( 

164 cls, 

165 x_t: torch.Tensor, 

166 t: torch.Tensor, 

167 diffusion_process: DiffusionProcess, 

168 batched_dist_params: Dict[str, torch.Tensor], 

169 dist_hparams: Dict[str, Any], 

170 ) -> torch.Tensor: 

171 """ 

172 Computes the velocity estimator E[d/dt x_t | x_t] at a given time t and input x_t, under the data model 

173 

174 x_t = alpha(t) * x_0 + sigma(t) * eps 

175 

176 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). 

177 This is stateless for the same reason as the denoiser method. 

178 

179 Arguments: 

180 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data. 

181 t: The time tensor, of shape (N, ). 

182 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

183 the time-evolution of the vector fields corresponding to the distribution. 

184 batched_dist_params: A dictionary of batched parameters for the distribution. 

185 Each parameter is of shape (N, *P) where P is the shape of the parameter. 

186 dist_hparams: A dictionary of hyperparameters for the distribution. 

187 

188 Returns: 

189 The prediction of d/dt x_t, of shape (N, *D). 

190 

191 Note: 

192 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

193 """ 

194 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams) 

195 v_hat = convert_vector_field_type( 

196 x_t, 

197 x0_hat, 

198 diffusion_process.alpha(t), 

199 diffusion_process.sigma(t), 

200 diffusion_process.alpha_prime(t), 

201 diffusion_process.sigma_prime(t), 

202 in_type=VectorFieldType.X0, 

203 out_type=VectorFieldType.V, 

204 ) 

205 return v_hat 

206 

207 @classmethod 

208 def score( 

209 cls, 

210 x_t: torch.Tensor, 

211 t: torch.Tensor, 

212 diffusion_process: DiffusionProcess, 

213 batched_dist_params: Dict[str, torch.Tensor], 

214 dist_hparams: Dict[str, Any], 

215 ) -> torch.Tensor: 

216 """ 

217 Computes the score estimator grad_x log p(x_t, t) at a given time t and input x_t, under the data model 

218 

219 x_t = alpha(t) * x_0 + sigma(t) * eps 

220 

221 where x_0 is drawn from the data distribution, and eps is drawn independently from N(0, I). 

222 This is stateless for the same reason as the denoiser method. 

223 

224 Arguments: 

225 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data. 

226 t: The time tensor, of shape (N, ). 

227 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

228 the time-evolution of the vector fields corresponding to the distribution. 

229 batched_dist_params: A dictionary of batched parameters for the distribution. 

230 Each parameter is of shape (N, *P) where P is the shape of the parameter. 

231 dist_hparams: A dictionary of hyperparameters for the distribution. 

232 

233 Returns: 

234 The prediction of grad_x log p(x_t, t), of shape (N, *D). 

235 

236 Note: 

237 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

238 """ 

239 x0_hat = cls.x0(x_t, t, diffusion_process, batched_dist_params, dist_hparams) 

240 score_hat = convert_vector_field_type( 

241 x_t, 

242 x0_hat, 

243 diffusion_process.alpha(t), 

244 diffusion_process.sigma(t), 

245 diffusion_process.alpha_prime(t), 

246 diffusion_process.sigma_prime(t), 

247 in_type=VectorFieldType.X0, 

248 out_type=VectorFieldType.SCORE, 

249 ) 

250 return score_hat 

251 

252 @classmethod 

253 def sample( 

254 cls, 

255 N: int, 

256 dist_params: Dict[str, torch.Tensor], 

257 dist_hparams: Dict[str, Any], 

258 ) -> Tuple[torch.Tensor, Any]: 

259 """ 

260 Draws N i.i.d. samples from the data distribution. 

261 

262 Arguments: 

263 N: The number of samples to draw. 

264 dist_params: A dictionary of parameters for the distribution. 

265 dist_hparams: A dictionary of hyperparameters for the distribution. 

266 

267 Returns: 

268 A tuple (samples, metadata), where samples is a tensor of shape (N, *D) and metadata is any additional information. 

269 For example, if the distribution has labels, the metadata is a tensor of shape (N, ) containing the labels. 

270 Note that the samples are always placed on the CPU. 

271 """ 

272 raise NotImplementedError 

273 

274 @staticmethod 

275 def batch_dist_params( 

276 N: int, dist_params: Dict[str, torch.Tensor] 

277 ) -> Dict[str, torch.Tensor]: 

278 """ 

279 Add a batch dimension to the distribution parameters. 

280 

281 Arguments: 

282 N: The number of samples in the batch. 

283 dist_params: A dictionary of parameters for the distribution. 

284 

285 Returns: 

286 A dictionary of parameters for the distribution, with a batch dimension added. 

287 """ 

288 return {k: v.unsqueeze(0).expand(N, *v.shape) for k, v in dist_params.items()}