Coverage for src/shephex/experiment/chain_iterator.py: 100%
80 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:45 +0100
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:45 +0100
1"""
2Defines the ChainableExperimentIterator class, which can be used to define sets of experiments.
3"""
4from collections.abc import Iterator
5from pathlib import Path
6from typing import Callable, Self, Union
8from shephex.experiment import Experiment, Options
9from shephex.study import Study
12class ChainableExperimentIterator(Iterator):
13 """
14 An iterator that can be used to define sets of experiments with a chainable interface.
15 """
16 def __init__(self, function: Union[Callable, str, Path], directory: Union[Path, str]) -> None:
17 """
19 Parameters
20 ----------
21 function : Union[Callable, str, Path]
22 The function to be executed
23 directory : Union[Path, str]
24 The directory where the experiments will be saved.
25 """
26 self.function = function
27 self.directory = Path(directory)
28 self.options = []
29 self.index = -1
31 def add(self, *args, zipped: bool = False, permute: bool = True, **kwargs) -> Self:
32 """
33 Add one or more options to the iterator.
35 Parameters
36 ----------
37 args : Iterable
38 Positional arguments to be added.
39 zipped : bool, optional
40 If True, arguments are added in order - no permutation of arguments is done, by default False
41 This is analogous to Python's zip function.
42 permute : bool, optional
43 If True, arguments are permuted with other arguments and previous options to yield
44 all possible combinations, by default True
45 kwargs : Dict
46 Key-word arguments to be added.
47 """
48 if zipped or not permute:
49 self._zipped_add(*args, **kwargs)
50 else:
51 self._permute_add(*args, **kwargs)
53 return self
55 def zip(self, *args, **kwargs) -> Self:
56 """
57 Add arguments in order.
59 All positional arguments and key-word arguments must have the same number
60 of elements. If some options are already configured, the number of elements
61 must match the number of elements in the previously added options.
63 Parameters
64 ----------
65 args : Iterable
66 Positional arguments to be added.
67 kwargs : Dict
68 Key-word arguments to be added.
69 """
70 return self.add(*args, zipped=True, permute=False, **kwargs)
72 def permute(self, *args, **kwargs) -> Self:
73 """
74 Add arguments permutationally.
76 Arguments are permuted with other arguments and previous options to yield
77 all possible combinations.
79 Parameters
80 ----------
81 args : Iterable
82 Positional arguments to be added.
83 kwargs : Dict
84 Key-word arguments to be added.
85 """
86 return self.add(*args, zipped=False, permute=True, **kwargs)
88 def _zipped_add(self, *args, **kwargs):
89 """
90 Add options 'strictly' meaning that all arguments and keyword arguments
91 must have the same number of elements in their iterables.
92 """
94 arg_elements = []
95 for iterable in args:
96 arg_elements.append(len(iterable))
98 for iterable in kwargs.values():
99 arg_elements.append(len(iterable))
101 n_elements = arg_elements[0]
103 if not all([arg_element == n_elements for arg_element in arg_elements]):
104 raise ValueError("Number of elements in passed arguments and key-word arguments is not equal")
106 if len(self.options) > 0 and n_elements != len(self.options):
107 raise ValueError("When adding strict options the number of values for added options must equal the number for previously added values.")
109 if len(self.options) == 0:
110 self.options = [Options() for _ in range(n_elements)]
112 for kwarg, values in kwargs.items():
113 for options, value in zip(self.options, values):
114 options.add_kwarg(kwarg, value)
116 for arg_values in args:
117 for option, value in zip(self.options, arg_values):
118 option.add_arg(value)
120 def _permute_add(self, *args, **kwargs):
121 """
122 Add options permutationally.
123 """
125 # Deal with the case where this is the first options added.
126 if len(self.options) == 0:
127 if len(args) != 0:
128 self._zipped_add(args[0])
129 args = args[1:]
130 else:
131 key = list(kwargs.keys())[0]
132 self._zipped_add(**{key: kwargs.pop(key)})
134 if len(args) == 0 and len(kwargs) == 0:
135 return
137 # Make permutations
138 new_options = []
139 for kwarg, kwarg_values in kwargs.items():
140 for value in kwarg_values:
141 for option in self.options:
142 option_new = option.copy()
143 option_new.add_kwarg(key=kwarg, value=value)
144 new_options.append(option_new)
146 for arg_values in args:
147 for value in arg_values:
148 for option in self.options:
149 option_new = option.copy()
150 option_new.add_arg(value=value)
151 new_options.append(option_new)
153 self.options = new_options
155 def __iter__(self) -> Self:
156 self.index = 0
157 self.study = Study(path=self.directory, refresh=True)
158 experiments = []
159 for options in self.options:
160 experiment = Experiment(*options.args, function=self.function, **options.kwargs, root_path=self.directory)
161 self.study.add_experiment(experiment=experiment)
162 experiments.append(experiment)
163 self._experiments = self.study.get_experiments(status='pending', loaded_experiments=experiments)
165 return self
167 def __next__(self) -> Experiment:
168 if self.index < len(self._experiments):
169 experiment = self._experiments[self.index]
170 self.index += 1
171 return experiment
172 else:
173 raise StopIteration
175 def __len__(self) -> int:
176 return len(self.options)