Coverage for src/diffusionlab/distributions/empirical.py: 100%
32 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from typing import Any, Dict
3import torch
4from torch.utils.data import DataLoader
6from diffusionlab.diffusions import DiffusionProcess
7from diffusionlab.distributions.base import Distribution
8from diffusionlab.utils import pad_shape_back
11class EmpiricalDistribution(Distribution):
12 """
13 An empirical distribution, i.e., the uniform distribution over a dataset.
14 Formally, the distribution is defined as:
16 mu(B) = (1/N) * sum_(i=1)^(N) delta(x_i in B)
18 where x_i is the ith data point in the dataset, and N is the number of data points.
20 Distribution Parameters:
21 - None
23 Distribution Hyperparameters:
24 - labeled_data: A DataLoader of data which spawns the empirical distribution, where each data sample is a (data, label) tuple. Both data and label are PyTorch tensors.
26 Note:
27 - This class has no sample() method as it's difficult to sample randomly from a DataLoader. In practice, you can sample directly from the DataLoader and apply filtering there.
28 """
30 @classmethod
31 def validate_hparams(cls, dist_hparams: Dict[str, Any]) -> None:
32 """
33 Validate the hyperparameters for the empirical distribution.
35 Arguments:
36 dist_hparams: A dictionary of hyperparameters for the distribution.
37 Must contain 'labeled_data' which is a DataLoader.
39 Returns:
40 None
42 Throws:
43 AssertionError: If the parameters are invalid.
44 """
45 assert "labeled_data" in dist_hparams
46 labeled_data = dist_hparams["labeled_data"]
47 assert isinstance(labeled_data, DataLoader)
48 assert len(labeled_data) > 0
50 @classmethod
51 def x0(
52 cls,
53 x_t: torch.Tensor,
54 t: torch.Tensor,
55 diffusion_process: DiffusionProcess,
56 batched_dist_params: Dict[str, torch.Tensor],
57 dist_hparams: Dict[str, Any],
58 ) -> torch.Tensor:
59 """
60 Computes the denoiser E[x_0 | x_t] for an empirical distribution.
62 This method computes the denoiser by performing a weighted average of the
63 dataset samples, where the weights are determined by the likelihood of x_t
64 given each sample.
66 Arguments:
67 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data.
68 t: The time tensor, of shape (N, ).
69 diffusion_process: The diffusion process.
70 batched_dist_params: A dictionary of batched parameters for the distribution.
71 Not used for empirical distribution.
72 dist_hparams: A dictionary of hyperparameters for the distribution.
73 Must contain 'labeled_data' which is a DataLoader.
75 Returns:
76 The prediction of x_0, of shape (N, *D).
77 """
78 data = dist_hparams["labeled_data"]
80 x_flattened = torch.flatten(x_t, start_dim=1, end_dim=-1) # (N, *D)
82 alpha = diffusion_process.alpha(t) # (N, )
83 sigma = diffusion_process.sigma(t) # (N, )
85 softmax_denom = torch.zeros_like(t) # (N, )
86 x0_hat = torch.zeros_like(x_t) # (N, *D)
87 for X_batch, y_batch in data:
88 X_batch = X_batch.to(x_t.device, non_blocking=True) # (B, *D)
89 X_batch_flattened = torch.flatten(X_batch, start_dim=1, end_dim=-1)[
90 None, ...
91 ] # (1, B, D*)
92 alpha_X_batch_flattened = (
93 pad_shape_back(alpha, X_batch_flattened.shape) * X_batch_flattened
94 ) # (N, B, D*)
95 dists = (
96 torch.cdist(x_flattened[:, None, ...], alpha_X_batch_flattened)[
97 :, 0, ...
98 ]
99 ** 2
100 ) # (N, B)
101 exp_dists = torch.exp(
102 -dists / (2 * pad_shape_back(sigma, dists.shape) ** 2)
103 ) # (N, B)
104 softmax_denom += torch.sum(exp_dists, dim=-1) # (N, )
105 x0_hat += torch.sum(
106 pad_shape_back(exp_dists, X_batch[None, ...].shape)
107 * X_batch[None, ...], # (N, B, *D)
108 dim=1,
109 ) # (N, *D)
111 softmax_denom = torch.maximum(
112 softmax_denom,
113 torch.tensor(
114 torch.finfo(softmax_denom.dtype).eps, device=softmax_denom.device
115 ),
116 )
117 x0_hat = x0_hat / pad_shape_back(softmax_denom, x0_hat.shape) # (N, *D)
118 return x0_hat