Coverage for muutils\nbutils\convert_ipynb_to_script.py: 67%
124 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
1"""fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting."""
3from __future__ import annotations
5import argparse
6import json
7import os
8import sys
9import typing
10import warnings
12from muutils.spinner import SpinnerContext
14DISABLE_PLOTS: dict[str, list[str]] = {
15 "matplotlib": [
16 """
17# ------------------------------------------------------------
18# Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py`
19import matplotlib.pyplot as plt
20plt.show = lambda: None
21# ------------------------------------------------------------
22"""
23 ],
24 "circuitsvis": [
25 """
26# ------------------------------------------------------------
27# Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py`
28from circuitsvis.utils.convert_props import PythonProperty, convert_props
29from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local
31def new_render(
32 react_element_name: str,
33 **kwargs: PythonProperty
34) -> RenderedHTML:
35 "return a visualization as raw HTML"
36 local_src = render_local(react_element_name, **kwargs)
37 cdn_src = render_cdn(react_element_name, **kwargs)
38 # return as string instead of RenderedHTML for CI
39 return str(RenderedHTML(local_src, cdn_src))
41render = new_render
42# ------------------------------------------------------------
43"""
44 ],
45 "muutils": [
46 """import muutils.nbutils.configure_notebook as nb_conf
47nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"
48"""
49 ],
50}
52DISABLE_PLOTS_WARNING: list[str] = [
53 """
54# ------------------------------------------------------------
55# WARNING: this script is auto-generated by `convert_ipynb_to_script.py`
56# showing plots has been disabled, so this is presumably in a temp dict for CI or something
57# so don't modify this code, it will be overwritten!
58# ------------------------------------------------------------
59""".lstrip()
60]
63def disable_plots_in_script(script_lines: list[str]) -> list[str]:
64 """Disable plots in a script by adding cursed things after the import statements"""
65 result_str_TEMP: str = "\n\n".join(script_lines)
66 script_lines_new: list[str] = script_lines
68 if "muutils" in result_str_TEMP:
69 script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new
71 if "matplotlib" in result_str_TEMP:
72 assert (
73 "import matplotlib.pyplot as plt" in result_str_TEMP
74 ), "matplotlib.pyplot must be imported as plt"
76 # find the last import statement involving matplotlib, and the first line that uses plt
77 mpl_last_import_index: int = -1
78 mpl_first_usage_index: int = -1
79 for i, line in enumerate(script_lines_new):
80 if "matplotlib" in line and (("import" in line) or ("from" in line)):
81 mpl_last_import_index = i
83 if "configure_notebook" in line:
84 mpl_last_import_index = i
86 if "plt." in line:
87 mpl_first_usage_index = i
89 assert (
90 mpl_last_import_index != -1
91 ), f"matplotlib imports not found! see line {mpl_last_import_index}"
92 if mpl_first_usage_index != -1:
93 assert (
94 mpl_first_usage_index > mpl_last_import_index
95 ), f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}"
96 else:
97 warnings.warn(
98 "could not find where matplotlib is used, plot disabling might not work!"
99 )
101 # insert the cursed things
102 script_lines_new = (
103 script_lines_new[: mpl_last_import_index + 1]
104 + DISABLE_PLOTS["matplotlib"]
105 + script_lines_new[mpl_last_import_index + 1 :]
106 )
107 result_str_TEMP = "\n\n".join(script_lines_new)
109 if "circuitsvis" in result_str_TEMP:
110 # find the last import statement involving circuitsvis, and the first line that uses it
111 cirv_last_import_index: int = -1
112 cirv_first_usage_index: int = -1
114 for i, line in enumerate(script_lines_new):
115 if "circuitsvis" in line:
116 if (("import" in line) or ("from" in line)) and "circuitsvis" in line:
117 cirv_last_import_index = i
118 else:
119 cirv_first_usage_index = i
121 if "configure_notebook" in line:
122 mpl_last_import_index = i
124 if "render" in line:
125 cirv_first_usage_index = i
127 assert (
128 cirv_last_import_index != -1
129 ), f"circuitsvis imports not found! see line {cirv_last_import_index}"
130 if cirv_first_usage_index != -1:
131 assert (
132 cirv_first_usage_index > cirv_last_import_index
133 ), f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}"
134 else:
135 warnings.warn(
136 "could not find where circuitsvis is used, plot disabling might not work!"
137 )
139 # insert the cursed things
140 script_lines_new = (
141 script_lines_new[: cirv_last_import_index + 1]
142 + DISABLE_PLOTS["circuitsvis"]
143 + script_lines_new[cirv_last_import_index + 1 :]
144 )
145 result_str_TEMP = "\n\n".join(script_lines_new)
147 return script_lines_new
150def convert_ipynb(
151 notebook: dict,
152 strip_md_cells: bool = False,
153 header_comment: str = r"#%%",
154 disable_plots: bool = False,
155 filter_out_lines: str | typing.Sequence[str] = (
156 "%",
157 "!",
158 ), # ignore notebook magic commands and shell commands
159) -> str:
160 """Convert Jupyter Notebook to a script, doing some basic filtering and formatting.
162 # Arguments
163 - `notebook: dict`: Jupyter Notebook loaded as json.
164 - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
165 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
166 - `disable_plots: bool = False`: Disable plots in the output script.
167 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
168 if a string is passed, it will be split by char and each char will be treated as a separate filter.
170 # Returns
171 - `str`: Converted script.
172 """
174 if isinstance(filter_out_lines, str):
175 filter_out_lines = tuple(filter_out_lines)
176 filter_out_lines_set: set = set(filter_out_lines)
178 result: list[str] = []
180 all_cells: list[dict] = notebook["cells"]
182 for cell in all_cells:
183 cell_type: str = cell["cell_type"]
185 if not strip_md_cells and cell_type == "markdown":
186 result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""')
187 elif cell_type == "code":
188 source: list[str] = cell["source"]
189 if filter_out_lines:
190 source = [
191 (
192 f"#{line}"
193 if any(
194 line.startswith(filter_prefix)
195 for filter_prefix in filter_out_lines_set
196 )
197 else line
198 )
199 for line in source
200 ]
201 result.append(f'{header_comment}\n{"".join(source)}')
203 if disable_plots:
204 result = disable_plots_in_script(result)
205 result = DISABLE_PLOTS_WARNING + result
207 return "\n\n".join(result)
210def process_file(
211 in_file: str,
212 out_file: str | None = None,
213 strip_md_cells: bool = False,
214 header_comment: str = r"#%%",
215 disable_plots: bool = False,
216 filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
217):
218 print(f"\tProcessing {in_file}...", file=sys.stderr)
219 assert os.path.exists(in_file), f"File {in_file} does not exist."
220 assert os.path.isfile(in_file), f"Path {in_file} is not a file."
221 assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook."
223 with open(in_file, "r") as file:
224 notebook: dict = json.load(file)
226 try:
227 converted_script: str = convert_ipynb(
228 notebook=notebook,
229 strip_md_cells=strip_md_cells,
230 header_comment=header_comment,
231 disable_plots=disable_plots,
232 filter_out_lines=filter_out_lines,
233 )
234 except AssertionError as e:
235 print(f"Error converting {in_file}: {e}", file=sys.stderr)
236 raise e
238 if out_file:
239 with open(out_file, "w") as file:
240 file.write(converted_script)
241 else:
242 print(converted_script)
245def process_dir(
246 input_dir: str,
247 output_dir: str,
248 strip_md_cells: bool = False,
249 header_comment: str = r"#%%",
250 disable_plots: bool = False,
251 filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
252):
253 """Convert all Jupyter Notebooks in a directory to scripts.
255 # Arguments
256 - `input_dir: str`: Input directory.
257 - `output_dir: str`: Output directory.
258 - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
259 - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
260 - `disable_plots: bool = False`: Disable plots in the output script.
261 - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
262 if a string is passed, it will be split by char and each char will be treated as a separate filter.
263 """
265 assert os.path.exists(input_dir), f"Directory {input_dir} does not exist."
266 assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory."
268 if not os.path.exists(output_dir):
269 os.makedirs(output_dir, exist_ok=True)
271 filenames: list[str] = [
272 fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb")
273 ]
275 assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks."
276 n_files: int = len(filenames)
277 print(f"Converting {n_files} notebooks:", file=sys.stderr)
279 with SpinnerContext(
280 spinner_chars="braille",
281 update_interval=0.01,
282 format_string_when_updated=True,
283 output_stream=sys.stderr,
284 ) as spinner:
285 for idx, fname in enumerate(filenames):
286 spinner.update_value(f"\tConverting {idx+1}/{n_files}: {fname}")
287 in_file: str = os.path.join(input_dir, fname)
288 out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py"))
290 with open(in_file, "r", encoding="utf-8") as file_in:
291 notebook: dict = json.load(file_in)
293 try:
294 converted_script: str = convert_ipynb(
295 notebook=notebook,
296 strip_md_cells=strip_md_cells,
297 header_comment=header_comment,
298 disable_plots=disable_plots,
299 filter_out_lines=filter_out_lines,
300 )
301 except AssertionError as e:
302 spinner.stop()
303 raise Exception(f"Error converting {in_file}") from e
305 with open(out_file, "w", encoding="utf-8") as file_out:
306 file_out.write(converted_script)
309if __name__ == "__main__":
310 parser = argparse.ArgumentParser(
311 description="Convert Jupyter Notebook to a script with cell separators."
312 )
313 parser.add_argument(
314 "in_path",
315 type=str,
316 help="Input Jupyter Notebook file (.ipynb) or directory of files.",
317 )
318 parser.add_argument(
319 "--out_file",
320 type=str,
321 help="Output script file. If not specified, the result will be printed to stdout.",
322 )
323 parser.add_argument(
324 "--output_dir", type=str, help="Output directory for converted script files."
325 )
326 parser.add_argument(
327 "--strip_md_cells",
328 action="store_true",
329 help="Remove markdown cells from the output script.",
330 )
331 parser.add_argument(
332 "--header_comment",
333 type=str,
334 default=r"#%%",
335 help="Comment string to separate cells in the output script.",
336 )
337 parser.add_argument(
338 "--disable_plots",
339 action="store_true",
340 help="Disable plots in the output script. Useful for testing in CI.",
341 )
342 parser.add_argument(
343 "--filter_out_lines",
344 type=str,
345 default="%",
346 help="Comment out lines starting with these characters.",
347 )
349 args = parser.parse_args()
351 if args.output_dir:
352 assert not args.out_file, "Cannot specify both --out_file and --output_dir."
353 process_dir(
354 input_dir=args.in_path,
355 output_dir=args.output_dir,
356 strip_md_cells=args.strip_md_cells,
357 header_comment=args.header_comment,
358 disable_plots=args.disable_plots,
359 filter_out_lines=args.filter_out_lines,
360 )
362 else:
363 process_file(
364 in_file=args.in_path,
365 out_file=args.out_file,
366 strip_md_cells=args.strip_md_cells,
367 header_comment=args.header_comment,
368 disable_plots=args.disable_plots,
369 filter_out_lines=args.filter_out_lines,
370 )
373print("convert_ipynb_to_script.py loaded.")