Coverage for src/diffusionlab/schedulers.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 21:37 -0700

1from typing import Any 

2 

3import torch 

4 

5 

6class Scheduler: 

7 """ 

8 Base class for time step schedulers used in diffusion, denoising, and sampling. 

9 

10 A scheduler determines the sequence of time steps used during the sampling process. 

11 Different scheduling strategies can affect the quality and efficiency of the 

12 generative process. 

13 

14 The scheduler generates a sequence of time values, typically in the range [0, 1], 

15 which are used to control the noise level at each step of the sampling process. 

16 """ 

17 

18 def __init__(self, **schedule_hparams: Any) -> None: 

19 """ 

20 Initialize the scheduler. 

21 

22 This base implementation does not store any variables. 

23 Subclasses may override this method to initialize specific parameters. 

24 

25 Args: 

26 **schedule_hparams: Keyword arguments containing scheduler parameters. 

27 Not used in the base class but available for subclasses. 

28 """ 

29 pass 

30 

31 def get_ts(self, **ts_hparams: Any) -> torch.Tensor: 

32 """ 

33 Generate the sequence of time steps. 

34 

35 This is an abstract method that must be implemented by subclasses. 

36 

37 Args: 

38 **ts_hparams: Keyword arguments containing parameters for generating time steps. 

39 The specific parameters depend on the scheduler implementation. 

40 Typically includes: 

41 - t_min (float): The minimum time value 

42 - t_max (float): The maximum time value 

43 - L (int): The number of time steps to generate 

44 

45 Returns: 

46 torch.Tensor: A tensor of shape (L,) containing the sequence of time steps 

47 in descending order, where L is the number of time steps. 

48 

49 Raises: 

50 NotImplementedError: If the subclass does not implement this method. 

51 """ 

52 raise NotImplementedError 

53 

54 

55class UniformScheduler(Scheduler): 

56 """ 

57 A scheduler that generates uniformly spaced time steps. 

58 

59 This scheduler creates a sequence of time steps that are uniformly distributed 

60 between a minimum and maximum time value. The time steps are returned in 

61 descending order (from t_max to t_min). 

62 

63 This is the simplest scheduling strategy and is often used as a baseline. 

64 """ 

65 

66 def __init__(self, **schedule_hparams: Any) -> None: 

67 """ 

68 Initialize the uniform scheduler. 

69 

70 This implementation does not store any variables, following the base class design. 

71 

72 Args: 

73 **schedule_hparams: Keyword arguments containing scheduler parameters. 

74 Not used but passed to the parent class. 

75 """ 

76 super().__init__(**schedule_hparams) 

77 

78 def get_ts(self, **ts_hparams: Any) -> torch.Tensor: 

79 """ 

80 Generate uniformly spaced time steps. 

81 

82 Args: 

83 **ts_hparams: Keyword arguments containing: 

84 - t_min (float): The minimum time value, typically close to 0. 

85 - t_max (float): The maximum time value, typically close to 1. 

86 - L (int): The number of time steps to generate. 

87 

88 Returns: 

89 torch.Tensor: A tensor of shape (L,) containing uniformly spaced time steps 

90 in descending order (from t_max to t_min), where L is the number 

91 of time steps specified in ts_hparams. 

92 

93 Raises: 

94 AssertionError: If t_min or t_max are outside the range [0, 1] or if t_min > t_max. 

95 """ 

96 t_min = ts_hparams["t_min"] 

97 t_max = ts_hparams["t_max"] 

98 L = ts_hparams["L"] 

99 assert 0 <= t_min <= t_max <= 1, "t_min and t_max must be in the range [0, 1]" 

100 assert L >= 2, "L must be at least 2" 

101 

102 ts = torch.linspace(t_min, t_max, L) 

103 ts = ts.flip(0) 

104 return ts