Coverage for src/diffusionlab/models.py: 100%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-21 15:33 -0700

1from typing import Any, Callable, Dict 

2import torch 

3from torch import nn, optim 

4from lightning import LightningModule 

5from lightning.pytorch.utilities.types import OptimizerLRScheduler 

6 

7from diffusionlab.losses import SamplewiseDiffusionLoss 

8from diffusionlab.diffusions import DiffusionProcess 

9from diffusionlab.schedulers import Scheduler 

10from diffusionlab.vector_fields import VectorField, VectorFieldType 

11 

12 

13class DiffusionModel(LightningModule, VectorField): 

14 """ 

15 A PyTorch Lightning module for training and evaluating diffusion models. 

16 

17 This class implements a diffusion model that can be trained using various vector field types 

18 (score, x0, eps, v) and diffusion processes. It handles the training loop, loss computation, 

19 and evaluation metrics. 

20 

21 The model inherits from both LightningModule (for training) and VectorField (for sampling), 

22 making it compatible with both the Lightning training framework and the diffusion sampling 

23 algorithms. 

24 

25 Attributes: 

26 net (nn.Module): The neural network that predicts the vector field. 

27 vector_field_type (VectorFieldType): The type of vector field the model predicts. 

28 diffusion_process (DiffusionProcess): The diffusion process used for training. 

29 train_scheduler (Scheduler): The scheduler for generating training time steps. 

30 optimizer (optim.Optimizer): The optimizer for training the model. 

31 lr_scheduler (optim.lr_scheduler.LRScheduler): The learning rate scheduler. 

32 batchwise_metrics (nn.ModuleDict): Metrics computed on each batch during validation. 

33 batchfree_metrics (nn.ModuleDict): Metrics computed at the end of validation epoch. 

34 t_loss_weights (Callable): Function that weights loss at different time steps. 

35 t_loss_probs (Callable): Function that determines sampling probability of time steps. 

36 N_noise_draws_per_sample (int): Number of noise samples per data point. 

37 samplewise_loss (SamplewiseDiffusionLoss): Loss function for each sample. 

38 batchwise_loss (Callable): Factory-generated function that computes loss for a batch. 

39 train_ts (torch.Tensor): Precomputed time steps for training. 

40 train_ts_loss_weights (torch.Tensor): Precomputed weights for each time step. 

41 train_ts_loss_probs (torch.Tensor): Precomputed sampling probabilities for each time step. 

42 LOG_ON_STEP_TRAIN_LOSS (bool): Whether to log training loss on each step. Default is True. 

43 LOG_ON_EPOCH_TRAIN_LOSS (bool): Whether to log training loss on each epoch. Default is True. 

44 LOG_ON_PROGRESS_BAR_TRAIN_LOSS (bool): Whether to display training loss on the progress bar. Default is True. 

45 LOG_ON_STEP_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each step. Default is False. 

46 LOG_ON_EPOCH_BATCHWISE_METRICS (bool): Whether to log batchwise metrics on each epoch. Default is True. 

47 LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS (bool): Whether to display batchwise metrics on the progress bar. Default is False. 

48 LOG_ON_STEP_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each step. Default is False. 

49 LOG_ON_EPOCH_BATCHFREE_METRICS (bool): Whether to log batchfree metrics on each epoch. Default is True. 

50 LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS (bool): Whether to display batchfree metrics on the progress bar. Default is False. 

51 """ 

52 

53 LOG_ON_STEP_TRAIN_LOSS = True 

54 LOG_ON_EPOCH_TRAIN_LOSS = True 

55 LOG_ON_PROGRESS_BAR_TRAIN_LOSS = True 

56 

57 LOG_ON_STEP_BATCHWISE_METRICS = False 

58 LOG_ON_EPOCH_BATCHWISE_METRICS = True 

59 LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS = False 

60 

61 LOG_ON_STEP_BATCHFREE_METRICS = False 

62 LOG_ON_EPOCH_BATCHFREE_METRICS = True 

63 LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS = False 

64 

65 def __init__( 

66 self, 

67 net: nn.Module, 

68 diffusion_process: DiffusionProcess, 

69 train_scheduler: Scheduler, 

70 vector_field_type: VectorFieldType, 

71 optimizer: optim.Optimizer, 

72 lr_scheduler: optim.lr_scheduler.LRScheduler, 

73 batchwise_metrics: Dict[str, nn.Module], 

74 batchfree_metrics: Dict[str, nn.Module], 

75 train_ts_hparams: Dict[str, Any], 

76 t_loss_weights: Callable[[torch.Tensor], torch.Tensor], 

77 t_loss_probs: Callable[[torch.Tensor], torch.Tensor], 

78 N_noise_draws_per_sample: int, 

79 ): 

80 """ 

81 Initialize the diffusion model. 

82 

83 Args: 

84 net (nn.Module): Neural network that predicts the vector field. 

85 diffusion_process (DiffusionProcess): The diffusion process used for training. 

86 train_scheduler (Scheduler): Scheduler for generating training time steps. 

87 vector_field_type (VectorFieldType): Type of vector field the model predicts. 

88 optimizer (optim.Optimizer): Optimizer for training the model. 

89 lr_scheduler (optim.lr_scheduler.LRScheduler): Learning rate scheduler. 

90 batchwise_metrics (Dict[str, nn.Module]): Metrics computed on each batch during validation. Each metric takes in (x, metadata, model) and returns a dictionary of metric (name, value) pairs. 

91 batchfree_metrics (Dict[str, nn.Module]): Metrics computed at the end of validation epoch. Each metric takes in (model) and returns a dictionary of metric (name, value) pairs. 

92 train_ts_hparams (Dict[str, Any]): Parameters for the training time step scheduler. 

93 t_loss_weights (Callable[[torch.Tensor], torch.Tensor]): Function that weights loss at different time steps. 

94 t_loss_probs (Callable[[torch.Tensor], torch.Tensor]): Function that determines sampling probability of time steps. 

95 N_noise_draws_per_sample (int): Number of noise draws per data point. 

96 """ 

97 super().__init__() 

98 # Initialize VectorField with a forward function for the current instance 

99 VectorField.__init__(self, self.forward, vector_field_type) 

100 

101 self.net: nn.Module = net 

102 self.vector_field_type: VectorFieldType = vector_field_type 

103 self.diffusion_process: DiffusionProcess = diffusion_process 

104 self.train_scheduler: Scheduler = train_scheduler 

105 self.optimizer: optim.Optimizer = optimizer 

106 self.lr_scheduler: optim.lr_scheduler.LRScheduler = lr_scheduler 

107 self.batchwise_metrics: nn.ModuleDict = nn.ModuleDict(batchwise_metrics) 

108 self.batchfree_metrics: nn.ModuleDict = nn.ModuleDict(batchfree_metrics) 

109 

110 self.t_loss_weights: Callable[[torch.Tensor], torch.Tensor] = t_loss_weights 

111 self.t_loss_probs: Callable[[torch.Tensor], torch.Tensor] = t_loss_probs 

112 self.N_noise_draws_per_sample: int = N_noise_draws_per_sample 

113 

114 # Create the samplewise loss function 

115 self.samplewise_loss: SamplewiseDiffusionLoss = SamplewiseDiffusionLoss( 

116 diffusion_process, vector_field_type 

117 ) 

118 

119 # Create the batchwise loss function using the factory method 

120 self.batchwise_loss = self.samplewise_loss.batchwise_loss_factory( 

121 N_noise_draws_per_sample=N_noise_draws_per_sample 

122 ) 

123 

124 self.register_buffer("train_ts", torch.zeros((0,))) 

125 self.register_buffer("train_ts_loss_weights", torch.zeros((0,))) 

126 self.register_buffer("train_ts_loss_probs", torch.zeros((0,))) 

127 self.precompute_train_schedule(train_ts_hparams) 

128 

129 def precompute_train_schedule(self, train_ts_hparams: Dict[str, float]) -> None: 

130 """ 

131 Precompute time steps and their associated weights for training. 

132 

133 This method generates the time steps used during training and computes 

134 the loss weights and sampling probabilities for each time step. 

135 

136 Args: 

137 train_ts_hparams (Dict[str, float]): Parameters for the training time step scheduler. 

138 Typically includes t_min, t_max, and the number of steps L. 

139 """ 

140 self.train_ts = self.train_scheduler.get_ts(**train_ts_hparams).to( 

141 self.device, non_blocking=True 

142 ) 

143 self.train_ts_loss_weights: torch.Tensor = self.t_loss_weights(self.train_ts) 

144 self.train_ts_loss_probs: torch.Tensor = self.t_loss_probs(self.train_ts) 

145 

146 def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 

147 """ 

148 Forward pass of the model. 

149 

150 Passes the input through the neural network to predict the vector field. 

151 

152 Args: 

153 x (torch.Tensor): Input tensor of shape (batch_size, *data_dims). 

154 t (torch.Tensor): Time tensor of shape (batch_size,). 

155 

156 Returns: 

157 torch.Tensor: Predicted vector field of shape (batch_size, *data_dims). 

158 """ 

159 return self.net(x, t) 

160 

161 def configure_optimizers(self) -> OptimizerLRScheduler: 

162 """ 

163 Configure optimizers and learning rate schedulers for training. 

164 

165 This method is called by PyTorch Lightning to set up the optimization process. 

166 

167 Returns: 

168 OptimizerLRScheduler: Dictionary containing the optimizer and learning rate scheduler. 

169 """ 

170 return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler} 

171 

172 def loss( 

173 self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor 

174 ) -> torch.Tensor: 

175 """ 

176 Compute the loss for a batch of data at specified time steps. 

177 

178 Uses the batchwise_loss function created from the SamplewiseDiffusionLoss factory 

179 to compute the loss for the batch. 

180 

181 Args: 

182 x (torch.Tensor): Input data of shape (batch_size, *data_dims). 

183 t (torch.Tensor): Time steps of shape (batch_size,). 

184 sample_weights (torch.Tensor): Weights for each sample of shape (batch_size,). 

185 

186 Returns: 

187 torch.Tensor: Scalar loss value. 

188 """ 

189 return self.batchwise_loss(self, x, t, sample_weights) 

190 

191 def aggregate_loss(self, x: torch.Tensor) -> torch.Tensor: 

192 """ 

193 Compute the loss for a batch of data with randomly sampled time steps. 

194 

195 This method: 

196 1. Samples time steps according to the training distribution 

197 2. Computes the loss at those time steps 

198 

199 Args: 

200 x (torch.Tensor): Input data of shape (batch_size, *data_dims). 

201 

202 Returns: 

203 torch.Tensor: Scalar loss value. 

204 """ 

205 t_idx = torch.multinomial( 

206 self.train_ts_loss_probs, x.shape[0], replacement=True 

207 ).to(self.device, non_blocking=True) 

208 t = self.train_ts[t_idx] 

209 t_weights = self.train_ts_loss_weights[t_idx] 

210 mean_loss = self.loss(x, t, t_weights) 

211 return mean_loss 

212 

213 def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: 

214 """ 

215 Perform a single training step. 

216 

217 This method is called by PyTorch Lightning during training. 

218 

219 Args: 

220 batch (torch.Tensor): Batch of data, typically a tuple (x, metadata). 

221 batch_idx (int): Index of the current batch. 

222 

223 Returns: 

224 torch.Tensor: Loss value for the batch. 

225 """ 

226 x, metadata = batch 

227 loss = self.aggregate_loss(x) 

228 self.log( 

229 "train_loss", 

230 loss, 

231 on_step=self.LOG_ON_STEP_TRAIN_LOSS, 

232 on_epoch=self.LOG_ON_EPOCH_TRAIN_LOSS, 

233 prog_bar=self.LOG_ON_PROGRESS_BAR_TRAIN_LOSS, 

234 ) 

235 return loss 

236 

237 def validation_step( 

238 self, batch: torch.Tensor, batch_idx: int 

239 ) -> Dict[str, torch.Tensor]: 

240 """ 

241 Perform a single validation step. 

242 

243 This method is called by PyTorch Lightning during validation. 

244 It computes the loss and any batch-wise metrics. 

245 

246 Args: 

247 batch (torch.Tensor): Batch of data, typically a tuple (x, metadata). 

248 batch_idx (int): Index of the current batch. 

249 

250 Returns: 

251 Dict[str, torch.Tensor]: Dictionary of metric values. 

252 """ 

253 x, metadata = batch 

254 loss = self.aggregate_loss(x) 

255 metric_values = {"val_loss": loss} 

256 for metric_name, metric in self.batchwise_metrics.items(): 

257 metric_values_dict = metric(x, metadata, self) 

258 for key, value in metric_values_dict.items(): 

259 metric_label = self._get_metric_label(metric_name, key) 

260 metric_values[metric_label] = value 

261 self.log_dict( 

262 metric_values, 

263 on_step=self.LOG_ON_STEP_BATCHWISE_METRICS, 

264 on_epoch=self.LOG_ON_EPOCH_BATCHWISE_METRICS, 

265 prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHWISE_METRICS, 

266 ) 

267 return metric_values 

268 

269 def on_validation_epoch_end(self) -> None: 

270 """ 

271 Perform operations at the end of a validation epoch. 

272 

273 This method is called by PyTorch Lightning at the end of each validation epoch. 

274 It computes and logs any batch-free metrics that require the entire validation set. 

275 """ 

276 metric_values = {} 

277 for metric_name, metric in self.batchfree_metrics.items(): 

278 metric_values_dict = metric(self) 

279 for key, value in metric_values_dict.items(): 

280 metric_label = self._get_metric_label(metric_name, key) 

281 metric_values[metric_label] = value 

282 self.log_dict( 

283 metric_values, 

284 on_step=self.LOG_ON_STEP_BATCHFREE_METRICS, 

285 on_epoch=self.LOG_ON_EPOCH_BATCHFREE_METRICS, 

286 prog_bar=self.LOG_ON_PROGRESS_BAR_BATCHFREE_METRICS, 

287 ) 

288 

289 def _get_metric_label(self, metric_name: str, key: str) -> str: 

290 """ 

291 Get the label for a metric's values. 

292 

293 This method concatenates the metric name and key with an underscore if both are non-empty. 

294 If one of the two is empty, it concatenates the non-empty one with the other. 

295 

296 Args: 

297 metric_name (str): The name of the metric. 

298 key (str): The key of the metric. 

299 

300 Returns: 

301 str: The label for the metric's values. 

302 """ 

303 metric_name, key = metric_name.strip(), key.strip() 

304 if len(metric_name) > 0 and len(key) > 0: 

305 metric_label = f"{metric_name}_{key}" 

306 else: 

307 metric_label = f"{metric_name}{key}" 

308 return metric_label