Coverage for src/shephex/experiment/chain_iterator.py: 100%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-29 18:45 +0100

1""" 

2Defines the ChainableExperimentIterator class, which can be used to define sets of experiments. 

3""" 

4from collections.abc import Iterator 

5from pathlib import Path 

6from typing import Callable, Self, Union 

7 

8from shephex.experiment import Experiment, Options 

9from shephex.study import Study 

10 

11 

12class ChainableExperimentIterator(Iterator): 

13 """ 

14 An iterator that can be used to define sets of experiments with a chainable interface. 

15 """ 

16 def __init__(self, function: Union[Callable, str, Path], directory: Union[Path, str]) -> None: 

17 """ 

18 

19 Parameters 

20 ---------- 

21 function : Union[Callable, str, Path] 

22 The function to be executed 

23 directory : Union[Path, str] 

24 The directory where the experiments will be saved.  

25 """ 

26 self.function = function 

27 self.directory = Path(directory) 

28 self.options = [] 

29 self.index = -1 

30 

31 def add(self, *args, zipped: bool = False, permute: bool = True, **kwargs) -> Self: 

32 """ 

33 Add one or more options to the iterator. 

34 

35 Parameters 

36 ---------- 

37 args : Iterable 

38 Positional arguments to be added. 

39 zipped : bool, optional 

40 If True, arguments are added in order - no permutation of arguments is done, by default False 

41 This is analogous to Python's zip function. 

42 permute : bool, optional 

43 If True, arguments are permuted with other arguments and previous options to yield  

44 all possible combinations, by default True 

45 kwargs : Dict 

46 Key-word arguments to be added.  

47 """ 

48 if zipped or not permute: 

49 self._zipped_add(*args, **kwargs) 

50 else: 

51 self._permute_add(*args, **kwargs) 

52 

53 return self 

54 

55 def zip(self, *args, **kwargs) -> Self: 

56 """ 

57 Add arguments in order.  

58 

59 All positional arguments and key-word arguments must have the same number  

60 of elements. If some options are already configured, the number of elements  

61 must match the number of elements in the previously added options. 

62 

63 Parameters 

64 ---------- 

65 args : Iterable 

66 Positional arguments to be added. 

67 kwargs : Dict 

68 Key-word arguments to be added.  

69 """ 

70 return self.add(*args, zipped=True, permute=False, **kwargs) 

71 

72 def permute(self, *args, **kwargs) -> Self: 

73 """ 

74 Add arguments permutationally. 

75 

76 Arguments are permuted with other arguments and previous options to yield 

77 all possible combinations. 

78 

79 Parameters 

80 ---------- 

81 args : Iterable 

82 Positional arguments to be added. 

83 kwargs : Dict  

84 Key-word arguments to be added. 

85 """ 

86 return self.add(*args, zipped=False, permute=True, **kwargs) 

87 

88 def _zipped_add(self, *args, **kwargs): 

89 """ 

90 Add options 'strictly' meaning that all arguments and keyword arguments  

91 must have the same number of elements in their iterables. 

92 """ 

93 

94 arg_elements = [] 

95 for iterable in args: 

96 arg_elements.append(len(iterable)) 

97 

98 for iterable in kwargs.values(): 

99 arg_elements.append(len(iterable)) 

100 

101 n_elements = arg_elements[0] 

102 

103 if not all([arg_element == n_elements for arg_element in arg_elements]): 

104 raise ValueError("Number of elements in passed arguments and key-word arguments is not equal") 

105 

106 if len(self.options) > 0 and n_elements != len(self.options): 

107 raise ValueError("When adding strict options the number of values for added options must equal the number for previously added values.") 

108 

109 if len(self.options) == 0: 

110 self.options = [Options() for _ in range(n_elements)] 

111 

112 for kwarg, values in kwargs.items(): 

113 for options, value in zip(self.options, values): 

114 options.add_kwarg(kwarg, value) 

115 

116 for arg_values in args: 

117 for option, value in zip(self.options, arg_values): 

118 option.add_arg(value) 

119 

120 def _permute_add(self, *args, **kwargs): 

121 """ 

122 Add options permutationally.  

123 """ 

124 

125 # Deal with the case where this is the first options added. 

126 if len(self.options) == 0: 

127 if len(args) != 0: 

128 self._zipped_add(args[0]) 

129 args = args[1:] 

130 else: 

131 key = list(kwargs.keys())[0] 

132 self._zipped_add(**{key: kwargs.pop(key)}) 

133 

134 if len(args) == 0 and len(kwargs) == 0: 

135 return 

136 

137 # Make permutations 

138 new_options = [] 

139 for kwarg, kwarg_values in kwargs.items(): 

140 for value in kwarg_values: 

141 for option in self.options: 

142 option_new = option.copy() 

143 option_new.add_kwarg(key=kwarg, value=value) 

144 new_options.append(option_new) 

145 

146 for arg_values in args: 

147 for value in arg_values: 

148 for option in self.options: 

149 option_new = option.copy() 

150 option_new.add_arg(value=value) 

151 new_options.append(option_new) 

152 

153 self.options = new_options 

154 

155 def __iter__(self) -> Self: 

156 self.index = 0 

157 self.study = Study(path=self.directory, refresh=True) 

158 experiments = [] 

159 for options in self.options: 

160 experiment = Experiment(*options.args, function=self.function, **options.kwargs, root_path=self.directory) 

161 self.study.add_experiment(experiment=experiment) 

162 experiments.append(experiment) 

163 self._experiments = self.study.get_experiments(status='pending', loaded_experiments=experiments) 

164 

165 return self 

166 

167 def __next__(self) -> Experiment: 

168 if self.index < len(self._experiments): 

169 experiment = self._experiments[self.index] 

170 self.index += 1 

171 return experiment 

172 else: 

173 raise StopIteration 

174 

175 def __len__(self) -> int: 

176 return len(self.options)