Coverage for muutils\nbutils\convert_ipynb_to_script.py: 67%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 21:53 -0600

1"""fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.""" 

2 

3from __future__ import annotations 

4 

5import argparse 

6import json 

7import os 

8import sys 

9import typing 

10import warnings 

11 

12from muutils.spinner import SpinnerContext 

13 

14DISABLE_PLOTS: dict[str, list[str]] = { 

15 "matplotlib": [ 

16 """ 

17# ------------------------------------------------------------ 

18# Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py` 

19import matplotlib.pyplot as plt 

20plt.show = lambda: None 

21# ------------------------------------------------------------ 

22""" 

23 ], 

24 "circuitsvis": [ 

25 """ 

26# ------------------------------------------------------------ 

27# Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py` 

28from circuitsvis.utils.convert_props import PythonProperty, convert_props 

29from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local 

30 

31def new_render( 

32 react_element_name: str, 

33 **kwargs: PythonProperty 

34) -> RenderedHTML: 

35 "return a visualization as raw HTML" 

36 local_src = render_local(react_element_name, **kwargs) 

37 cdn_src = render_cdn(react_element_name, **kwargs) 

38 # return as string instead of RenderedHTML for CI 

39 return str(RenderedHTML(local_src, cdn_src)) 

40 

41render = new_render 

42# ------------------------------------------------------------ 

43""" 

44 ], 

45 "muutils": [ 

46 """import muutils.nbutils.configure_notebook as nb_conf 

47nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore" 

48""" 

49 ], 

50} 

51 

52DISABLE_PLOTS_WARNING: list[str] = [ 

53 """ 

54# ------------------------------------------------------------ 

55# WARNING: this script is auto-generated by `convert_ipynb_to_script.py` 

56# showing plots has been disabled, so this is presumably in a temp dict for CI or something 

57# so don't modify this code, it will be overwritten! 

58# ------------------------------------------------------------ 

59""".lstrip() 

60] 

61 

62 

63def disable_plots_in_script(script_lines: list[str]) -> list[str]: 

64 """Disable plots in a script by adding cursed things after the import statements""" 

65 result_str_TEMP: str = "\n\n".join(script_lines) 

66 script_lines_new: list[str] = script_lines 

67 

68 if "muutils" in result_str_TEMP: 

69 script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new 

70 

71 if "matplotlib" in result_str_TEMP: 

72 assert ( 

73 "import matplotlib.pyplot as plt" in result_str_TEMP 

74 ), "matplotlib.pyplot must be imported as plt" 

75 

76 # find the last import statement involving matplotlib, and the first line that uses plt 

77 mpl_last_import_index: int = -1 

78 mpl_first_usage_index: int = -1 

79 for i, line in enumerate(script_lines_new): 

80 if "matplotlib" in line and (("import" in line) or ("from" in line)): 

81 mpl_last_import_index = i 

82 

83 if "configure_notebook" in line: 

84 mpl_last_import_index = i 

85 

86 if "plt." in line: 

87 mpl_first_usage_index = i 

88 

89 assert ( 

90 mpl_last_import_index != -1 

91 ), f"matplotlib imports not found! see line {mpl_last_import_index}" 

92 if mpl_first_usage_index != -1: 

93 assert ( 

94 mpl_first_usage_index > mpl_last_import_index 

95 ), f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}" 

96 else: 

97 warnings.warn( 

98 "could not find where matplotlib is used, plot disabling might not work!" 

99 ) 

100 

101 # insert the cursed things 

102 script_lines_new = ( 

103 script_lines_new[: mpl_last_import_index + 1] 

104 + DISABLE_PLOTS["matplotlib"] 

105 + script_lines_new[mpl_last_import_index + 1 :] 

106 ) 

107 result_str_TEMP = "\n\n".join(script_lines_new) 

108 

109 if "circuitsvis" in result_str_TEMP: 

110 # find the last import statement involving circuitsvis, and the first line that uses it 

111 cirv_last_import_index: int = -1 

112 cirv_first_usage_index: int = -1 

113 

114 for i, line in enumerate(script_lines_new): 

115 if "circuitsvis" in line: 

116 if (("import" in line) or ("from" in line)) and "circuitsvis" in line: 

117 cirv_last_import_index = i 

118 else: 

119 cirv_first_usage_index = i 

120 

121 if "configure_notebook" in line: 

122 mpl_last_import_index = i 

123 

124 if "render" in line: 

125 cirv_first_usage_index = i 

126 

127 assert ( 

128 cirv_last_import_index != -1 

129 ), f"circuitsvis imports not found! see line {cirv_last_import_index}" 

130 if cirv_first_usage_index != -1: 

131 assert ( 

132 cirv_first_usage_index > cirv_last_import_index 

133 ), f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}" 

134 else: 

135 warnings.warn( 

136 "could not find where circuitsvis is used, plot disabling might not work!" 

137 ) 

138 

139 # insert the cursed things 

140 script_lines_new = ( 

141 script_lines_new[: cirv_last_import_index + 1] 

142 + DISABLE_PLOTS["circuitsvis"] 

143 + script_lines_new[cirv_last_import_index + 1 :] 

144 ) 

145 result_str_TEMP = "\n\n".join(script_lines_new) 

146 

147 return script_lines_new 

148 

149 

150def convert_ipynb( 

151 notebook: dict, 

152 strip_md_cells: bool = False, 

153 header_comment: str = r"#%%", 

154 disable_plots: bool = False, 

155 filter_out_lines: str | typing.Sequence[str] = ( 

156 "%", 

157 "!", 

158 ), # ignore notebook magic commands and shell commands 

159) -> str: 

160 """Convert Jupyter Notebook to a script, doing some basic filtering and formatting. 

161 

162 # Arguments 

163 - `notebook: dict`: Jupyter Notebook loaded as json. 

164 - `strip_md_cells: bool = False`: Remove markdown cells from the output script. 

165 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. 

166 - `disable_plots: bool = False`: Disable plots in the output script. 

167 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). 

168 if a string is passed, it will be split by char and each char will be treated as a separate filter. 

169 

170 # Returns 

171 - `str`: Converted script. 

172 """ 

173 

174 if isinstance(filter_out_lines, str): 

175 filter_out_lines = tuple(filter_out_lines) 

176 filter_out_lines_set: set = set(filter_out_lines) 

177 

178 result: list[str] = [] 

179 

180 all_cells: list[dict] = notebook["cells"] 

181 

182 for cell in all_cells: 

183 cell_type: str = cell["cell_type"] 

184 

185 if not strip_md_cells and cell_type == "markdown": 

186 result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""') 

187 elif cell_type == "code": 

188 source: list[str] = cell["source"] 

189 if filter_out_lines: 

190 source = [ 

191 ( 

192 f"#{line}" 

193 if any( 

194 line.startswith(filter_prefix) 

195 for filter_prefix in filter_out_lines_set 

196 ) 

197 else line 

198 ) 

199 for line in source 

200 ] 

201 result.append(f'{header_comment}\n{"".join(source)}') 

202 

203 if disable_plots: 

204 result = disable_plots_in_script(result) 

205 result = DISABLE_PLOTS_WARNING + result 

206 

207 return "\n\n".join(result) 

208 

209 

210def process_file( 

211 in_file: str, 

212 out_file: str | None = None, 

213 strip_md_cells: bool = False, 

214 header_comment: str = r"#%%", 

215 disable_plots: bool = False, 

216 filter_out_lines: str | typing.Sequence[str] = ("%", "!"), 

217): 

218 print(f"\tProcessing {in_file}...", file=sys.stderr) 

219 assert os.path.exists(in_file), f"File {in_file} does not exist." 

220 assert os.path.isfile(in_file), f"Path {in_file} is not a file." 

221 assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook." 

222 

223 with open(in_file, "r") as file: 

224 notebook: dict = json.load(file) 

225 

226 try: 

227 converted_script: str = convert_ipynb( 

228 notebook=notebook, 

229 strip_md_cells=strip_md_cells, 

230 header_comment=header_comment, 

231 disable_plots=disable_plots, 

232 filter_out_lines=filter_out_lines, 

233 ) 

234 except AssertionError as e: 

235 print(f"Error converting {in_file}: {e}", file=sys.stderr) 

236 raise e 

237 

238 if out_file: 

239 with open(out_file, "w") as file: 

240 file.write(converted_script) 

241 else: 

242 print(converted_script) 

243 

244 

245def process_dir( 

246 input_dir: str, 

247 output_dir: str, 

248 strip_md_cells: bool = False, 

249 header_comment: str = r"#%%", 

250 disable_plots: bool = False, 

251 filter_out_lines: str | typing.Sequence[str] = ("%", "!"), 

252): 

253 """Convert all Jupyter Notebooks in a directory to scripts. 

254 

255 # Arguments 

256 - `input_dir: str`: Input directory. 

257 - `output_dir: str`: Output directory. 

258 - `strip_md_cells: bool = False`: Remove markdown cells from the output script. 

259 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script. 

260 - `disable_plots: bool = False`: Disable plots in the output script. 

261 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks). 

262 if a string is passed, it will be split by char and each char will be treated as a separate filter. 

263 """ 

264 

265 assert os.path.exists(input_dir), f"Directory {input_dir} does not exist." 

266 assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory." 

267 

268 if not os.path.exists(output_dir): 

269 os.makedirs(output_dir, exist_ok=True) 

270 

271 filenames: list[str] = [ 

272 fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb") 

273 ] 

274 

275 assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks." 

276 n_files: int = len(filenames) 

277 print(f"Converting {n_files} notebooks:", file=sys.stderr) 

278 

279 with SpinnerContext( 

280 spinner_chars="braille", 

281 update_interval=0.01, 

282 format_string_when_updated=True, 

283 output_stream=sys.stderr, 

284 ) as spinner: 

285 for idx, fname in enumerate(filenames): 

286 spinner.update_value(f"\tConverting {idx+1}/{n_files}: {fname}") 

287 in_file: str = os.path.join(input_dir, fname) 

288 out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py")) 

289 

290 with open(in_file, "r", encoding="utf-8") as file_in: 

291 notebook: dict = json.load(file_in) 

292 

293 try: 

294 converted_script: str = convert_ipynb( 

295 notebook=notebook, 

296 strip_md_cells=strip_md_cells, 

297 header_comment=header_comment, 

298 disable_plots=disable_plots, 

299 filter_out_lines=filter_out_lines, 

300 ) 

301 except AssertionError as e: 

302 spinner.stop() 

303 raise Exception(f"Error converting {in_file}") from e 

304 

305 with open(out_file, "w", encoding="utf-8") as file_out: 

306 file_out.write(converted_script) 

307 

308 

309if __name__ == "__main__": 

310 parser = argparse.ArgumentParser( 

311 description="Convert Jupyter Notebook to a script with cell separators." 

312 ) 

313 parser.add_argument( 

314 "in_path", 

315 type=str, 

316 help="Input Jupyter Notebook file (.ipynb) or directory of files.", 

317 ) 

318 parser.add_argument( 

319 "--out_file", 

320 type=str, 

321 help="Output script file. If not specified, the result will be printed to stdout.", 

322 ) 

323 parser.add_argument( 

324 "--output_dir", type=str, help="Output directory for converted script files." 

325 ) 

326 parser.add_argument( 

327 "--strip_md_cells", 

328 action="store_true", 

329 help="Remove markdown cells from the output script.", 

330 ) 

331 parser.add_argument( 

332 "--header_comment", 

333 type=str, 

334 default=r"#%%", 

335 help="Comment string to separate cells in the output script.", 

336 ) 

337 parser.add_argument( 

338 "--disable_plots", 

339 action="store_true", 

340 help="Disable plots in the output script. Useful for testing in CI.", 

341 ) 

342 parser.add_argument( 

343 "--filter_out_lines", 

344 type=str, 

345 default="%", 

346 help="Comment out lines starting with these characters.", 

347 ) 

348 

349 args = parser.parse_args() 

350 

351 if args.output_dir: 

352 assert not args.out_file, "Cannot specify both --out_file and --output_dir." 

353 process_dir( 

354 input_dir=args.in_path, 

355 output_dir=args.output_dir, 

356 strip_md_cells=args.strip_md_cells, 

357 header_comment=args.header_comment, 

358 disable_plots=args.disable_plots, 

359 filter_out_lines=args.filter_out_lines, 

360 ) 

361 

362 else: 

363 process_file( 

364 in_file=args.in_path, 

365 out_file=args.out_file, 

366 strip_md_cells=args.strip_md_cells, 

367 header_comment=args.header_comment, 

368 disable_plots=args.disable_plots, 

369 filter_out_lines=args.filter_out_lines, 

370 ) 

371 

372 

373print("convert_ipynb_to_script.py loaded.")