Coverage for src/shephex/executor/slurm/slurm_executor.py: 74%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-29 18:47 +0100

1import inspect 

2import json 

3from pathlib import Path 

4from typing import List, Literal, Optional, Union 

5 

6from shephex.executor.executor import Executor 

7from shephex.executor.slurm import ( 

8 SlurmBody, 

9 SlurmHeader, 

10 SlurmProfileManager, 

11 SlurmScript, 

12) 

13from shephex.experiment import FutureResult 

14from shephex.experiment.experiment import Experiment 

15 

16 

17class SlurmSafetyError(Exception): 

18 pass 

19 

20class SlurmExecutor(Executor): 

21 """ 

22 Shephex SLURM executor for executing experiments on a SLURM cluster. 

23 """ 

24 def __init__( 

25 self, directory: Union[str, Path] = None, 

26 scratch: bool = False, 

27 ulimit: Union[int, Literal['default']] = 8000, 

28 move_output_file: bool = True, 

29 safety_check: bool = True, 

30 **kwargs 

31 ) -> None: 

32 """ 

33 shephex SLURM executor. 

34 

35 Parameters 

36 ---------- 

37 directory : Union[str, Path], optional 

38 Directory where the SLURM script and output files will be stored, 

39 defaults to /slurm. 

40 scratch : bool, optional 

41 If True, the executor will use the /scratch directory for the 

42 execution of the experiments. Defaults to False. When true 

43 files will automatically be copied back to the original directory 

44 once the job is finished. 

45 **kwargs 

46 Additional keyword arguments to be passed to the SlurmHeader, 

47 these are the SLURM parameters for the job. Supports all the  

48 arguments for sbatch, see https://slurm.schedmd.com/sbatch.html. 

49 """ 

50 if safety_check: 

51 self.safety_check(frame_index=2) 

52 

53 self.header = SlurmHeader() 

54 for key, value in kwargs.items(): 

55 self.header.add(key, value) 

56 

57 if directory is None: 

58 directory = 'slurm' 

59 self.directory = Path(directory) 

60 

61 # Containers for commands to be executed before and after the main execution 

62 self._commands_pre_execution = [] 

63 self._commands_post_execution = [] 

64 

65 # Special options 

66 self.ulimit = ulimit 

67 self.move_output_file = move_output_file 

68 self.scratch = scratch 

69 

70 # To kepe track of the special options for saving the config 

71 self.special_options = { 

72 'scratch': scratch, 

73 'ulimit': ulimit, 

74 'move_output_file': move_output_file 

75 } 

76 

77 

78 @classmethod 

79 def from_config(cls, path: Path, safety_check: bool = True, **kwargs) -> 'SlurmExecutor': 

80 

81 if safety_check: 

82 cls.safety_check(frame_index=1) 

83 

84 if not isinstance(path, Path): 

85 path = Path(path) 

86 

87 assert path.exists(), f'File {path} does not exist' 

88 assert path.suffix == '.json', f'File {path} is not a json file' 

89 

90 with open(path) as f: 

91 config = json.load(f) 

92 config.update(kwargs) 

93 

94 pre_commands = config.pop('commands_pre_execution', list()) 

95 post_commands = config.pop('commands_post_execution', list()) 

96 

97 instance = cls(**config, safety_check=False) 

98 instance._commands_pre_execution = pre_commands 

99 instance._commands_post_execution = post_commands 

100 return instance 

101 

102 def to_config(self, path: Path | str) -> None: 

103 if not isinstance(path, Path): 

104 path = Path(path) 

105 config = self.header.to_dict() 

106 config.update(self.special_options) 

107 

108 config['commands_pre_execution'] = self._commands_pre_execution 

109 config['commands_post_execution'] = self._commands_post_execution 

110 

111 with open(path, 'w') as f: 

112 json.dump(config, f, indent=4) 

113 

114 @classmethod 

115 def from_profile(cls, name: str, safety_check: bool = True, **kwargs) -> 'SlurmExecutor': 

116 """ 

117 Create a new SlurmExecutor from a profile. 

118 

119 Parameters 

120 ---------- 

121 name : str 

122 Name of the profile. 

123 safety_check : bool, optional 

124 If True, a safety check will be performed to ensure that the executor 

125 is not instantiated on a script that is not the main script. Defaults to True. 

126 **kwargs 

127 Additional keyword arguments to be passed to the SlurmExecutor. 

128 """ 

129 if safety_check: 

130 cls.safety_check(frame_index=2) 

131 kwargs["safety_check"] = False # Safety checked here so don't do it again.D 

132 

133 spm = SlurmProfileManager() 

134 profile = spm.get_profile_path(name) 

135 return cls.from_config(profile, **kwargs) 

136 

137 def _single_execute(self) -> None: 

138 raise NotImplementedError('Single execution is not supported for SLURM Executor, everything is executed with _sequence execute.') 

139 

140 def _sequence_execute( 

141 self, 

142 experiments: List[Experiment], 

143 dry: bool = False, 

144 execution_directory: Union[Path, str] = None, 

145 ) -> List[FutureResult]: 

146 """ 

147 Execute a sequence of experiments as an array job. 

148 

149 Parameters 

150 ---------- 

151 experiments : List[Experiment] 

152 List of experiments to be executed. 

153 dry : bool, optional 

154 If True, the script will be printed instead of executed. 

155 execution_directory : Union[Path, str], optional 

156 Directory where the experiments will be executed. 

157 

158 Returns 

159 ------- 

160 List[FutureResult] 

161 List of FutureResult objects. 

162 """ 

163 

164 if len(experiments) == 0: 

165 return [] 

166 

167 # Dump config: 

168 self.directory.mkdir(parents=True, exist_ok=True) 

169 index = len(list(self.directory.glob('config*.json'))) 

170 path = self.directory / f'config_{index}.json' 

171 self.to_config(path) 

172 

173 header = self.header.copy() 

174 header.add('array', f'0-{len(experiments)-1}%{len(experiments)}') 

175 

176 body = self._make_slurm_body(experiments) 

177 

178 count = len(list(self.directory.glob('submit*.sh'))) 

179 script = SlurmScript(header, body, directory=self.directory, name=f'submit_{count}.sh') 

180 if dry: 

181 print(script) 

182 return [FutureResult() for _ in experiments] 

183 

184 script.write() 

185 

186 job_id = script.submit() 

187 for experiment in experiments: 

188 experiment.update_status('submitted') 

189 

190 return [FutureResult(info={'job_id': job_id}) for _ in experiments] 

191 

192 def _bash_array_str(self, strings: List[str]) -> str: 

193 """ 

194 Convert a list of strings into a nicely formatted bash array of string. 

195 

196 Parameters 

197 ---------- 

198 strings : List[str] 

199 List of strings to be converted. 

200 

201 Returns 

202 ------- 

203 str 

204 A python string representing a bash array of strings. 

205 """ 

206 bash_str = ' \n\t'.join(strings) 

207 return f'(\n\t{bash_str}\n)' 

208 

209 def _body_add(self, command: str, when: Optional[Literal['pre', 'post']] = None) -> None: 

210 """ 

211 Add a command to the body of the SLURM script. 

212 

213 Parameters 

214 ---------- 

215 command : str 

216 Command to be added to the body. 

217 """ 

218 if when is None: 

219 when = 'pre' 

220 

221 if when == 'pre': 

222 self._commands_pre_execution.append(command) 

223 

224 elif when == 'post': 

225 self._commands_post_execution.append(command) 

226 

227 def _make_slurm_body(self, experiments: List[Experiment]) -> SlurmBody: 

228 """ 

229 Make a new SlurmBody object. 

230 

231 Returns 

232 ------- 

233 SlurmBody 

234 A new SlurmBody object. 

235 """ 

236 

237 identifiers = [str(experiment.identifier) for experiment in experiments] 

238 directories = [str(experiment.directory.resolve()) for experiment in experiments] 

239 

240 body = SlurmBody() 

241 

242 body.add(f'directories={self._bash_array_str(directories)}') 

243 body.add(f'identifiers={self._bash_array_str(identifiers)}') 

244 

245 if self.move_output_file: 

246 self._body_add(r"mv slurm-${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}.out ${directories[$SLURM_ARRAY_TASK_ID]}", when='pre') 

247 if self.ulimit != 'default': 

248 self._body_add(f'ulimit -Su {self.ulimit}', when='pre') 

249 

250 for command in self._commands_pre_execution: 

251 body.add(command) 

252 

253 # Slurm info command: 

254 command = r'hex slurm add-info -d ${directories[$SLURM_ARRAY_TASK_ID]} -j "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}"' 

255 body.add(command) 

256 

257 # Execution command 

258 command = r'hex execute ${directories[$SLURM_ARRAY_TASK_ID]}' 

259 

260 if self.scratch: 

261 command += ' -e /scratch/$SLURM_JOB_ID' 

262 

263 body.add(command) 

264 

265 for command in self._commands_post_execution: 

266 body.add(command) 

267 

268 if self.scratch: 

269 body.add( 

270 r'cp -r /scratch/$SLURM_JOB_ID/* ${directories[$SLURM_ARRAY_TASK_ID]}' 

271 ) 

272 

273 return body 

274 

275 @staticmethod 

276 def safety_check(frame_index: int = 2) -> None: 

277 """ 

278 Check if the executor is being called from the main script. 

279 

280 Parameters 

281 ---------- 

282 frame_index : int, optional 

283 Index of the frame to be checked. Defaults to 2. 

284  

285 Raises 

286 ------ 

287 SlurmSafetyError 

288 If the executor is not being called from the main script. 

289 

290 Frame index depends on which creation method is used: 

291 - from_profile: 2 

292 - from_config: 1 

293 - __init__: 0 

294 """ 

295 

296 caller_frames = inspect.stack() 

297 caller_frame = caller_frames[frame_index] 

298 caller_module = inspect.getmodule(caller_frame[0]) 

299 

300 if caller_module and caller_module.__name__ != "__main__" or caller_module is None: 

301 raise SlurmSafetyError("""SlurmExecutor should only be called from the main script. 

302 If the you really want, you can disable this check. This error may be caused by not having 

303 a 'if __name__ == "__main__":' block in the main script.""") 

304 

305