Coverage for src/diffusionlab/schedulers.py: 100%
19 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 21:37 -0700
1from typing import Any
3import torch
6class Scheduler:
7 """
8 Base class for time step schedulers used in diffusion, denoising, and sampling.
10 A scheduler determines the sequence of time steps used during the sampling process.
11 Different scheduling strategies can affect the quality and efficiency of the
12 generative process.
14 The scheduler generates a sequence of time values, typically in the range [0, 1],
15 which are used to control the noise level at each step of the sampling process.
16 """
18 def __init__(self, **schedule_hparams: Any) -> None:
19 """
20 Initialize the scheduler.
22 This base implementation does not store any variables.
23 Subclasses may override this method to initialize specific parameters.
25 Args:
26 **schedule_hparams: Keyword arguments containing scheduler parameters.
27 Not used in the base class but available for subclasses.
28 """
29 pass
31 def get_ts(self, **ts_hparams: Any) -> torch.Tensor:
32 """
33 Generate the sequence of time steps.
35 This is an abstract method that must be implemented by subclasses.
37 Args:
38 **ts_hparams: Keyword arguments containing parameters for generating time steps.
39 The specific parameters depend on the scheduler implementation.
40 Typically includes:
41 - t_min (float): The minimum time value
42 - t_max (float): The maximum time value
43 - L (int): The number of time steps to generate
45 Returns:
46 torch.Tensor: A tensor of shape (L,) containing the sequence of time steps
47 in descending order, where L is the number of time steps.
49 Raises:
50 NotImplementedError: If the subclass does not implement this method.
51 """
52 raise NotImplementedError
55class UniformScheduler(Scheduler):
56 """
57 A scheduler that generates uniformly spaced time steps.
59 This scheduler creates a sequence of time steps that are uniformly distributed
60 between a minimum and maximum time value. The time steps are returned in
61 descending order (from t_max to t_min).
63 This is the simplest scheduling strategy and is often used as a baseline.
64 """
66 def __init__(self, **schedule_hparams: Any) -> None:
67 """
68 Initialize the uniform scheduler.
70 This implementation does not store any variables, following the base class design.
72 Args:
73 **schedule_hparams: Keyword arguments containing scheduler parameters.
74 Not used but passed to the parent class.
75 """
76 super().__init__(**schedule_hparams)
78 def get_ts(self, **ts_hparams: Any) -> torch.Tensor:
79 """
80 Generate uniformly spaced time steps.
82 Args:
83 **ts_hparams: Keyword arguments containing:
84 - t_min (float): The minimum time value, typically close to 0.
85 - t_max (float): The maximum time value, typically close to 1.
86 - L (int): The number of time steps to generate.
88 Returns:
89 torch.Tensor: A tensor of shape (L,) containing uniformly spaced time steps
90 in descending order (from t_max to t_min), where L is the number
91 of time steps specified in ts_hparams.
93 Raises:
94 AssertionError: If t_min or t_max are outside the range [0, 1] or if t_min > t_max.
95 """
96 t_min = ts_hparams["t_min"]
97 t_max = ts_hparams["t_max"]
98 L = ts_hparams["L"]
99 assert 0 <= t_min <= t_max <= 1, "t_min and t_max must be in the range [0, 1]"
100 assert L >= 2, "L must be at least 2"
102 ts = torch.linspace(t_min, t_max, L)
103 ts = ts.flip(0)
104 return ts