Coverage for src/diffusionlab/distributions/empirical.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from typing import Any, Dict 

2 

3import torch 

4from torch.utils.data import DataLoader 

5 

6from diffusionlab.diffusions import DiffusionProcess 

7from diffusionlab.distributions.base import Distribution 

8from diffusionlab.utils import pad_shape_back 

9 

10 

11class EmpiricalDistribution(Distribution): 

12 """ 

13 An empirical distribution, i.e., the uniform distribution over a dataset. 

14 Formally, the distribution is defined as: 

15 

16 mu(B) = (1/N) * sum_(i=1)^(N) delta(x_i in B) 

17 

18 where x_i is the ith data point in the dataset, and N is the number of data points. 

19 

20 Distribution Parameters: 

21 - None 

22 

23 Distribution Hyperparameters: 

24 - labeled_data: A DataLoader of data which spawns the empirical distribution, where each data sample is a (data, label) tuple. Both data and label are PyTorch tensors. 

25 

26 Note: 

27 - This class has no sample() method as it's difficult to sample randomly from a DataLoader. In practice, you can sample directly from the DataLoader and apply filtering there. 

28 """ 

29 

30 @classmethod 

31 def validate_hparams(cls, dist_hparams: Dict[str, Any]) -> None: 

32 """ 

33 Validate the hyperparameters for the empirical distribution. 

34 

35 Arguments: 

36 dist_hparams: A dictionary of hyperparameters for the distribution. 

37 Must contain 'labeled_data' which is a DataLoader. 

38 

39 Returns: 

40 None 

41 

42 Throws: 

43 AssertionError: If the parameters are invalid. 

44 """ 

45 assert "labeled_data" in dist_hparams 

46 labeled_data = dist_hparams["labeled_data"] 

47 assert isinstance(labeled_data, DataLoader) 

48 assert len(labeled_data) > 0 

49 

50 @classmethod 

51 def x0( 

52 cls, 

53 x_t: torch.Tensor, 

54 t: torch.Tensor, 

55 diffusion_process: DiffusionProcess, 

56 batched_dist_params: Dict[str, torch.Tensor], 

57 dist_hparams: Dict[str, Any], 

58 ) -> torch.Tensor: 

59 """ 

60 Computes the denoiser E[x_0 | x_t] for an empirical distribution. 

61 

62 This method computes the denoiser by performing a weighted average of the 

63 dataset samples, where the weights are determined by the likelihood of x_t 

64 given each sample. 

65 

66 Arguments: 

67 x_t: The input tensor, of shape (N, *D), where *D is the shape of each data. 

68 t: The time tensor, of shape (N, ). 

69 diffusion_process: The diffusion process. 

70 batched_dist_params: A dictionary of batched parameters for the distribution. 

71 Not used for empirical distribution. 

72 dist_hparams: A dictionary of hyperparameters for the distribution. 

73 Must contain 'labeled_data' which is a DataLoader. 

74 

75 Returns: 

76 The prediction of x_0, of shape (N, *D). 

77 """ 

78 data = dist_hparams["labeled_data"] 

79 

80 x_flattened = torch.flatten(x_t, start_dim=1, end_dim=-1) # (N, *D) 

81 

82 alpha = diffusion_process.alpha(t) # (N, ) 

83 sigma = diffusion_process.sigma(t) # (N, ) 

84 

85 softmax_denom = torch.zeros_like(t) # (N, ) 

86 x0_hat = torch.zeros_like(x_t) # (N, *D) 

87 for X_batch, y_batch in data: 

88 X_batch = X_batch.to(x_t.device, non_blocking=True) # (B, *D) 

89 X_batch_flattened = torch.flatten(X_batch, start_dim=1, end_dim=-1)[ 

90 None, ... 

91 ] # (1, B, D*) 

92 alpha_X_batch_flattened = ( 

93 pad_shape_back(alpha, X_batch_flattened.shape) * X_batch_flattened 

94 ) # (N, B, D*) 

95 dists = ( 

96 torch.cdist(x_flattened[:, None, ...], alpha_X_batch_flattened)[ 

97 :, 0, ... 

98 ] 

99 ** 2 

100 ) # (N, B) 

101 exp_dists = torch.exp( 

102 -dists / (2 * pad_shape_back(sigma, dists.shape) ** 2) 

103 ) # (N, B) 

104 softmax_denom += torch.sum(exp_dists, dim=-1) # (N, ) 

105 x0_hat += torch.sum( 

106 pad_shape_back(exp_dists, X_batch[None, ...].shape) 

107 * X_batch[None, ...], # (N, B, *D) 

108 dim=1, 

109 ) # (N, *D) 

110 

111 softmax_denom = torch.maximum( 

112 softmax_denom, 

113 torch.tensor( 

114 torch.finfo(softmax_denom.dtype).eps, device=softmax_denom.device 

115 ), 

116 ) 

117 x0_hat = x0_hat / pad_shape_back(softmax_denom, x0_hat.shape) # (N, *D) 

118 return x0_hat