Coverage for src/diffusionlab/distributions/gmm.py: 100%

212 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from typing import Any, Dict, Tuple 

2 

3import torch 

4 

5from diffusionlab.diffusions import DiffusionProcess 

6from diffusionlab.distributions.base import Distribution 

7from diffusionlab.utils import logdet_pd, sqrt_psd 

8 

9 

10class GMMDistribution(Distribution): 

11 """ 

12 A Gaussian Mixture Model (GMM) with K components. 

13 Formally, the distribution is defined as: 

14 

15 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B) 

16 

17 where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component, 

18 and pi_i is the prior probability of the ith component. 

19 

20 Distribution Parameters: 

21 - means: A tensor of shape (K, D) containing the means of the components. 

22 - covs: A tensor of shape (K, D, D) containing the covariance matrices of the components. 

23 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

24 

25 Distribution Hyperparameters: 

26 - None 

27 """ 

28 

29 @classmethod 

30 def validate_params( 

31 cls, possibly_batched_dist_params: Dict[str, torch.Tensor] 

32 ) -> None: 

33 assert ( 

34 "means" in possibly_batched_dist_params 

35 and "covs" in possibly_batched_dist_params 

36 and "priors" in possibly_batched_dist_params 

37 ) 

38 means = possibly_batched_dist_params["means"] 

39 covs = possibly_batched_dist_params["covs"] 

40 priors = possibly_batched_dist_params["priors"] 

41 

42 if len(means.shape) == 2: 

43 assert len(covs.shape) == 3 

44 assert len(priors.shape) == 1 

45 means = means[None, :, :] 

46 covs = covs[None, :, :, :] 

47 priors = priors[None, :] 

48 

49 assert len(means.shape) == 3 

50 assert len(covs.shape) == 4 

51 assert len(priors.shape) == 2 

52 

53 N, K, D = means.shape 

54 assert ( 

55 len(covs.shape) == 4 

56 and covs.shape[0] == N 

57 and covs.shape[1] == K 

58 and covs.shape[2] == D 

59 and covs.shape[3] == D 

60 ) 

61 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K 

62 assert means.device == covs.device == priors.device 

63 

64 assert torch.all(priors >= 0) 

65 sum_priors = torch.sum(priors, dim=-1) 

66 assert torch.allclose(sum_priors, torch.ones_like(sum_priors)) 

67 

68 evals = torch.linalg.eigvalsh(covs) 

69 assert torch.all( 

70 evals >= -D * torch.finfo(evals.dtype).eps 

71 ) # Allow for numerical errors 

72 

73 @classmethod 

74 def x0( 

75 cls, 

76 x_t: torch.Tensor, 

77 t: torch.Tensor, 

78 diffusion_process: DiffusionProcess, 

79 batched_dist_params: Dict[str, torch.Tensor], 

80 dist_hparams: Dict[str, Any], 

81 ) -> torch.Tensor: 

82 """ 

83 Computes the denoiser E[x_0 | x_t] for a GMM distribution. 

84 

85 Arguments: 

86 x_t: The input tensor, of shape (N, D). 

87 t: The time tensor, of shape (N, ). 

88 diffusion_process: The diffusion process. 

89 batched_dist_params: A dictionary containing the batched parameters of the distribution. 

90 - means: A tensor of shape (N, K, D) containing the means of the components. 

91 - covs: A tensor of shape (N, K, D, D) containing the covariance matrices of the components. 

92 - priors: A tensor of shape (N, K) containing the prior probabilities of the components. 

93 dist_hparams: A dictionary of hyperparameters for the distribution. 

94 

95 Returns: 

96 The prediction of x_0, of shape (N, D). 

97 

98 Note: 

99 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

100 """ 

101 means = batched_dist_params["means"] # (N, K, D) 

102 covs = batched_dist_params["covs"] # (N, K, D, D) 

103 priors = batched_dist_params["priors"] # (N, K) 

104 

105 N, K, D = means.shape 

106 

107 alpha = diffusion_process.alpha(t) # (N, ) 

108 sigma = diffusion_process.sigma(t) # (N, ) 

109 

110 covs_t = (alpha[:, None, None, None] ** 2) * covs + ( 

111 sigma[:, None, None, None] ** 2 

112 ) * torch.eye(D, device=x_t.device)[None, None, :, :] # (N, K, D, D) 

113 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D) 

114 covs_t_inv_centered_x = torch.linalg.lstsq( 

115 covs_t, # (N, K, D, D) 

116 centered_x[..., None], # (N, K, D, 1) 

117 ).solution[..., 0] # (N, K, D, 1) -> (N, K, D) 

118 

119 mahalanobis_dists = torch.sum( 

120 centered_x * covs_t_inv_centered_x, dim=-1 

121 ) # (N, K) 

122 logdets_covs_t = logdet_pd(covs_t) # (N, K) 

123 w = ( 

124 torch.log(priors) - 1 / 2 * logdets_covs_t - 1 / 2 * mahalanobis_dists 

125 ) # (N, K) 

126 softmax_w = torch.softmax(w, dim=-1) # (N, K) 

127 

128 weighted_normalized_x = torch.sum( 

129 softmax_w[:, :, None] * covs_t_inv_centered_x, dim=-2 

130 ) # (N, D) 

131 x0_hat = (1 / alpha[:, None]) * ( 

132 x_t - (sigma[:, None] ** 2) * weighted_normalized_x 

133 ) # (N, D) 

134 

135 return x0_hat 

136 

137 @classmethod 

138 def sample( 

139 cls, 

140 N: int, 

141 dist_params: Dict[str, torch.Tensor], 

142 dist_hparams: Dict[str, Any], 

143 ) -> Tuple[torch.Tensor, torch.Tensor]: 

144 means = dist_params["means"] # (K, D) 

145 covs = dist_params["covs"] # (K, D, D) 

146 priors = dist_params["priors"] # (K, ) 

147 

148 K, D = means.shape 

149 

150 device = priors.device 

151 y = torch.multinomial(priors, N, replacement=True) # (N, ) 

152 X = torch.empty((N, D), device=device) 

153 for k in range(K): 

154 idx = y == k 

155 X[idx] = ( 

156 torch.randn((X[idx].shape[0], D), device=device) @ sqrt_psd(covs[k]) 

157 + means[k][None, :] 

158 ) 

159 return X.to("cpu"), y.to("cpu") 

160 

161 

162class IsoGMMDistribution(Distribution): 

163 """ 

164 An isotropic (i.e., spherical variances) Gaussian Mixture Model (GMM) with K components. 

165 Formally, the distribution is defined as: 

166 

167 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, tau_i^2 * I_D)(B) 

168 

169 where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances, 

170 and pi_i is the prior probability of the ith component. 

171 

172 Distribution Parameters: 

173 - means: A tensor of shape (K, D) containing the means of the components. 

174 - vars: A tensor of shape (K, ) containing the variances of the components. 

175 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

176 

177 Distribution Hyperparameters: 

178 - None 

179 """ 

180 

181 @classmethod 

182 def validate_params( 

183 cls, possibly_batched_dist_params: Dict[str, torch.Tensor] 

184 ) -> None: 

185 assert ( 

186 "means" in possibly_batched_dist_params 

187 and "vars" in possibly_batched_dist_params 

188 and "priors" in possibly_batched_dist_params 

189 ) 

190 means = possibly_batched_dist_params["means"] 

191 vars_ = possibly_batched_dist_params["vars"] 

192 priors = possibly_batched_dist_params["priors"] 

193 

194 if len(means.shape) == 2: 

195 assert len(vars_.shape) == 1 

196 assert len(priors.shape) == 1 

197 means = means[None, :, :] 

198 vars_ = vars_[None, :] 

199 priors = priors[None, :] 

200 

201 assert len(means.shape) == 3 

202 N, K, D = means.shape 

203 assert len(vars_.shape) == 2 and vars_.shape[0] == N and vars_.shape[1] == K 

204 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K 

205 assert means.device == vars_.device == priors.device 

206 

207 priors_sum = torch.sum(priors, dim=-1) 

208 assert torch.all(priors_sum >= 0) 

209 assert torch.allclose(priors_sum, torch.ones_like(priors_sum)) 

210 assert torch.all( 

211 vars_ >= -D * torch.finfo(vars_.dtype).eps 

212 ) # Allow for numerical errors 

213 

214 @classmethod 

215 def x0( 

216 cls, 

217 x_t: torch.Tensor, 

218 t: torch.Tensor, 

219 diffusion_process: DiffusionProcess, 

220 batched_dist_params: Dict[str, torch.Tensor], 

221 dist_hparams: Dict[str, Any], 

222 ) -> torch.Tensor: 

223 """ 

224 Computes the denoiser E[x_0 | x_t] for an isotropic GMM distribution. 

225 

226 Arguments: 

227 x_t: The input tensor, of shape (N, D). 

228 t: The time tensor, of shape (N, ). 

229 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

230 the time-evolution of the vector fields corresponding to the distribution. 

231 batched_dist_params: A dictionary containing the batched parameters of the distribution. 

232 - means: A tensor of shape (N, K, D) containing the means of the components. 

233 - vars: A tensor of shape (N, K) containing the variances of the components. 

234 - priors: A tensor of shape (N, K) containing the prior probabilities of the components. 

235 dist_hparams: A dictionary of hyperparameters for the distribution. 

236 

237 Returns: 

238 The prediction of x_0, of shape (N, D). 

239 

240 Note: 

241 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

242 """ 

243 means = batched_dist_params["means"] # (N, K, D) 

244 vars_ = batched_dist_params["vars"] # (N, K) 

245 priors = batched_dist_params["priors"] # (N, K) 

246 

247 N, K, D = means.shape 

248 

249 alpha = diffusion_process.alpha(t) # (N, ) 

250 sigma = diffusion_process.sigma(t) # (N, ) 

251 

252 vars_t = (alpha[:, None] ** 2) * vars_ + (sigma[:, None] ** 2) # (N, K) 

253 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D) 

254 vars_t_inv_centered_x = centered_x / vars_t[:, :, None] # (N, K, D) 

255 

256 mahalanobis_dists = torch.sum( 

257 centered_x * vars_t_inv_centered_x, dim=-1 

258 ) # (N, K) 

259 w = ( 

260 torch.log(priors) - D / 2 * torch.log(vars_t) - 1 / 2 * mahalanobis_dists 

261 ) # (N, K) 

262 softmax_w = torch.softmax(w, dim=-1) # (N, K) 

263 

264 weighted_normalized_x = torch.sum( 

265 softmax_w[:, :, None] * vars_t_inv_centered_x, dim=-2 

266 ) # (N, D) 

267 x0_hat = (1 / alpha[:, None]) * ( 

268 x_t - (sigma[:, None] ** 2) * weighted_normalized_x 

269 ) # (N, D) 

270 

271 return x0_hat 

272 

273 @classmethod 

274 def sample( 

275 cls, 

276 N: int, 

277 dist_params: Dict[str, torch.Tensor], 

278 dist_hparams: Dict[str, Any], 

279 ) -> Tuple[torch.Tensor, torch.Tensor]: 

280 """ 

281 Draws N i.i.d. samples from the isotropic GMM distribution. 

282 

283 Arguments: 

284 N: The number of samples to draw. 

285 dist_params: A dictionary of parameters for the distribution. 

286 - means: A tensor of shape (K, D) containing the means of the components. 

287 - vars: A tensor of shape (K, ) containing the variances of the components. 

288 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

289 dist_hparams: A dictionary of hyperparameters for the distribution. 

290 

291 Returns: 

292 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) 

293 containing the component indices from which each sample was drawn. 

294 Note that the samples are always placed on the CPU. 

295 """ 

296 means = dist_params["means"] # (K, D) 

297 vars_ = dist_params["vars"] # (K, ) 

298 priors = dist_params["priors"] # (K, ) 

299 

300 K, D = means.shape 

301 covs = ( 

302 torch.eye(D, device=vars_.device)[None, :, :].expand(K, -1, -1) 

303 * vars_[:, None, None] 

304 ) 

305 return GMMDistribution.sample( 

306 N, {"means": means, "covs": covs, "priors": priors}, dict() 

307 ) 

308 

309 

310class IsoHomoGMMDistribution(Distribution): 

311 """ 

312 An isotropic homoscedastic (i.e., equal spherical variances) Gaussian Mixture Model (GMM) with K components. 

313 Formally, the distribution is defined as: 

314 

315 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, tau^2 * I_D)(B) 

316 

317 where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances, 

318 and pi_i is the prior probability of the ith component. 

319 

320 Distribution Parameters: 

321 - means: A tensor of shape (K, D) containing the means of the components. 

322 - var: A tensor of shape () containing the variances of the components. 

323 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

324 

325 Distribution Hyperparameters: 

326 - None 

327 """ 

328 

329 @classmethod 

330 def validate_params( 

331 cls, possibly_batched_dist_params: Dict[str, torch.Tensor] 

332 ) -> None: 

333 assert ( 

334 "means" in possibly_batched_dist_params 

335 and "var" in possibly_batched_dist_params 

336 and "priors" in possibly_batched_dist_params 

337 ) 

338 means = possibly_batched_dist_params["means"] 

339 var = possibly_batched_dist_params["var"] 

340 priors = possibly_batched_dist_params["priors"] 

341 

342 if len(means.shape) == 2: 

343 assert len(var.shape) == 0 

344 assert len(priors.shape) == 1 

345 means = means[None, :, :] 

346 var = var[None] 

347 priors = priors[None, :] 

348 

349 assert len(means.shape) == 3 

350 N, K, D = means.shape 

351 assert len(var.shape) == 1 and var.shape[0] == N 

352 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K 

353 assert means.device == var.device == priors.device 

354 

355 priors_sum = torch.sum(priors, dim=-1) 

356 assert torch.all(priors_sum >= 0) 

357 assert torch.allclose(priors_sum, torch.ones_like(priors_sum)) 

358 assert torch.all( 

359 var >= -D * torch.finfo(var.dtype).eps 

360 ) # Allow for numerical errors 

361 

362 @classmethod 

363 def x0( 

364 cls, 

365 x_t: torch.Tensor, 

366 t: torch.Tensor, 

367 diffusion_process: DiffusionProcess, 

368 batched_dist_params: Dict[str, torch.Tensor], 

369 dist_hparams: Dict[str, Any], 

370 ) -> torch.Tensor: 

371 """ 

372 Computes the denoiser E[x_0 | x_t] for an isotropic homoscedastic GMM distribution. 

373 

374 Arguments: 

375 x_t: The input tensor, of shape (N, D). 

376 t: The time tensor, of shape (N, ). 

377 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

378 the time-evolution of the vector fields corresponding to the distribution. 

379 batched_dist_params: A dictionary containing the batched parameters of the distribution. 

380 - means: A tensor of shape (N, K, D) containing the means of the components. 

381 - var: A tensor of shape (N, ) containing the shared variance of all components. 

382 - priors: A tensor of shape (N, K) containing the prior probabilities of the components. 

383 dist_hparams: A dictionary of hyperparameters for the distribution. 

384 

385 Returns: 

386 The prediction of x_0, of shape (N, D). 

387 

388 Note: 

389 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

390 """ 

391 means = batched_dist_params["means"] # (N, K, D) 

392 var = batched_dist_params["var"] # (N, ) 

393 priors = batched_dist_params["priors"] # (N, K) 

394 

395 N, K, D = means.shape 

396 

397 alpha = diffusion_process.alpha(t) # (N, ) 

398 sigma = diffusion_process.sigma(t) # (N, ) 

399 

400 var_t = (alpha**2) * var + (sigma**2) # (N, ) 

401 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D) 

402 vars_t_inv_centered_x = centered_x / var_t[:, None, None] # (N, K, D) 

403 

404 mahalanobis_dists = torch.sum( 

405 centered_x * vars_t_inv_centered_x, dim=-1 

406 ) # (N, K) 

407 w = torch.log(priors) - 1 / 2 * mahalanobis_dists # (N, K) 

408 softmax_w = torch.softmax(w, dim=-1) # (N, K) 

409 

410 weighted_normalized_x = torch.sum( 

411 softmax_w[:, :, None] * vars_t_inv_centered_x, dim=-2 

412 ) # (N, D) 

413 x0_hat = (1 / alpha[:, None]) * ( 

414 x_t - (sigma[:, None] ** 2) * weighted_normalized_x 

415 ) # (N, D) 

416 

417 return x0_hat 

418 

419 @classmethod 

420 def sample( 

421 cls, 

422 N: int, 

423 dist_params: Dict[str, torch.Tensor], 

424 dist_hparams: Dict[str, Any], 

425 ) -> Tuple[torch.Tensor, torch.Tensor]: 

426 """ 

427 Draws N i.i.d. samples from the isotropic homoscedastic GMM distribution. 

428 

429 Arguments: 

430 N: The number of samples to draw. 

431 dist_params: A dictionary of parameters for the distribution. 

432 - means: A tensor of shape (K, D) containing the means of the components. 

433 - var: A tensor of shape () containing the shared variance of all components. 

434 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

435 dist_hparams: A dictionary of hyperparameters for the distribution. 

436 

437 Returns: 

438 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) 

439 containing the component indices from which each sample was drawn. 

440 Note that the samples are always placed on the CPU. 

441 """ 

442 means = dist_params["means"] # (K, D) 

443 var = dist_params["var"] # () 

444 priors = dist_params["priors"] # (K, ) 

445 

446 K, D = means.shape 

447 covs = torch.eye(D, device=var.device)[None, :, :].expand(K, -1, -1) * var 

448 return GMMDistribution.sample( 

449 N, {"means": means, "covs": covs, "priors": priors}, dict() 

450 ) 

451 

452 

453class LowRankGMMDistribution(Distribution): 

454 """ 

455 A Gaussian Mixture Model (GMM) with K low-rank components. 

456 Formally, the distribution is defined as: 

457 

458 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B) 

459 

460 where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component, 

461 and pi_i is the prior probability of the ith component. Notably, Sigma_i is a low-rank matrix of the form 

462 

463 Sigma_i = A_i @ A_i^T 

464 

465 Distribution Parameters: 

466 - means: A tensor of shape (K, D) containing the means of the components. 

467 - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices of the components. 

468 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

469 

470 Distribution Hyperparameters: 

471 - None 

472 

473 Note: 

474 - The covariance matrices are not explicitly stored, but rather computed as Sigma_i = A_i @ A_i^T. 

475 - The time and memory complexity is much lower in this class compared to the full GMM class, if and only if each covariance is low-rank (P << D). 

476 """ 

477 

478 @classmethod 

479 def validate_params( 

480 cls, possibly_batched_dist_params: Dict[str, torch.Tensor] 

481 ) -> None: 

482 assert ( 

483 "means" in possibly_batched_dist_params 

484 and "covs_factors" in possibly_batched_dist_params 

485 and "priors" in possibly_batched_dist_params 

486 ) 

487 means = possibly_batched_dist_params["means"] 

488 covs_factors = possibly_batched_dist_params["covs_factors"] 

489 priors = possibly_batched_dist_params["priors"] 

490 

491 if len(means.shape) == 2: 

492 assert len(covs_factors.shape) == 3 

493 assert len(priors.shape) == 1 

494 means = means[None, :, :] 

495 covs_factors = covs_factors[None, :, :, :] 

496 priors = priors[None, :] 

497 

498 assert len(means.shape) == 3 

499 assert len(covs_factors.shape) == 4 

500 assert len(priors.shape) == 2 

501 

502 N, K, D, P = covs_factors.shape 

503 assert means.shape[0] == N and means.shape[1] == K and means.shape[2] == D 

504 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K 

505 assert means.device == covs_factors.device == priors.device 

506 

507 assert torch.all(priors >= 0) 

508 sum_priors = torch.sum(priors, dim=-1) 

509 assert torch.allclose(sum_priors, torch.ones_like(sum_priors)) 

510 

511 @classmethod 

512 def x0( 

513 cls, 

514 x_t: torch.Tensor, 

515 t: torch.Tensor, 

516 diffusion_process: DiffusionProcess, 

517 batched_dist_params: Dict[str, torch.Tensor], 

518 dist_hparams: Dict[str, Any], 

519 ) -> torch.Tensor: 

520 """ 

521 Computes the denoiser E[x_0 | x_t] for a low-rank GMM distribution. 

522 

523 Arguments: 

524 x_t: The input tensor, of shape (N, D). 

525 t: The time tensor, of shape (N, ). 

526 diffusion_process: The diffusion process whose forward and reverse dynamics determine 

527 the time-evolution of the vector fields corresponding to the distribution. 

528 batched_dist_params: A dictionary containing the batched parameters of the distribution. 

529 - means: A tensor of shape (N, K, D) containing the means of the components. 

530 - covs_factors: A tensor of shape (N, K, D, P) containing the tall factors of the covariance matrices. 

531 - priors: A tensor of shape (N, K) containing the prior probabilities of the components. 

532 dist_hparams: A dictionary of hyperparameters for the distribution. 

533 

534 Returns: 

535 The prediction of x_0, of shape (N, D). 

536 

537 Note: 

538 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension. 

539 The covariance matrices are implicitly defined as Sigma_i = A_i @ A_i^T, where A_i is the ith factor. 

540 """ 

541 means = batched_dist_params["means"] # (N, K, D) 

542 covs_factors = batched_dist_params["covs_factors"] # (N, K, D, R) 

543 priors = batched_dist_params["priors"] # (N, K) 

544 

545 N, K, D, P = covs_factors.shape 

546 covs_factors_T = covs_factors.transpose(-1, -2) # (N, K, R, D) 

547 

548 alpha = diffusion_process.alpha(t) # (N, ) 

549 sigma = diffusion_process.sigma(t) # (N, ) 

550 alpha_sigma_ratio_sq = (alpha / sigma) ** 2 # (N, ) 

551 sigma_alpha_ratio_sq = 1 / alpha_sigma_ratio_sq # (N, ) 

552 

553 internal_covs = covs_factors_T @ covs_factors # (N, K, R, R) 

554 logdets_covs_t = 2 * D * torch.log(sigma[:, None]) + logdet_pd( 

555 torch.eye(P, device=covs_factors.device)[None, None, :, :] # (1, 1, P, P) 

556 + alpha_sigma_ratio_sq[:, None, None, None] * internal_covs # (N, K, P, P) 

557 ) # (N, K) 

558 

559 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D) 

560 covs_t_inv_centered_x = (1 / sigma[:, None, None] ** 2) * ( 

561 centered_x # (N, K, D) 

562 - ( 

563 covs_factors # (N, K, D, P) 

564 @ torch.linalg.lstsq( # (N, K, P, 1) 

565 internal_covs # (N, K, P, P) 

566 + sigma_alpha_ratio_sq[:, None, None, None] # (N, K, 1, 1) 

567 * torch.eye(P, device=internal_covs.device)[ 

568 None, None, :, : 

569 ], # (1, 1, P, P) 

570 covs_factors_T @ centered_x[:, :, :, None], # (N, K, P, 1) 

571 ).solution # (N, K, P, 1) 

572 )[:, :, :, 0] # (N, K, D, 1) -> (N, K, D) 

573 ) # (N, K, D) 

574 

575 mahalanobis_dists = torch.sum( 

576 centered_x * covs_t_inv_centered_x, dim=-1 

577 ) # (N, K) 

578 w = ( 

579 torch.log(priors) - 1 / 2 * logdets_covs_t - 1 / 2 * mahalanobis_dists 

580 ) # (N, K) 

581 softmax_w = torch.softmax(w, dim=-1) # (N, K) 

582 

583 weighted_normalized_x = torch.sum( 

584 softmax_w[:, :, None] * covs_t_inv_centered_x, dim=-2 

585 ) # (N, D) 

586 x0_hat = (1 / alpha[:, None]) * ( 

587 x_t - (sigma[:, None] ** 2) * weighted_normalized_x 

588 ) # (N, D) 

589 

590 return x0_hat 

591 

592 @classmethod 

593 def sample( 

594 cls, 

595 N: int, 

596 dist_params: Dict[str, torch.Tensor], 

597 dist_hparams: Dict[str, Any], 

598 ) -> Tuple[torch.Tensor, torch.Tensor]: 

599 """ 

600 Draws N i.i.d. samples from the low-rank GMM distribution. 

601 

602 Arguments: 

603 N: The number of samples to draw. 

604 dist_params: A dictionary of parameters for the distribution. 

605 - means: A tensor of shape (K, D) containing the means of the components. 

606 - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices. 

607 - priors: A tensor of shape (K, ) containing the prior probabilities of the components. 

608 dist_hparams: A dictionary of hyperparameters for the distribution. 

609 

610 Returns: 

611 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, ) 

612 containing the component indices from which each sample was drawn. 

613 Note that the samples are always placed on the CPU. 

614 """ 

615 means = dist_params["means"] # (K, D) 

616 covs_factors = dist_params["covs_factors"] # (K, D, P) 

617 priors = dist_params["priors"] # (K, ) 

618 

619 K, D, P = covs_factors.shape 

620 

621 device = priors.device 

622 y = torch.multinomial(priors, N, replacement=True) # (N, ) 

623 X = torch.empty((N, D), device=device) 

624 for k in range(K): 

625 idx = y == k 

626 X[idx] = ( 

627 torch.randn((X[idx].shape[0], P), device=device) @ covs_factors[k].T 

628 + means[k][None, :] 

629 ) 

630 return X.to("cpu"), y.to("cpu")