Coverage for src/diffusionlab/distributions/gmm.py: 100%
212 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from typing import Any, Dict, Tuple
3import torch
5from diffusionlab.diffusions import DiffusionProcess
6from diffusionlab.distributions.base import Distribution
7from diffusionlab.utils import logdet_pd, sqrt_psd
10class GMMDistribution(Distribution):
11 """
12 A Gaussian Mixture Model (GMM) with K components.
13 Formally, the distribution is defined as:
15 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B)
17 where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component,
18 and pi_i is the prior probability of the ith component.
20 Distribution Parameters:
21 - means: A tensor of shape (K, D) containing the means of the components.
22 - covs: A tensor of shape (K, D, D) containing the covariance matrices of the components.
23 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
25 Distribution Hyperparameters:
26 - None
27 """
29 @classmethod
30 def validate_params(
31 cls, possibly_batched_dist_params: Dict[str, torch.Tensor]
32 ) -> None:
33 assert (
34 "means" in possibly_batched_dist_params
35 and "covs" in possibly_batched_dist_params
36 and "priors" in possibly_batched_dist_params
37 )
38 means = possibly_batched_dist_params["means"]
39 covs = possibly_batched_dist_params["covs"]
40 priors = possibly_batched_dist_params["priors"]
42 if len(means.shape) == 2:
43 assert len(covs.shape) == 3
44 assert len(priors.shape) == 1
45 means = means[None, :, :]
46 covs = covs[None, :, :, :]
47 priors = priors[None, :]
49 assert len(means.shape) == 3
50 assert len(covs.shape) == 4
51 assert len(priors.shape) == 2
53 N, K, D = means.shape
54 assert (
55 len(covs.shape) == 4
56 and covs.shape[0] == N
57 and covs.shape[1] == K
58 and covs.shape[2] == D
59 and covs.shape[3] == D
60 )
61 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K
62 assert means.device == covs.device == priors.device
64 assert torch.all(priors >= 0)
65 sum_priors = torch.sum(priors, dim=-1)
66 assert torch.allclose(sum_priors, torch.ones_like(sum_priors))
68 evals = torch.linalg.eigvalsh(covs)
69 assert torch.all(
70 evals >= -D * torch.finfo(evals.dtype).eps
71 ) # Allow for numerical errors
73 @classmethod
74 def x0(
75 cls,
76 x_t: torch.Tensor,
77 t: torch.Tensor,
78 diffusion_process: DiffusionProcess,
79 batched_dist_params: Dict[str, torch.Tensor],
80 dist_hparams: Dict[str, Any],
81 ) -> torch.Tensor:
82 """
83 Computes the denoiser E[x_0 | x_t] for a GMM distribution.
85 Arguments:
86 x_t: The input tensor, of shape (N, D).
87 t: The time tensor, of shape (N, ).
88 diffusion_process: The diffusion process.
89 batched_dist_params: A dictionary containing the batched parameters of the distribution.
90 - means: A tensor of shape (N, K, D) containing the means of the components.
91 - covs: A tensor of shape (N, K, D, D) containing the covariance matrices of the components.
92 - priors: A tensor of shape (N, K) containing the prior probabilities of the components.
93 dist_hparams: A dictionary of hyperparameters for the distribution.
95 Returns:
96 The prediction of x_0, of shape (N, D).
98 Note:
99 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
100 """
101 means = batched_dist_params["means"] # (N, K, D)
102 covs = batched_dist_params["covs"] # (N, K, D, D)
103 priors = batched_dist_params["priors"] # (N, K)
105 N, K, D = means.shape
107 alpha = diffusion_process.alpha(t) # (N, )
108 sigma = diffusion_process.sigma(t) # (N, )
110 covs_t = (alpha[:, None, None, None] ** 2) * covs + (
111 sigma[:, None, None, None] ** 2
112 ) * torch.eye(D, device=x_t.device)[None, None, :, :] # (N, K, D, D)
113 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D)
114 covs_t_inv_centered_x = torch.linalg.lstsq(
115 covs_t, # (N, K, D, D)
116 centered_x[..., None], # (N, K, D, 1)
117 ).solution[..., 0] # (N, K, D, 1) -> (N, K, D)
119 mahalanobis_dists = torch.sum(
120 centered_x * covs_t_inv_centered_x, dim=-1
121 ) # (N, K)
122 logdets_covs_t = logdet_pd(covs_t) # (N, K)
123 w = (
124 torch.log(priors) - 1 / 2 * logdets_covs_t - 1 / 2 * mahalanobis_dists
125 ) # (N, K)
126 softmax_w = torch.softmax(w, dim=-1) # (N, K)
128 weighted_normalized_x = torch.sum(
129 softmax_w[:, :, None] * covs_t_inv_centered_x, dim=-2
130 ) # (N, D)
131 x0_hat = (1 / alpha[:, None]) * (
132 x_t - (sigma[:, None] ** 2) * weighted_normalized_x
133 ) # (N, D)
135 return x0_hat
137 @classmethod
138 def sample(
139 cls,
140 N: int,
141 dist_params: Dict[str, torch.Tensor],
142 dist_hparams: Dict[str, Any],
143 ) -> Tuple[torch.Tensor, torch.Tensor]:
144 means = dist_params["means"] # (K, D)
145 covs = dist_params["covs"] # (K, D, D)
146 priors = dist_params["priors"] # (K, )
148 K, D = means.shape
150 device = priors.device
151 y = torch.multinomial(priors, N, replacement=True) # (N, )
152 X = torch.empty((N, D), device=device)
153 for k in range(K):
154 idx = y == k
155 X[idx] = (
156 torch.randn((X[idx].shape[0], D), device=device) @ sqrt_psd(covs[k])
157 + means[k][None, :]
158 )
159 return X.to("cpu"), y.to("cpu")
162class IsoGMMDistribution(Distribution):
163 """
164 An isotropic (i.e., spherical variances) Gaussian Mixture Model (GMM) with K components.
165 Formally, the distribution is defined as:
167 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, tau_i^2 * I_D)(B)
169 where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances,
170 and pi_i is the prior probability of the ith component.
172 Distribution Parameters:
173 - means: A tensor of shape (K, D) containing the means of the components.
174 - vars: A tensor of shape (K, ) containing the variances of the components.
175 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
177 Distribution Hyperparameters:
178 - None
179 """
181 @classmethod
182 def validate_params(
183 cls, possibly_batched_dist_params: Dict[str, torch.Tensor]
184 ) -> None:
185 assert (
186 "means" in possibly_batched_dist_params
187 and "vars" in possibly_batched_dist_params
188 and "priors" in possibly_batched_dist_params
189 )
190 means = possibly_batched_dist_params["means"]
191 vars_ = possibly_batched_dist_params["vars"]
192 priors = possibly_batched_dist_params["priors"]
194 if len(means.shape) == 2:
195 assert len(vars_.shape) == 1
196 assert len(priors.shape) == 1
197 means = means[None, :, :]
198 vars_ = vars_[None, :]
199 priors = priors[None, :]
201 assert len(means.shape) == 3
202 N, K, D = means.shape
203 assert len(vars_.shape) == 2 and vars_.shape[0] == N and vars_.shape[1] == K
204 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K
205 assert means.device == vars_.device == priors.device
207 priors_sum = torch.sum(priors, dim=-1)
208 assert torch.all(priors_sum >= 0)
209 assert torch.allclose(priors_sum, torch.ones_like(priors_sum))
210 assert torch.all(
211 vars_ >= -D * torch.finfo(vars_.dtype).eps
212 ) # Allow for numerical errors
214 @classmethod
215 def x0(
216 cls,
217 x_t: torch.Tensor,
218 t: torch.Tensor,
219 diffusion_process: DiffusionProcess,
220 batched_dist_params: Dict[str, torch.Tensor],
221 dist_hparams: Dict[str, Any],
222 ) -> torch.Tensor:
223 """
224 Computes the denoiser E[x_0 | x_t] for an isotropic GMM distribution.
226 Arguments:
227 x_t: The input tensor, of shape (N, D).
228 t: The time tensor, of shape (N, ).
229 diffusion_process: The diffusion process whose forward and reverse dynamics determine
230 the time-evolution of the vector fields corresponding to the distribution.
231 batched_dist_params: A dictionary containing the batched parameters of the distribution.
232 - means: A tensor of shape (N, K, D) containing the means of the components.
233 - vars: A tensor of shape (N, K) containing the variances of the components.
234 - priors: A tensor of shape (N, K) containing the prior probabilities of the components.
235 dist_hparams: A dictionary of hyperparameters for the distribution.
237 Returns:
238 The prediction of x_0, of shape (N, D).
240 Note:
241 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
242 """
243 means = batched_dist_params["means"] # (N, K, D)
244 vars_ = batched_dist_params["vars"] # (N, K)
245 priors = batched_dist_params["priors"] # (N, K)
247 N, K, D = means.shape
249 alpha = diffusion_process.alpha(t) # (N, )
250 sigma = diffusion_process.sigma(t) # (N, )
252 vars_t = (alpha[:, None] ** 2) * vars_ + (sigma[:, None] ** 2) # (N, K)
253 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D)
254 vars_t_inv_centered_x = centered_x / vars_t[:, :, None] # (N, K, D)
256 mahalanobis_dists = torch.sum(
257 centered_x * vars_t_inv_centered_x, dim=-1
258 ) # (N, K)
259 w = (
260 torch.log(priors) - D / 2 * torch.log(vars_t) - 1 / 2 * mahalanobis_dists
261 ) # (N, K)
262 softmax_w = torch.softmax(w, dim=-1) # (N, K)
264 weighted_normalized_x = torch.sum(
265 softmax_w[:, :, None] * vars_t_inv_centered_x, dim=-2
266 ) # (N, D)
267 x0_hat = (1 / alpha[:, None]) * (
268 x_t - (sigma[:, None] ** 2) * weighted_normalized_x
269 ) # (N, D)
271 return x0_hat
273 @classmethod
274 def sample(
275 cls,
276 N: int,
277 dist_params: Dict[str, torch.Tensor],
278 dist_hparams: Dict[str, Any],
279 ) -> Tuple[torch.Tensor, torch.Tensor]:
280 """
281 Draws N i.i.d. samples from the isotropic GMM distribution.
283 Arguments:
284 N: The number of samples to draw.
285 dist_params: A dictionary of parameters for the distribution.
286 - means: A tensor of shape (K, D) containing the means of the components.
287 - vars: A tensor of shape (K, ) containing the variances of the components.
288 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
289 dist_hparams: A dictionary of hyperparameters for the distribution.
291 Returns:
292 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, )
293 containing the component indices from which each sample was drawn.
294 Note that the samples are always placed on the CPU.
295 """
296 means = dist_params["means"] # (K, D)
297 vars_ = dist_params["vars"] # (K, )
298 priors = dist_params["priors"] # (K, )
300 K, D = means.shape
301 covs = (
302 torch.eye(D, device=vars_.device)[None, :, :].expand(K, -1, -1)
303 * vars_[:, None, None]
304 )
305 return GMMDistribution.sample(
306 N, {"means": means, "covs": covs, "priors": priors}, dict()
307 )
310class IsoHomoGMMDistribution(Distribution):
311 """
312 An isotropic homoscedastic (i.e., equal spherical variances) Gaussian Mixture Model (GMM) with K components.
313 Formally, the distribution is defined as:
315 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, tau^2 * I_D)(B)
317 where mu_i is the mean of the ith component, tau is the standard deviation of the spherical variances,
318 and pi_i is the prior probability of the ith component.
320 Distribution Parameters:
321 - means: A tensor of shape (K, D) containing the means of the components.
322 - var: A tensor of shape () containing the variances of the components.
323 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
325 Distribution Hyperparameters:
326 - None
327 """
329 @classmethod
330 def validate_params(
331 cls, possibly_batched_dist_params: Dict[str, torch.Tensor]
332 ) -> None:
333 assert (
334 "means" in possibly_batched_dist_params
335 and "var" in possibly_batched_dist_params
336 and "priors" in possibly_batched_dist_params
337 )
338 means = possibly_batched_dist_params["means"]
339 var = possibly_batched_dist_params["var"]
340 priors = possibly_batched_dist_params["priors"]
342 if len(means.shape) == 2:
343 assert len(var.shape) == 0
344 assert len(priors.shape) == 1
345 means = means[None, :, :]
346 var = var[None]
347 priors = priors[None, :]
349 assert len(means.shape) == 3
350 N, K, D = means.shape
351 assert len(var.shape) == 1 and var.shape[0] == N
352 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K
353 assert means.device == var.device == priors.device
355 priors_sum = torch.sum(priors, dim=-1)
356 assert torch.all(priors_sum >= 0)
357 assert torch.allclose(priors_sum, torch.ones_like(priors_sum))
358 assert torch.all(
359 var >= -D * torch.finfo(var.dtype).eps
360 ) # Allow for numerical errors
362 @classmethod
363 def x0(
364 cls,
365 x_t: torch.Tensor,
366 t: torch.Tensor,
367 diffusion_process: DiffusionProcess,
368 batched_dist_params: Dict[str, torch.Tensor],
369 dist_hparams: Dict[str, Any],
370 ) -> torch.Tensor:
371 """
372 Computes the denoiser E[x_0 | x_t] for an isotropic homoscedastic GMM distribution.
374 Arguments:
375 x_t: The input tensor, of shape (N, D).
376 t: The time tensor, of shape (N, ).
377 diffusion_process: The diffusion process whose forward and reverse dynamics determine
378 the time-evolution of the vector fields corresponding to the distribution.
379 batched_dist_params: A dictionary containing the batched parameters of the distribution.
380 - means: A tensor of shape (N, K, D) containing the means of the components.
381 - var: A tensor of shape (N, ) containing the shared variance of all components.
382 - priors: A tensor of shape (N, K) containing the prior probabilities of the components.
383 dist_hparams: A dictionary of hyperparameters for the distribution.
385 Returns:
386 The prediction of x_0, of shape (N, D).
388 Note:
389 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
390 """
391 means = batched_dist_params["means"] # (N, K, D)
392 var = batched_dist_params["var"] # (N, )
393 priors = batched_dist_params["priors"] # (N, K)
395 N, K, D = means.shape
397 alpha = diffusion_process.alpha(t) # (N, )
398 sigma = diffusion_process.sigma(t) # (N, )
400 var_t = (alpha**2) * var + (sigma**2) # (N, )
401 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D)
402 vars_t_inv_centered_x = centered_x / var_t[:, None, None] # (N, K, D)
404 mahalanobis_dists = torch.sum(
405 centered_x * vars_t_inv_centered_x, dim=-1
406 ) # (N, K)
407 w = torch.log(priors) - 1 / 2 * mahalanobis_dists # (N, K)
408 softmax_w = torch.softmax(w, dim=-1) # (N, K)
410 weighted_normalized_x = torch.sum(
411 softmax_w[:, :, None] * vars_t_inv_centered_x, dim=-2
412 ) # (N, D)
413 x0_hat = (1 / alpha[:, None]) * (
414 x_t - (sigma[:, None] ** 2) * weighted_normalized_x
415 ) # (N, D)
417 return x0_hat
419 @classmethod
420 def sample(
421 cls,
422 N: int,
423 dist_params: Dict[str, torch.Tensor],
424 dist_hparams: Dict[str, Any],
425 ) -> Tuple[torch.Tensor, torch.Tensor]:
426 """
427 Draws N i.i.d. samples from the isotropic homoscedastic GMM distribution.
429 Arguments:
430 N: The number of samples to draw.
431 dist_params: A dictionary of parameters for the distribution.
432 - means: A tensor of shape (K, D) containing the means of the components.
433 - var: A tensor of shape () containing the shared variance of all components.
434 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
435 dist_hparams: A dictionary of hyperparameters for the distribution.
437 Returns:
438 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, )
439 containing the component indices from which each sample was drawn.
440 Note that the samples are always placed on the CPU.
441 """
442 means = dist_params["means"] # (K, D)
443 var = dist_params["var"] # ()
444 priors = dist_params["priors"] # (K, )
446 K, D = means.shape
447 covs = torch.eye(D, device=var.device)[None, :, :].expand(K, -1, -1) * var
448 return GMMDistribution.sample(
449 N, {"means": means, "covs": covs, "priors": priors}, dict()
450 )
453class LowRankGMMDistribution(Distribution):
454 """
455 A Gaussian Mixture Model (GMM) with K low-rank components.
456 Formally, the distribution is defined as:
458 mu(B) = sum_(i=1)^(K) pi_i * N(mu_i, Sigma_i)(B)
460 where mu_i is the mean of the ith component, Sigma_i is the covariance matrix of the ith component,
461 and pi_i is the prior probability of the ith component. Notably, Sigma_i is a low-rank matrix of the form
463 Sigma_i = A_i @ A_i^T
465 Distribution Parameters:
466 - means: A tensor of shape (K, D) containing the means of the components.
467 - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices of the components.
468 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
470 Distribution Hyperparameters:
471 - None
473 Note:
474 - The covariance matrices are not explicitly stored, but rather computed as Sigma_i = A_i @ A_i^T.
475 - The time and memory complexity is much lower in this class compared to the full GMM class, if and only if each covariance is low-rank (P << D).
476 """
478 @classmethod
479 def validate_params(
480 cls, possibly_batched_dist_params: Dict[str, torch.Tensor]
481 ) -> None:
482 assert (
483 "means" in possibly_batched_dist_params
484 and "covs_factors" in possibly_batched_dist_params
485 and "priors" in possibly_batched_dist_params
486 )
487 means = possibly_batched_dist_params["means"]
488 covs_factors = possibly_batched_dist_params["covs_factors"]
489 priors = possibly_batched_dist_params["priors"]
491 if len(means.shape) == 2:
492 assert len(covs_factors.shape) == 3
493 assert len(priors.shape) == 1
494 means = means[None, :, :]
495 covs_factors = covs_factors[None, :, :, :]
496 priors = priors[None, :]
498 assert len(means.shape) == 3
499 assert len(covs_factors.shape) == 4
500 assert len(priors.shape) == 2
502 N, K, D, P = covs_factors.shape
503 assert means.shape[0] == N and means.shape[1] == K and means.shape[2] == D
504 assert len(priors.shape) == 2 and priors.shape[0] == N and priors.shape[1] == K
505 assert means.device == covs_factors.device == priors.device
507 assert torch.all(priors >= 0)
508 sum_priors = torch.sum(priors, dim=-1)
509 assert torch.allclose(sum_priors, torch.ones_like(sum_priors))
511 @classmethod
512 def x0(
513 cls,
514 x_t: torch.Tensor,
515 t: torch.Tensor,
516 diffusion_process: DiffusionProcess,
517 batched_dist_params: Dict[str, torch.Tensor],
518 dist_hparams: Dict[str, Any],
519 ) -> torch.Tensor:
520 """
521 Computes the denoiser E[x_0 | x_t] for a low-rank GMM distribution.
523 Arguments:
524 x_t: The input tensor, of shape (N, D).
525 t: The time tensor, of shape (N, ).
526 diffusion_process: The diffusion process whose forward and reverse dynamics determine
527 the time-evolution of the vector fields corresponding to the distribution.
528 batched_dist_params: A dictionary containing the batched parameters of the distribution.
529 - means: A tensor of shape (N, K, D) containing the means of the components.
530 - covs_factors: A tensor of shape (N, K, D, P) containing the tall factors of the covariance matrices.
531 - priors: A tensor of shape (N, K) containing the prior probabilities of the components.
532 dist_hparams: A dictionary of hyperparameters for the distribution.
534 Returns:
535 The prediction of x_0, of shape (N, D).
537 Note:
538 The batched_dist_params dictionary contains BATCHED tensors, i.e., the first dimension is the batch dimension.
539 The covariance matrices are implicitly defined as Sigma_i = A_i @ A_i^T, where A_i is the ith factor.
540 """
541 means = batched_dist_params["means"] # (N, K, D)
542 covs_factors = batched_dist_params["covs_factors"] # (N, K, D, R)
543 priors = batched_dist_params["priors"] # (N, K)
545 N, K, D, P = covs_factors.shape
546 covs_factors_T = covs_factors.transpose(-1, -2) # (N, K, R, D)
548 alpha = diffusion_process.alpha(t) # (N, )
549 sigma = diffusion_process.sigma(t) # (N, )
550 alpha_sigma_ratio_sq = (alpha / sigma) ** 2 # (N, )
551 sigma_alpha_ratio_sq = 1 / alpha_sigma_ratio_sq # (N, )
553 internal_covs = covs_factors_T @ covs_factors # (N, K, R, R)
554 logdets_covs_t = 2 * D * torch.log(sigma[:, None]) + logdet_pd(
555 torch.eye(P, device=covs_factors.device)[None, None, :, :] # (1, 1, P, P)
556 + alpha_sigma_ratio_sq[:, None, None, None] * internal_covs # (N, K, P, P)
557 ) # (N, K)
559 centered_x = x_t[:, None, :] - alpha[:, None, None] * means # (N, K, D)
560 covs_t_inv_centered_x = (1 / sigma[:, None, None] ** 2) * (
561 centered_x # (N, K, D)
562 - (
563 covs_factors # (N, K, D, P)
564 @ torch.linalg.lstsq( # (N, K, P, 1)
565 internal_covs # (N, K, P, P)
566 + sigma_alpha_ratio_sq[:, None, None, None] # (N, K, 1, 1)
567 * torch.eye(P, device=internal_covs.device)[
568 None, None, :, :
569 ], # (1, 1, P, P)
570 covs_factors_T @ centered_x[:, :, :, None], # (N, K, P, 1)
571 ).solution # (N, K, P, 1)
572 )[:, :, :, 0] # (N, K, D, 1) -> (N, K, D)
573 ) # (N, K, D)
575 mahalanobis_dists = torch.sum(
576 centered_x * covs_t_inv_centered_x, dim=-1
577 ) # (N, K)
578 w = (
579 torch.log(priors) - 1 / 2 * logdets_covs_t - 1 / 2 * mahalanobis_dists
580 ) # (N, K)
581 softmax_w = torch.softmax(w, dim=-1) # (N, K)
583 weighted_normalized_x = torch.sum(
584 softmax_w[:, :, None] * covs_t_inv_centered_x, dim=-2
585 ) # (N, D)
586 x0_hat = (1 / alpha[:, None]) * (
587 x_t - (sigma[:, None] ** 2) * weighted_normalized_x
588 ) # (N, D)
590 return x0_hat
592 @classmethod
593 def sample(
594 cls,
595 N: int,
596 dist_params: Dict[str, torch.Tensor],
597 dist_hparams: Dict[str, Any],
598 ) -> Tuple[torch.Tensor, torch.Tensor]:
599 """
600 Draws N i.i.d. samples from the low-rank GMM distribution.
602 Arguments:
603 N: The number of samples to draw.
604 dist_params: A dictionary of parameters for the distribution.
605 - means: A tensor of shape (K, D) containing the means of the components.
606 - covs_factors: A tensor of shape (K, D, P) containing the tall factors of the covariance matrices.
607 - priors: A tensor of shape (K, ) containing the prior probabilities of the components.
608 dist_hparams: A dictionary of hyperparameters for the distribution.
610 Returns:
611 A tuple (samples, labels), where samples is a tensor of shape (N, D) and labels is a tensor of shape (N, )
612 containing the component indices from which each sample was drawn.
613 Note that the samples are always placed on the CPU.
614 """
615 means = dist_params["means"] # (K, D)
616 covs_factors = dist_params["covs_factors"] # (K, D, P)
617 priors = dist_params["priors"] # (K, )
619 K, D, P = covs_factors.shape
621 device = priors.device
622 y = torch.multinomial(priors, N, replacement=True) # (N, )
623 X = torch.empty((N, D), device=device)
624 for k in range(K):
625 idx = y == k
626 X[idx] = (
627 torch.randn((X[idx].shape[0], P), device=device) @ covs_factors[k].T
628 + means[k][None, :]
629 )
630 return X.to("cpu"), y.to("cpu")