Coverage for src/shephex/executor/slurm/slurm_executor.py: 74%
122 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:47 +0100
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:47 +0100
1import inspect
2import json
3from pathlib import Path
4from typing import List, Literal, Optional, Union
6from shephex.executor.executor import Executor
7from shephex.executor.slurm import (
8 SlurmBody,
9 SlurmHeader,
10 SlurmProfileManager,
11 SlurmScript,
12)
13from shephex.experiment import FutureResult
14from shephex.experiment.experiment import Experiment
17class SlurmSafetyError(Exception):
18 pass
20class SlurmExecutor(Executor):
21 """
22 Shephex SLURM executor for executing experiments on a SLURM cluster.
23 """
24 def __init__(
25 self, directory: Union[str, Path] = None,
26 scratch: bool = False,
27 ulimit: Union[int, Literal['default']] = 8000,
28 move_output_file: bool = True,
29 safety_check: bool = True,
30 **kwargs
31 ) -> None:
32 """
33 shephex SLURM executor.
35 Parameters
36 ----------
37 directory : Union[str, Path], optional
38 Directory where the SLURM script and output files will be stored,
39 defaults to /slurm.
40 scratch : bool, optional
41 If True, the executor will use the /scratch directory for the
42 execution of the experiments. Defaults to False. When true
43 files will automatically be copied back to the original directory
44 once the job is finished.
45 **kwargs
46 Additional keyword arguments to be passed to the SlurmHeader,
47 these are the SLURM parameters for the job. Supports all the
48 arguments for sbatch, see https://slurm.schedmd.com/sbatch.html.
49 """
50 if safety_check:
51 self.safety_check(frame_index=2)
53 self.header = SlurmHeader()
54 for key, value in kwargs.items():
55 self.header.add(key, value)
57 if directory is None:
58 directory = 'slurm'
59 self.directory = Path(directory)
61 # Containers for commands to be executed before and after the main execution
62 self._commands_pre_execution = []
63 self._commands_post_execution = []
65 # Special options
66 self.ulimit = ulimit
67 self.move_output_file = move_output_file
68 self.scratch = scratch
70 # To kepe track of the special options for saving the config
71 self.special_options = {
72 'scratch': scratch,
73 'ulimit': ulimit,
74 'move_output_file': move_output_file
75 }
78 @classmethod
79 def from_config(cls, path: Path, safety_check: bool = True, **kwargs) -> 'SlurmExecutor':
81 if safety_check:
82 cls.safety_check(frame_index=1)
84 if not isinstance(path, Path):
85 path = Path(path)
87 assert path.exists(), f'File {path} does not exist'
88 assert path.suffix == '.json', f'File {path} is not a json file'
90 with open(path) as f:
91 config = json.load(f)
92 config.update(kwargs)
94 pre_commands = config.pop('commands_pre_execution', list())
95 post_commands = config.pop('commands_post_execution', list())
97 instance = cls(**config, safety_check=False)
98 instance._commands_pre_execution = pre_commands
99 instance._commands_post_execution = post_commands
100 return instance
102 def to_config(self, path: Path | str) -> None:
103 if not isinstance(path, Path):
104 path = Path(path)
105 config = self.header.to_dict()
106 config.update(self.special_options)
108 config['commands_pre_execution'] = self._commands_pre_execution
109 config['commands_post_execution'] = self._commands_post_execution
111 with open(path, 'w') as f:
112 json.dump(config, f, indent=4)
114 @classmethod
115 def from_profile(cls, name: str, safety_check: bool = True, **kwargs) -> 'SlurmExecutor':
116 """
117 Create a new SlurmExecutor from a profile.
119 Parameters
120 ----------
121 name : str
122 Name of the profile.
123 safety_check : bool, optional
124 If True, a safety check will be performed to ensure that the executor
125 is not instantiated on a script that is not the main script. Defaults to True.
126 **kwargs
127 Additional keyword arguments to be passed to the SlurmExecutor.
128 """
129 if safety_check:
130 cls.safety_check(frame_index=2)
131 kwargs["safety_check"] = False # Safety checked here so don't do it again.D
133 spm = SlurmProfileManager()
134 profile = spm.get_profile_path(name)
135 return cls.from_config(profile, **kwargs)
137 def _single_execute(self) -> None:
138 raise NotImplementedError('Single execution is not supported for SLURM Executor, everything is executed with _sequence execute.')
140 def _sequence_execute(
141 self,
142 experiments: List[Experiment],
143 dry: bool = False,
144 execution_directory: Union[Path, str] = None,
145 ) -> List[FutureResult]:
146 """
147 Execute a sequence of experiments as an array job.
149 Parameters
150 ----------
151 experiments : List[Experiment]
152 List of experiments to be executed.
153 dry : bool, optional
154 If True, the script will be printed instead of executed.
155 execution_directory : Union[Path, str], optional
156 Directory where the experiments will be executed.
158 Returns
159 -------
160 List[FutureResult]
161 List of FutureResult objects.
162 """
164 if len(experiments) == 0:
165 return []
167 # Dump config:
168 self.directory.mkdir(parents=True, exist_ok=True)
169 index = len(list(self.directory.glob('config*.json')))
170 path = self.directory / f'config_{index}.json'
171 self.to_config(path)
173 header = self.header.copy()
174 header.add('array', f'0-{len(experiments)-1}%{len(experiments)}')
176 body = self._make_slurm_body(experiments)
178 count = len(list(self.directory.glob('submit*.sh')))
179 script = SlurmScript(header, body, directory=self.directory, name=f'submit_{count}.sh')
180 if dry:
181 print(script)
182 return [FutureResult() for _ in experiments]
184 script.write()
186 job_id = script.submit()
187 for experiment in experiments:
188 experiment.update_status('submitted')
190 return [FutureResult(info={'job_id': job_id}) for _ in experiments]
192 def _bash_array_str(self, strings: List[str]) -> str:
193 """
194 Convert a list of strings into a nicely formatted bash array of string.
196 Parameters
197 ----------
198 strings : List[str]
199 List of strings to be converted.
201 Returns
202 -------
203 str
204 A python string representing a bash array of strings.
205 """
206 bash_str = ' \n\t'.join(strings)
207 return f'(\n\t{bash_str}\n)'
209 def _body_add(self, command: str, when: Optional[Literal['pre', 'post']] = None) -> None:
210 """
211 Add a command to the body of the SLURM script.
213 Parameters
214 ----------
215 command : str
216 Command to be added to the body.
217 """
218 if when is None:
219 when = 'pre'
221 if when == 'pre':
222 self._commands_pre_execution.append(command)
224 elif when == 'post':
225 self._commands_post_execution.append(command)
227 def _make_slurm_body(self, experiments: List[Experiment]) -> SlurmBody:
228 """
229 Make a new SlurmBody object.
231 Returns
232 -------
233 SlurmBody
234 A new SlurmBody object.
235 """
237 identifiers = [str(experiment.identifier) for experiment in experiments]
238 directories = [str(experiment.directory.resolve()) for experiment in experiments]
240 body = SlurmBody()
242 body.add(f'directories={self._bash_array_str(directories)}')
243 body.add(f'identifiers={self._bash_array_str(identifiers)}')
245 if self.move_output_file:
246 self._body_add(r"mv slurm-${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}.out ${directories[$SLURM_ARRAY_TASK_ID]}", when='pre')
247 if self.ulimit != 'default':
248 self._body_add(f'ulimit -Su {self.ulimit}', when='pre')
250 for command in self._commands_pre_execution:
251 body.add(command)
253 # Slurm info command:
254 command = r'hex slurm add-info -d ${directories[$SLURM_ARRAY_TASK_ID]} -j "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}"'
255 body.add(command)
257 # Execution command
258 command = r'hex execute ${directories[$SLURM_ARRAY_TASK_ID]}'
260 if self.scratch:
261 command += ' -e /scratch/$SLURM_JOB_ID'
263 body.add(command)
265 for command in self._commands_post_execution:
266 body.add(command)
268 if self.scratch:
269 body.add(
270 r'cp -r /scratch/$SLURM_JOB_ID/* ${directories[$SLURM_ARRAY_TASK_ID]}'
271 )
273 return body
275 @staticmethod
276 def safety_check(frame_index: int = 2) -> None:
277 """
278 Check if the executor is being called from the main script.
280 Parameters
281 ----------
282 frame_index : int, optional
283 Index of the frame to be checked. Defaults to 2.
285 Raises
286 ------
287 SlurmSafetyError
288 If the executor is not being called from the main script.
290 Frame index depends on which creation method is used:
291 - from_profile: 2
292 - from_config: 1
293 - __init__: 0
294 """
296 caller_frames = inspect.stack()
297 caller_frame = caller_frames[frame_index]
298 caller_module = inspect.getmodule(caller_frame[0])
300 if caller_module and caller_module.__name__ != "__main__" or caller_module is None:
301 raise SlurmSafetyError("""SlurmExecutor should only be called from the main script.
302 If the you really want, you can disable this check. This error may be caused by not having
303 a 'if __name__ == "__main__":' block in the main script.""")