Coverage for src/diffusionlab/samplers.py: 100%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from typing import Callable, Tuple 

2 

3import torch 

4 

5from diffusionlab.diffusions import DiffusionProcess 

6from diffusionlab.utils import pad_shape_back 

7from diffusionlab.vector_fields import ( 

8 VectorField, 

9 VectorFieldType, 

10 convert_vector_field_type, 

11) 

12 

13 

14class Sampler: 

15 """ 

16 Class for sampling from diffusion models using various vector field types. 

17 

18 A Sampler combines a diffusion process and a scheduler to generate samples from 

19 a trained diffusion model. It handles both the forward process (adding noise) and 

20 the reverse process (denoising/sampling). 

21 

22 The sampler supports different vector field types (SCORE, X0, EPS, V) and can perform 

23 both stochastic and deterministic sampling. 

24 

25 Attributes: 

26 diffusion_process (DiffusionProcess): The diffusion process defining the forward and reverse dynamics 

27 is_stochastic (bool): Whether the reverse process is stochastic or deterministic 

28 """ 

29 

30 def __init__( 

31 self, 

32 diffusion_process: DiffusionProcess, 

33 is_stochastic: bool, 

34 ): 

35 """ 

36 Initialize a sampler with a diffusion process and sampling strategy. 

37 

38 Args: 

39 diffusion_process (DiffusionProcess): The diffusion process to use for sampling 

40 is_stochastic (bool): Whether the reverse process should be stochastic 

41 """ 

42 self.diffusion_process: DiffusionProcess = diffusion_process 

43 self.is_stochastic: bool = is_stochastic 

44 

45 def sample( 

46 self, 

47 vector_field: VectorField, 

48 x_init: torch.Tensor, 

49 zs: torch.Tensor, 

50 ts: torch.Tensor, 

51 ) -> torch.Tensor: 

52 """ 

53 Sample from the model using the reverse diffusion process. 

54 

55 This method generates a sample by iteratively applying the appropriate sampling step 

56 function based on the vector field type. 

57 

58 Args: 

59 vector_field (VectorField): The vector field model to use for sampling 

60 x_init (torch.Tensor): The initial noisy tensor to start sampling from, of shape (N, *D) where N is the batch size and D represents the data dimensions 

61 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D) 

62 where L is the number of time steps 

63 ts (torch.Tensor): The time schedule for sampling, of shape (L,) 

64 where L is the number of time steps 

65 

66 Returns: 

67 torch.Tensor: The generated sample, of shape (N, *D) 

68 """ 

69 sample_step_function = self.get_sample_step_function( 

70 vector_field.vector_field_type 

71 ) 

72 x = x_init 

73 for i in range(ts.shape[0] - 1): 

74 x = sample_step_function(vector_field, x, zs, i, ts) 

75 return x 

76 

77 def sample_trajectory( 

78 self, 

79 vector_field: VectorField, 

80 x_init: torch.Tensor, 

81 zs: torch.Tensor, 

82 ts: torch.Tensor, 

83 ) -> torch.Tensor: 

84 """ 

85 Sample a trajectory from the model using the reverse diffusion process. 

86 

87 This method is similar to sample() but returns the entire trajectory of 

88 intermediate samples rather than just the final sample. 

89 

90 Args: 

91 vector_field (VectorField): The vector field model to use for sampling 

92 x_init (torch.Tensor): The initial noisy tensor to start sampling from, of shape (N, *D) 

93 where N is the batch size and D represents the data dimensions 

94 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D) 

95 where L is the number of time steps 

96 ts (torch.Tensor): The time schedule for sampling, of shape (L,) 

97 where L is the number of time steps 

98 

99 Returns: 

100 torch.Tensor: The generated trajectory, of shape (L, N, *D) 

101 where L is the number of time steps 

102 """ 

103 sample_step_function = self.get_sample_step_function( 

104 vector_field.vector_field_type 

105 ) 

106 xs = [x_init] 

107 x = x_init 

108 for i in range(ts.shape[0] - 1): 

109 x = sample_step_function(vector_field, x, zs, i, ts) 

110 xs.append(x) 

111 return torch.stack(xs) 

112 

113 def get_sample_step_function( 

114 self, vector_field_type: VectorFieldType 

115 ) -> Callable[ 

116 [VectorField, torch.Tensor, torch.Tensor, int, torch.Tensor], torch.Tensor 

117 ]: 

118 """ 

119 Get the appropriate sampling step function based on the vector field type. 

120 

121 This method selects the correct sampling function based on the vector field type 

122 and whether sampling is stochastic or deterministic. 

123 

124 Args: 

125 vector_field_type (VectorFieldType): The type of vector field being used 

126 (SCORE, X0, EPS, or V) 

127 

128 Returns: 

129 Callable: A function that performs one step of the sampling process with signature: 

130 (vector_field, x, zs, idx, ts) -> next_x 

131 where: 

132 - vector_field is the model 

133 - x is the current state tensor of shape (N, *D) 

134 where N is the batch size and D represents the data dimensions 

135 - zs is the noise tensors of shape (L-1, N, *D) 

136 where L is the number of time steps 

137 - idx is the current step index 

138 - ts is the time steps tensor of shape (L,) 

139 where L is the number of time steps 

140 - next_x is the next state tensor of shape (N, *D) 

141 """ 

142 f = None 

143 if self.is_stochastic: 

144 if vector_field_type == VectorFieldType.SCORE: 

145 f = self.sample_step_stochastic_score 

146 elif vector_field_type == VectorFieldType.X0: 

147 f = self.sample_step_stochastic_x0 

148 elif vector_field_type == VectorFieldType.EPS: 

149 f = self.sample_step_stochastic_eps 

150 elif vector_field_type == VectorFieldType.V: 

151 f = self.sample_step_stochastic_v 

152 else: 

153 if vector_field_type == VectorFieldType.SCORE: 

154 f = self.sample_step_deterministic_score 

155 elif vector_field_type == VectorFieldType.X0: 

156 f = self.sample_step_deterministic_x0 

157 elif vector_field_type == VectorFieldType.EPS: 

158 f = self.sample_step_deterministic_eps 

159 elif vector_field_type == VectorFieldType.V: 

160 f = self.sample_step_deterministic_v 

161 return f 

162 

163 def _fix_t_shape(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 

164 """ 

165 Reshape the time tensor to be compatible with the batch dimension of x. 

166 

167 Args: 

168 x (torch.Tensor): The data tensor of shape (N, *D) 

169 where N is the batch size and D represents the data dimensions 

170 t (torch.Tensor): The time tensor to reshape, of shape (1, 1, ..., 1) 

171 or any shape that can be broadcast to match the batch size 

172 

173 Returns: 

174 torch.Tensor: The reshaped time tensor of shape (N,) 

175 where N is the batch size of x 

176 """ 

177 t = t.view((1,)).expand(x.shape[0]) 

178 return t 

179 

180 def sample_step_stochastic_score( 

181 self, 

182 score: VectorField, 

183 x: torch.Tensor, 

184 zs: torch.Tensor, 

185 idx: int, 

186 ts: torch.Tensor, 

187 ) -> torch.Tensor: 

188 """ 

189 Perform a stochastic sampling step using the score vector field. 

190 

191 This method implements one step of the stochastic reverse process using the score function. 

192 

193 Args: 

194 score (VectorField): The score vector field model 

195 x (torch.Tensor): The current state tensor, of shape (N, *D) 

196 where N is the batch size and D represents the data dimensions 

197 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D) 

198 where L is the number of time steps 

199 idx (int): The current step index 

200 ts (torch.Tensor): The time steps tensor, of shape (L,) 

201 where L is the number of time steps 

202 

203 Returns: 

204 torch.Tensor: The next state tensor, of shape (N, *D) 

205 """ 

206 raise NotImplementedError 

207 

208 def sample_step_deterministic_score( 

209 self, 

210 score: VectorField, 

211 x: torch.Tensor, 

212 zs: torch.Tensor, 

213 idx: int, 

214 ts: torch.Tensor, 

215 ) -> torch.Tensor: 

216 """ 

217 Perform one step of deterministic sampling using the score vector field. 

218 

219 Args: 

220 score (VectorField): The score vector field model 

221 x (torch.Tensor): The current state, of shape (N, *D) 

222 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

223 idx (int): The current step index 

224 ts (torch.Tensor): The time steps tensor, of shape (L,) 

225 

226 Returns: 

227 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

228 """ 

229 raise NotImplementedError 

230 

231 def sample_step_stochastic_x0( 

232 self, 

233 x0: VectorField, 

234 x: torch.Tensor, 

235 zs: torch.Tensor, 

236 idx: int, 

237 ts: torch.Tensor, 

238 ) -> torch.Tensor: 

239 """ 

240 Perform one step of stochastic sampling using the x0 vector field. 

241 

242 Args: 

243 x0 (VectorField): The x0 vector field model 

244 x (torch.Tensor): The current state, of shape (N, *D) 

245 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

246 idx (int): The current step index 

247 ts (torch.Tensor): The time steps tensor, of shape (L,) 

248 

249 Returns: 

250 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

251 """ 

252 raise NotImplementedError 

253 

254 def sample_step_deterministic_x0( 

255 self, 

256 x0: VectorField, 

257 x: torch.Tensor, 

258 zs: torch.Tensor, 

259 idx: int, 

260 ts: torch.Tensor, 

261 ) -> torch.Tensor: 

262 """ 

263 Perform one step of deterministic sampling using the x0 vector field. 

264 

265 Args: 

266 x0 (VectorField): The x0 vector field model 

267 x (torch.Tensor): The current state, of shape (N, *D) 

268 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

269 idx (int): The current step index 

270 ts (torch.Tensor): The time steps tensor, of shape (L,) 

271 

272 Returns: 

273 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

274 """ 

275 raise NotImplementedError 

276 

277 def sample_step_stochastic_eps( 

278 self, 

279 eps: VectorField, 

280 x: torch.Tensor, 

281 zs: torch.Tensor, 

282 idx: int, 

283 ts: torch.Tensor, 

284 ) -> torch.Tensor: 

285 """ 

286 Perform one step of stochastic sampling using the eps vector field. 

287 

288 Args: 

289 eps (VectorField): The eps vector field model 

290 x (torch.Tensor): The current state, of shape (N, *D) 

291 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

292 idx (int): The current step index 

293 ts (torch.Tensor): The time steps tensor, of shape (L,) 

294 

295 Returns: 

296 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

297 """ 

298 raise NotImplementedError 

299 

300 def sample_step_deterministic_eps( 

301 self, 

302 eps: VectorField, 

303 x: torch.Tensor, 

304 zs: torch.Tensor, 

305 idx: int, 

306 ts: torch.Tensor, 

307 ) -> torch.Tensor: 

308 """ 

309 Perform one step of deterministic sampling using the eps vector field. 

310 

311 Args: 

312 eps (VectorField): The eps vector field model 

313 x (torch.Tensor): The current state, of shape (N, *D) 

314 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

315 idx (int): The current step index 

316 ts (torch.Tensor): The time steps tensor, of shape (L,) 

317 

318 Returns: 

319 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

320 """ 

321 raise NotImplementedError 

322 

323 def sample_step_stochastic_v( 

324 self, 

325 v: VectorField, 

326 x: torch.Tensor, 

327 zs: torch.Tensor, 

328 idx: int, 

329 ts: torch.Tensor, 

330 ) -> torch.Tensor: 

331 """ 

332 Perform one step of stochastic sampling using the v vector field. 

333 

334 Args: 

335 v (VectorField): The velocity vector field model 

336 x (torch.Tensor): The current state, of shape (N, *D) 

337 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

338 idx (int): The current step index 

339 ts (torch.Tensor): The time steps tensor, of shape (L,) 

340 

341 Returns: 

342 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

343 """ 

344 raise NotImplementedError 

345 

346 def sample_step_deterministic_v( 

347 self, 

348 v: VectorField, 

349 x: torch.Tensor, 

350 zs: torch.Tensor, 

351 idx: int, 

352 ts: torch.Tensor, 

353 ) -> torch.Tensor: 

354 """ 

355 Perform one step of deterministic sampling using the v vector field. 

356 

357 Args: 

358 v (VectorField): The velocity vector field model 

359 x (torch.Tensor): The current state, of shape (N, *D) 

360 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

361 idx (int): The current step index 

362 ts (torch.Tensor): The time steps tensor, of shape (L,) 

363 

364 Returns: 

365 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

366 """ 

367 raise NotImplementedError 

368 

369 

370class EulerMaruyamaSampler(Sampler): 

371 def _get_step_quantities( 

372 self, zs: torch.Tensor, idx: int, ts: torch.Tensor 

373 ) -> Tuple[ 

374 torch.Tensor, 

375 torch.Tensor, 

376 torch.Tensor, 

377 torch.Tensor, 

378 torch.Tensor, 

379 torch.Tensor, 

380 torch.Tensor, 

381 torch.Tensor, 

382 torch.Tensor, 

383 torch.Tensor, 

384 torch.Tensor, 

385 ]: 

386 """ 

387 Calculate various quantities needed for a sampling step. 

388 

389 This helper method computes time-dependent quantities used in the sampling 

390 step functions. 

391 

392 Args: 

393 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D) 

394 where L is the number of time steps, N is the batch size, and D represents the data dimensions 

395 idx (int): The current step index 

396 ts (torch.Tensor): The time steps tensor, of shape (L,) 

397 where L is the number of time steps 

398 

399 Returns: 

400 Tuple: A tuple containing various time-dependent quantities: 

401 - t (torch.Tensor): Current time, of shape (1*), where 1* is a tuple with the same number of dimensions as (N, D*) 

402 - t1 (torch.Tensor): Next time, of shape (1*) 

403 - alpha_t (torch.Tensor): Alpha at current time, of shape (1*) 

404 - sigma_t (torch.Tensor): Sigma at current time, of shape (1*) 

405 - alpha_prime_t (torch.Tensor): Derivative of alpha at current time, of shape (1*) 

406 - sigma_prime_t (torch.Tensor): Derivative of sigma at current time, of shape (1*) 

407 - dt (torch.Tensor): Time difference, of shape (1*) 

408 - dwt (torch.Tensor): Scaled noise, of shape (N, *D) 

409 - alpha_ratio_t (torch.Tensor): alpha_prime_t / alpha_t, of shape (1*) 

410 - sigma_ratio_t (torch.Tensor): sigma_prime_t / sigma_t, of shape (1*) 

411 - diff_ratio_t (torch.Tensor): sigma_ratio_t - alpha_ratio_t, of shape (1*) 

412 """ 

413 x_shape = zs.shape[1:] 

414 t = pad_shape_back(ts[idx], x_shape) 

415 t1 = pad_shape_back(ts[idx + 1], x_shape) 

416 dt = t1 - t 

417 dwt = zs[idx] * torch.sqrt(-dt) 

418 

419 alpha_t = pad_shape_back(self.diffusion_process.alpha(ts[idx]), x_shape) 

420 sigma_t = pad_shape_back(self.diffusion_process.sigma(ts[idx]), x_shape) 

421 alpha_prime_t = pad_shape_back( 

422 self.diffusion_process.alpha_prime(ts[idx]), x_shape 

423 ) 

424 sigma_prime_t = pad_shape_back( 

425 self.diffusion_process.sigma_prime(ts[idx]), x_shape 

426 ) 

427 alpha_ratio_t = alpha_prime_t / alpha_t 

428 sigma_ratio_t = sigma_prime_t / sigma_t 

429 diff_ratio_t = sigma_ratio_t - alpha_ratio_t 

430 return ( 

431 t, 

432 t1, 

433 alpha_t, 

434 sigma_t, 

435 alpha_prime_t, 

436 sigma_prime_t, 

437 dt, 

438 dwt, 

439 alpha_ratio_t, 

440 sigma_ratio_t, 

441 diff_ratio_t, 

442 ) 

443 

444 def sample_step_deterministic_score( 

445 self, 

446 score: VectorField, 

447 x: torch.Tensor, 

448 zs: torch.Tensor, 

449 idx: int, 

450 ts: torch.Tensor, 

451 ) -> torch.Tensor: 

452 """ 

453 Perform one step of deterministic sampling using the score vector field. 

454 

455 Args: 

456 score (VectorField): The score vector field model 

457 x (torch.Tensor): The current state, of shape (N, *D) 

458 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

459 idx (int): The current step index 

460 ts (torch.Tensor): The time steps tensor, of shape (L,) 

461 

462 Returns: 

463 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

464 """ 

465 ( 

466 t, 

467 t1, 

468 alpha_t, 

469 sigma_t, 

470 alpha_prime_t, 

471 sigma_prime_t, 

472 dt, 

473 dwt, 

474 alpha_ratio_t, 

475 sigma_ratio_t, 

476 diff_ratio_t, 

477 ) = self._get_step_quantities(zs, idx, ts) 

478 drift_t = alpha_ratio_t * x - (sigma_t**2) * diff_ratio_t * score( 

479 x, self._fix_t_shape(x, t) 

480 ) 

481 return x + drift_t * dt 

482 

483 def sample_step_stochastic_score( 

484 self, 

485 score: VectorField, 

486 x: torch.Tensor, 

487 zs: torch.Tensor, 

488 idx: int, 

489 ts: torch.Tensor, 

490 ) -> torch.Tensor: 

491 """ 

492 Perform a stochastic sampling step using the score vector field. 

493 

494 This implements the stochastic reverse SDE for score-based models using the 

495 Euler-Maruyama discretization method. 

496 

497 Args: 

498 score (VectorField): The score vector field model 

499 x (torch.Tensor): The current state tensor, of shape (N, *D) 

500 where N is the batch size and D represents the data dimensions 

501 zs (torch.Tensor): The noise tensors for stochastic sampling, of shape (L-1, N, *D) 

502 where L is the number of time steps 

503 idx (int): The current step index 

504 ts (torch.Tensor): The time steps tensor, of shape (L,) 

505 where L is the number of time steps 

506 

507 Returns: 

508 torch.Tensor: The next state tensor, of shape (N, *D) 

509 """ 

510 ( 

511 t, 

512 t1, 

513 alpha_t, 

514 sigma_t, 

515 alpha_prime_t, 

516 sigma_prime_t, 

517 dt, 

518 dwt, 

519 alpha_ratio_t, 

520 sigma_ratio_t, 

521 diff_ratio_t, 

522 ) = self._get_step_quantities(zs, idx, ts) 

523 

524 # Compute score at current state 

525 score_x_t = score(x, ts[idx]) 

526 

527 # Compute drift and diffusion terms 

528 drift = alpha_prime_t * x / alpha_t - sigma_t * sigma_prime_t * score_x_t 

529 diffusion = sigma_prime_t * dwt 

530 

531 # Update state using Euler-Maruyama method 

532 x_next = x + drift * dt + diffusion 

533 

534 return x_next 

535 

536 def sample_step_deterministic_x0( 

537 self, 

538 x0: VectorField, 

539 x: torch.Tensor, 

540 zs: torch.Tensor, 

541 idx: int, 

542 ts: torch.Tensor, 

543 ) -> torch.Tensor: 

544 """ 

545 Perform one step of deterministic sampling using the x0 vector field. 

546 

547 Args: 

548 x0 (VectorField): The x0 vector field model 

549 x (torch.Tensor): The current state, of shape (N, *D) 

550 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

551 idx (int): The current step index 

552 ts (torch.Tensor): The time steps tensor, of shape (L,) 

553 

554 Returns: 

555 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

556 """ 

557 ( 

558 t, 

559 t1, 

560 alpha_t, 

561 sigma_t, 

562 alpha_prime_t, 

563 sigma_prime_t, 

564 dt, 

565 dwt, 

566 alpha_ratio_t, 

567 sigma_ratio_t, 

568 diff_ratio_t, 

569 ) = self._get_step_quantities(zs, idx, ts) 

570 drift_t = sigma_ratio_t * x - alpha_t * diff_ratio_t * x0( 

571 x, self._fix_t_shape(x, t) 

572 ) 

573 return x + drift_t * dt 

574 

575 def sample_step_stochastic_x0( 

576 self, 

577 x0: VectorField, 

578 x: torch.Tensor, 

579 zs: torch.Tensor, 

580 idx: int, 

581 ts: torch.Tensor, 

582 ) -> torch.Tensor: 

583 """ 

584 Perform one step of stochastic sampling using the x0 vector field. 

585 

586 Args: 

587 x0 (VectorField): The x0 vector field model 

588 x (torch.Tensor): The current state, of shape (N, *D) 

589 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

590 idx (int): The current step index 

591 ts (torch.Tensor): The time steps tensor, of shape (L,) 

592 

593 Returns: 

594 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

595 """ 

596 ( 

597 t, 

598 t1, 

599 alpha_t, 

600 sigma_t, 

601 alpha_prime_t, 

602 sigma_prime_t, 

603 dt, 

604 dwt, 

605 alpha_ratio_t, 

606 sigma_ratio_t, 

607 diff_ratio_t, 

608 ) = self._get_step_quantities(zs, idx, ts) 

609 drift_t = ( 

610 alpha_ratio_t + 2 * diff_ratio_t 

611 ) * x - 2 * alpha_t * diff_ratio_t * x0(x, self._fix_t_shape(x, t)) 

612 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t 

613 return x + drift_t * dt + diffusion_t * dwt 

614 

615 def sample_step_deterministic_eps( 

616 self, 

617 eps: VectorField, 

618 x: torch.Tensor, 

619 zs: torch.Tensor, 

620 idx: int, 

621 ts: torch.Tensor, 

622 ) -> torch.Tensor: 

623 """ 

624 Perform one step of deterministic sampling using the eps vector field. 

625 

626 Args: 

627 eps (VectorField): The eps vector field model 

628 x (torch.Tensor): The current state, of shape (N, *D) 

629 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

630 idx (int): The current step index 

631 ts (torch.Tensor): The time steps tensor, of shape (L,) 

632 

633 Returns: 

634 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

635 """ 

636 ( 

637 t, 

638 t1, 

639 alpha_t, 

640 sigma_t, 

641 alpha_prime_t, 

642 sigma_prime_t, 

643 dt, 

644 dwt, 

645 alpha_ratio_t, 

646 sigma_ratio_t, 

647 diff_ratio_t, 

648 ) = self._get_step_quantities(zs, idx, ts) 

649 drift_t = alpha_ratio_t * x + sigma_t * diff_ratio_t * eps( 

650 x, self._fix_t_shape(x, t) 

651 ) 

652 return x + drift_t * dt 

653 

654 def sample_step_stochastic_eps( 

655 self, 

656 eps: VectorField, 

657 x: torch.Tensor, 

658 zs: torch.Tensor, 

659 idx: int, 

660 ts: torch.Tensor, 

661 ) -> torch.Tensor: 

662 """ 

663 Perform one step of stochastic sampling using the eps vector field. 

664 

665 Args: 

666 eps (VectorField): The eps vector field model 

667 x (torch.Tensor): The current state, of shape (N, *D) 

668 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

669 idx (int): The current step index 

670 ts (torch.Tensor): The time steps tensor, of shape (L,) 

671 

672 Returns: 

673 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

674 """ 

675 ( 

676 t, 

677 t1, 

678 alpha_t, 

679 sigma_t, 

680 alpha_prime_t, 

681 sigma_prime_t, 

682 dt, 

683 dwt, 

684 alpha_ratio_t, 

685 sigma_ratio_t, 

686 diff_ratio_t, 

687 ) = self._get_step_quantities(zs, idx, ts) 

688 drift_t = alpha_ratio_t * x + 2 * sigma_t * diff_ratio_t * eps( 

689 x, self._fix_t_shape(x, t) 

690 ) 

691 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t 

692 return x + drift_t * dt + diffusion_t * dwt 

693 

694 def sample_step_deterministic_v( 

695 self, 

696 v: VectorField, 

697 x: torch.Tensor, 

698 zs: torch.Tensor, 

699 idx: int, 

700 ts: torch.Tensor, 

701 ) -> torch.Tensor: 

702 """ 

703 Perform one step of deterministic sampling using the v vector field. 

704 

705 Args: 

706 v (VectorField): The velocity vector field model 

707 x (torch.Tensor): The current state, of shape (N, *D) 

708 zs (torch.Tensor): The noise tensors (unused in deterministic sampling), of shape (L-1, N, *D) 

709 idx (int): The current step index 

710 ts (torch.Tensor): The time steps tensor, of shape (L,) 

711 

712 Returns: 

713 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

714 """ 

715 ( 

716 t, 

717 t1, 

718 alpha_t, 

719 sigma_t, 

720 alpha_prime_t, 

721 sigma_prime_t, 

722 dt, 

723 dwt, 

724 alpha_ratio_t, 

725 sigma_ratio_t, 

726 diff_ratio_t, 

727 ) = self._get_step_quantities(zs, idx, ts) 

728 drift_t = v(x, self._fix_t_shape(x, t)) 

729 return x + drift_t * dt 

730 

731 def sample_step_stochastic_v( 

732 self, 

733 v: VectorField, 

734 x: torch.Tensor, 

735 zs: torch.Tensor, 

736 idx: int, 

737 ts: torch.Tensor, 

738 ) -> torch.Tensor: 

739 """ 

740 Perform one step of stochastic sampling using the v vector field. 

741 

742 Args: 

743 v (VectorField): The velocity vector field model 

744 x (torch.Tensor): The current state, of shape (N, *D) 

745 zs (torch.Tensor): The noise tensors, of shape (L-1, N, *D) 

746 idx (int): The current step index 

747 ts (torch.Tensor): The time steps tensor, of shape (L,) 

748 

749 Returns: 

750 torch.Tensor: The next state after one sampling step, of shape (N, *D) 

751 """ 

752 ( 

753 t, 

754 t1, 

755 alpha_t, 

756 sigma_t, 

757 alpha_prime_t, 

758 sigma_prime_t, 

759 dt, 

760 dwt, 

761 alpha_ratio_t, 

762 sigma_ratio_t, 

763 diff_ratio_t, 

764 ) = self._get_step_quantities(zs, idx, ts) 

765 drift_t = -alpha_ratio_t * x + v(x, self._fix_t_shape(x, t)) 

766 diffusion_t = torch.sqrt(2 * diff_ratio_t) * sigma_t 

767 return x + drift_t * dt + diffusion_t * dwt 

768 

769 

770class DDMSampler(Sampler): 

771 """ 

772 Class for sampling from diffusion models using the DDPM/DDIM sampler. 

773 """ 

774 

775 def _convert_to_x0( 

776 self, 

777 x: torch.Tensor, 

778 t: torch.Tensor, 

779 fx: torch.Tensor, 

780 fx_type: VectorFieldType, 

781 ): 

782 x0 = convert_vector_field_type( 

783 x, 

784 fx, 

785 self.diffusion_process.alpha(t), 

786 self.diffusion_process.sigma(t), 

787 self.diffusion_process.alpha_prime(t), 

788 self.diffusion_process.sigma_prime(t), 

789 fx_type, 

790 VectorFieldType.X0, 

791 ) 

792 return x0 

793 

794 def _ddpm_step_x0_tensor( 

795 self, 

796 x0: torch.Tensor, 

797 x: torch.Tensor, 

798 zs: torch.Tensor, 

799 idx: int, 

800 ts: torch.Tensor, 

801 ) -> torch.Tensor: 

802 t = pad_shape_back(ts[idx], x.shape) 

803 t1 = pad_shape_back(ts[idx + 1], x.shape) 

804 alpha_t = self.diffusion_process.alpha(t) 

805 sigma_t = self.diffusion_process.sigma(t) 

806 alpha_t1 = self.diffusion_process.alpha(t1) 

807 sigma_t1 = self.diffusion_process.sigma(t1) 

808 

809 r11 = (alpha_t / alpha_t1) * (sigma_t1 / sigma_t) 

810 r12 = r11 * (sigma_t1 / sigma_t) 

811 r22 = r12 * (alpha_t / alpha_t1) 

812 

813 mean = r12 * x + alpha_t1 * (1 - r22) * x0 

814 std = sigma_t1 * (1 - r11**2) ** (1 / 2) 

815 return mean + std * zs[idx] 

816 

817 def _ddim_step_x0_tensor( 

818 self, 

819 x0: torch.Tensor, 

820 x: torch.Tensor, 

821 zs: torch.Tensor, 

822 idx: int, 

823 ts: torch.Tensor, 

824 ) -> torch.Tensor: 

825 t = pad_shape_back(ts[idx], x.shape) 

826 t1 = pad_shape_back(ts[idx + 1], x.shape) 

827 alpha_t = self.diffusion_process.alpha(t) 

828 sigma_t = self.diffusion_process.sigma(t) 

829 alpha_t1 = self.diffusion_process.alpha(t1) 

830 sigma_t1 = self.diffusion_process.sigma(t1) 

831 

832 r01 = sigma_t1 / sigma_t 

833 r11 = (alpha_t / alpha_t1) * r01 

834 

835 mean = r01 * x + alpha_t1 * (1 - r11) * x0 

836 return mean 

837 

838 def sample_step_deterministic_x0( 

839 self, 

840 x0: VectorField, 

841 x: torch.Tensor, 

842 zs: torch.Tensor, 

843 idx: int, 

844 ts: torch.Tensor, 

845 ) -> torch.Tensor: 

846 x0_value = x0(x, self._fix_t_shape(x, ts[idx])) 

847 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts) 

848 

849 def sample_step_stochastic_x0( 

850 self, 

851 x0: VectorField, 

852 x: torch.Tensor, 

853 zs: torch.Tensor, 

854 idx: int, 

855 ts: torch.Tensor, 

856 ) -> torch.Tensor: 

857 x0_value = x0(x, self._fix_t_shape(x, ts[idx])) 

858 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts) 

859 

860 def sample_step_deterministic_score( 

861 self, 

862 score: VectorField, 

863 x: torch.Tensor, 

864 zs: torch.Tensor, 

865 idx: int, 

866 ts: torch.Tensor, 

867 ) -> torch.Tensor: 

868 t = pad_shape_back(ts[idx], x.shape) 

869 score_value = score(x, self._fix_t_shape(x, t)) 

870 x0_value = self._convert_to_x0(x, t, score_value, VectorFieldType.SCORE) 

871 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts) 

872 

873 def sample_step_stochastic_score( 

874 self, 

875 score: VectorField, 

876 x: torch.Tensor, 

877 zs: torch.Tensor, 

878 idx: int, 

879 ts: torch.Tensor, 

880 ) -> torch.Tensor: 

881 t = pad_shape_back(ts[idx], x.shape) 

882 score_value = score(x, self._fix_t_shape(x, t)) 

883 x0_value = self._convert_to_x0(x, t, score_value, VectorFieldType.SCORE) 

884 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts) 

885 

886 def sample_step_deterministic_eps( 

887 self, 

888 eps: VectorField, 

889 x: torch.Tensor, 

890 zs: torch.Tensor, 

891 idx: int, 

892 ts: torch.Tensor, 

893 ) -> torch.Tensor: 

894 t = pad_shape_back(ts[idx], x.shape) 

895 eps_value = eps(x, self._fix_t_shape(x, t)) 

896 x0_value = self._convert_to_x0(x, t, eps_value, VectorFieldType.EPS) 

897 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts) 

898 

899 def sample_step_stochastic_eps( 

900 self, 

901 eps: VectorField, 

902 x: torch.Tensor, 

903 zs: torch.Tensor, 

904 idx: int, 

905 ts: torch.Tensor, 

906 ) -> torch.Tensor: 

907 t = pad_shape_back(ts[idx], x.shape) 

908 eps_value = eps(x, self._fix_t_shape(x, t)) 

909 x0_value = self._convert_to_x0(x, t, eps_value, VectorFieldType.EPS) 

910 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts) 

911 

912 def sample_step_deterministic_v( 

913 self, 

914 v: VectorField, 

915 x: torch.Tensor, 

916 zs: torch.Tensor, 

917 idx: int, 

918 ts: torch.Tensor, 

919 ) -> torch.Tensor: 

920 t = pad_shape_back(ts[idx], x.shape) 

921 v_value = v(x, self._fix_t_shape(x, t)) 

922 x0_value = self._convert_to_x0(x, t, v_value, VectorFieldType.V) 

923 return self._ddim_step_x0_tensor(x0_value, x, zs, idx, ts) 

924 

925 def sample_step_stochastic_v( 

926 self, 

927 v: VectorField, 

928 x: torch.Tensor, 

929 zs: torch.Tensor, 

930 idx: int, 

931 ts: torch.Tensor, 

932 ) -> torch.Tensor: 

933 t = pad_shape_back(ts[idx], x.shape) 

934 v_value = v(x, self._fix_t_shape(x, t)) 

935 x0_value = self._convert_to_x0(x, t, v_value, VectorFieldType.V) 

936 return self._ddpm_step_x0_tensor(x0_value, x, zs, idx, ts)