muutils.nbutils.configure_notebook
shared utilities for setting up a notebook
1"""shared utilities for setting up a notebook""" 2 3from __future__ import annotations 4 5import os 6import typing 7import warnings 8 9import matplotlib.pyplot as plt # type: ignore[import] 10 11 12class PlotlyNotInstalledWarning(UserWarning): 13 pass 14 15 16# handle plotly importing 17PLOTLY_IMPORTED: bool 18try: 19 import plotly.io as pio # type: ignore[import] 20except ImportError: 21 warnings.warn( 22 "Plotly not installed. Plotly plots will not be available.", 23 PlotlyNotInstalledWarning, 24 ) 25 PLOTLY_IMPORTED = False 26else: 27 PLOTLY_IMPORTED = True 28 29# figure out if we're in a jupyter notebook 30try: 31 from IPython import get_ipython # type: ignore[import-not-found] 32 33 IN_JUPYTER = get_ipython() is not None 34except ImportError: 35 IN_JUPYTER = False 36 37# muutils imports 38from muutils.mlutils import get_device, set_reproducibility # noqa: E402 39 40# handling figures 41PlottingMode = typing.Literal["ignore", "inline", "widget", "save"] 42PLOT_MODE: PlottingMode = "inline" 43CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None 44FIG_COUNTER: int = 0 45FIG_OUTPUT_FMT: str | None = None 46FIG_NUMBERED_FNAME: str = "figure-{num}" 47FIG_CONFIG: dict | None = None 48FIG_BASEPATH: str | None = None 49CLOSE_AFTER_PLOTSHOW: bool = False 50 51MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"] 52TIKZPLOTLIB_FORMATS = ["tex", "tikz"] 53 54 55class UnknownFigureFormatWarning(UserWarning): 56 pass 57 58 59def universal_savefig(fname: str, fmt: str | None = None) -> None: 60 # try to infer format from fname 61 if fmt is None: 62 fmt = fname.split(".")[-1] 63 64 if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS): 65 warnings.warn( 66 f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'", 67 UnknownFigureFormatWarning, 68 ) 69 fmt = FIG_OUTPUT_FMT 70 71 # not sure why linting is throwing an error here 72 if not fname.endswith(fmt): # type: ignore[arg-type] 73 fname += f".{fmt}" 74 75 if fmt in MATPLOTLIB_FORMATS: 76 plt.savefig(fname, format=fmt, bbox_inches="tight") 77 elif fmt in TIKZPLOTLIB_FORMATS: 78 import tikzplotlib # type: ignore[import] 79 80 tikzplotlib.save(fname) 81 else: 82 warnings.warn(f"Unknown format '{fmt}', going with matplotlib default") 83 plt.savefig(fname, bbox_inches="tight") 84 85 86def setup_plots( 87 plot_mode: PlottingMode = "inline", 88 fig_output_fmt: str | None = "pdf", 89 fig_numbered_fname: str = "figure-{num}", 90 fig_config: dict | None = None, 91 fig_basepath: str | None = None, 92 close_after_plotshow: bool = False, 93) -> None: 94 """Set up plot saving/rendering options""" 95 global \ 96 PLOT_MODE, \ 97 CONVERSION_PLOTMODE_OVERRIDE, \ 98 FIG_COUNTER, \ 99 FIG_OUTPUT_FMT, \ 100 FIG_NUMBERED_FNAME, \ 101 FIG_CONFIG, \ 102 FIG_BASEPATH, \ 103 CLOSE_AFTER_PLOTSHOW 104 105 # set plot mode, handling override 106 if CONVERSION_PLOTMODE_OVERRIDE is not None: 107 # override if set 108 PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE 109 else: 110 # otherwise use the given plot mode 111 PLOT_MODE = plot_mode 112 113 FIG_COUNTER = 0 114 CLOSE_AFTER_PLOTSHOW = close_after_plotshow 115 116 if PLOT_MODE == "inline": 117 if IN_JUPYTER: 118 ipython = get_ipython() 119 ipython.magic("matplotlib inline") 120 else: 121 raise RuntimeError( 122 f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }" 123 ) 124 return 125 elif PLOT_MODE == "widget": 126 if IN_JUPYTER: 127 ipython = get_ipython() 128 ipython.magic("matplotlib widget") 129 else: 130 # matplotlib outside of jupyter will bring up a new window by default 131 pass 132 return 133 elif PLOT_MODE == "ignore": 134 # disable plotting 135 plt.show = lambda: None # type: ignore[misc] 136 return 137 138 # everything except saving handled up to this point 139 assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}" 140 141 FIG_OUTPUT_FMT = fig_output_fmt 142 FIG_NUMBERED_FNAME = fig_numbered_fname 143 FIG_CONFIG = fig_config 144 145 # set default figure format in rcParams savefig.format 146 plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT 147 if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS: 148 try: 149 import tikzplotlib # type: ignore[import] # noqa: F401 150 except ImportError: 151 warnings.warn( 152 f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break." 153 ) 154 else: 155 if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS: 156 warnings.warn( 157 f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }' 158 ) 159 160 # if base path not given, make one 161 if fig_basepath is None: 162 if fig_config is None: 163 # if no config, use the current time 164 from datetime import datetime 165 166 fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" 167 else: 168 # if config given, convert to string 169 from muutils.misc import dict_to_filename 170 171 fig_basepath = f"figures/{dict_to_filename(fig_config)}" 172 173 FIG_BASEPATH = fig_basepath 174 os.makedirs(fig_basepath, exist_ok=True) 175 176 # if config given, serialize and save that config 177 if fig_config is not None: 178 import json 179 180 from muutils.json_serialize import json_serialize 181 182 with open(f"{fig_basepath}/config.json", "w") as f: 183 json.dump( 184 json_serialize(fig_config), 185 f, 186 indent="\t", 187 ) 188 189 print(f"Figures will be saved to: '{fig_basepath}'") 190 191 192def configure_notebook( 193 *args, 194 seed: int = 42, 195 device: typing.Any = None, # this can be a string, torch.device, or None 196 dark_mode: bool = True, 197 plot_mode: PlottingMode = "inline", 198 fig_output_fmt: str | None = "pdf", 199 fig_numbered_fname: str = "figure-{num}", 200 fig_config: dict | None = None, 201 fig_basepath: str | None = None, 202 close_after_plotshow: bool = False, 203) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821 204 """Shared Jupyter notebook setup steps 205 206 - Set random seeds and library reproducibility settings 207 - Set device based on availability 208 - Set module reloading before code execution 209 - Set plot formatting 210 - Set plot saving/rendering options 211 212 # Parameters: 213 - `seed : int` 214 random seed across libraries including torch, numpy, and random (defaults to `42`) 215 (defaults to `42`) 216 - `device : typing.Any` 217 pytorch device to use 218 (defaults to `None`) 219 - `dark_mode : bool` 220 figures in dark mode 221 (defaults to `True`) 222 - `plot_mode : PlottingMode` 223 how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]` 224 (defaults to `"inline"`) 225 - `fig_output_fmt : str | None` 226 format for saving figures 227 (defaults to `"pdf"`) 228 - `fig_numbered_fname : str` 229 format for saving figures with numbers (if they aren't named) 230 (defaults to `"figure-{num}"`) 231 - `fig_config : dict | None` 232 metadata to save with the figures 233 (defaults to `None`) 234 - `fig_basepath : str | None` 235 base path for saving figures 236 (defaults to `None`) 237 - `close_after_plotshow : bool` 238 close figures after showing them 239 (defaults to `False`) 240 241 # Returns: 242 - `torch.device|None` 243 the device set, if torch is installed 244 """ 245 246 # set some globals related to plotting 247 setup_plots( 248 plot_mode=plot_mode, 249 fig_output_fmt=fig_output_fmt, 250 fig_numbered_fname=fig_numbered_fname, 251 fig_config=fig_config, 252 fig_basepath=fig_basepath, 253 close_after_plotshow=close_after_plotshow, 254 ) 255 256 global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH 257 258 print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }") 259 260 # Set seeds and other reproducibility-related library options 261 set_reproducibility(seed) 262 263 # Reload modules before executing user code 264 if IN_JUPYTER: 265 ipython = get_ipython() 266 if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded: 267 ipython.magic("load_ext autoreload") 268 ipython.magic("autoreload 2") 269 270 # Specify plotly renderer for vscode 271 if PLOTLY_IMPORTED: 272 pio.renderers.default = "notebook_connected" 273 274 if dark_mode: 275 pio.templates.default = "plotly_dark" 276 plt.style.use("dark_background") 277 278 try: 279 # Set device 280 device = get_device(device) 281 return device 282 except ImportError: 283 warnings.warn("Torch not installed. Cannot get/set device.") 284 return None 285 286 287def plotshow( 288 fname: str | None = None, 289 plot_mode: PlottingMode | None = None, 290 fmt: str | None = None, 291): 292 """Show the active plot, depending on global configs""" 293 global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE 294 FIG_COUNTER += 1 295 296 if plot_mode is None: 297 plot_mode = PLOT_MODE 298 299 if plot_mode == "save": 300 # get numbered figure name if not given 301 if fname is None: 302 fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER) 303 304 # save figure 305 assert FIG_BASEPATH is not None 306 universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt) 307 elif plot_mode == "ignore": 308 # do nothing 309 pass 310 elif plot_mode == "inline": 311 # show figure 312 plt.show() 313 elif plot_mode == "widget": 314 # show figure 315 plt.show() 316 else: 317 warnings.warn(f"Invalid plot mode: {plot_mode}") 318 319 if CLOSE_AFTER_PLOTSHOW: 320 plt.close()
class
PlotlyNotInstalledWarning(builtins.UserWarning):
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
PLOTLY_IMPORTED: bool =
True
PlottingMode =
typing.Literal['ignore', 'inline', 'widget', 'save']
PLOT_MODE: Literal['ignore', 'inline', 'widget', 'save'] =
'inline'
CONVERSION_PLOTMODE_OVERRIDE: Optional[Literal['ignore', 'inline', 'widget', 'save']] =
None
FIG_COUNTER: int =
0
FIG_OUTPUT_FMT: str | None =
None
FIG_NUMBERED_FNAME: str =
'figure-{num}'
FIG_CONFIG: dict | None =
None
FIG_BASEPATH: str | None =
None
CLOSE_AFTER_PLOTSHOW: bool =
False
MATPLOTLIB_FORMATS =
['pdf', 'png', 'jpg', 'jpeg', 'svg', 'eps', 'ps', 'tif', 'tiff']
TIKZPLOTLIB_FORMATS =
['tex', 'tikz']
class
UnknownFigureFormatWarning(builtins.UserWarning):
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
def
universal_savefig(fname: str, fmt: str | None = None) -> None:
60def universal_savefig(fname: str, fmt: str | None = None) -> None: 61 # try to infer format from fname 62 if fmt is None: 63 fmt = fname.split(".")[-1] 64 65 if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS): 66 warnings.warn( 67 f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'", 68 UnknownFigureFormatWarning, 69 ) 70 fmt = FIG_OUTPUT_FMT 71 72 # not sure why linting is throwing an error here 73 if not fname.endswith(fmt): # type: ignore[arg-type] 74 fname += f".{fmt}" 75 76 if fmt in MATPLOTLIB_FORMATS: 77 plt.savefig(fname, format=fmt, bbox_inches="tight") 78 elif fmt in TIKZPLOTLIB_FORMATS: 79 import tikzplotlib # type: ignore[import] 80 81 tikzplotlib.save(fname) 82 else: 83 warnings.warn(f"Unknown format '{fmt}', going with matplotlib default") 84 plt.savefig(fname, bbox_inches="tight")
def
setup_plots( plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline', fig_output_fmt: str | None = 'pdf', fig_numbered_fname: str = 'figure-{num}', fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False) -> None:
87def setup_plots( 88 plot_mode: PlottingMode = "inline", 89 fig_output_fmt: str | None = "pdf", 90 fig_numbered_fname: str = "figure-{num}", 91 fig_config: dict | None = None, 92 fig_basepath: str | None = None, 93 close_after_plotshow: bool = False, 94) -> None: 95 """Set up plot saving/rendering options""" 96 global \ 97 PLOT_MODE, \ 98 CONVERSION_PLOTMODE_OVERRIDE, \ 99 FIG_COUNTER, \ 100 FIG_OUTPUT_FMT, \ 101 FIG_NUMBERED_FNAME, \ 102 FIG_CONFIG, \ 103 FIG_BASEPATH, \ 104 CLOSE_AFTER_PLOTSHOW 105 106 # set plot mode, handling override 107 if CONVERSION_PLOTMODE_OVERRIDE is not None: 108 # override if set 109 PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE 110 else: 111 # otherwise use the given plot mode 112 PLOT_MODE = plot_mode 113 114 FIG_COUNTER = 0 115 CLOSE_AFTER_PLOTSHOW = close_after_plotshow 116 117 if PLOT_MODE == "inline": 118 if IN_JUPYTER: 119 ipython = get_ipython() 120 ipython.magic("matplotlib inline") 121 else: 122 raise RuntimeError( 123 f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }" 124 ) 125 return 126 elif PLOT_MODE == "widget": 127 if IN_JUPYTER: 128 ipython = get_ipython() 129 ipython.magic("matplotlib widget") 130 else: 131 # matplotlib outside of jupyter will bring up a new window by default 132 pass 133 return 134 elif PLOT_MODE == "ignore": 135 # disable plotting 136 plt.show = lambda: None # type: ignore[misc] 137 return 138 139 # everything except saving handled up to this point 140 assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}" 141 142 FIG_OUTPUT_FMT = fig_output_fmt 143 FIG_NUMBERED_FNAME = fig_numbered_fname 144 FIG_CONFIG = fig_config 145 146 # set default figure format in rcParams savefig.format 147 plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT 148 if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS: 149 try: 150 import tikzplotlib # type: ignore[import] # noqa: F401 151 except ImportError: 152 warnings.warn( 153 f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break." 154 ) 155 else: 156 if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS: 157 warnings.warn( 158 f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }' 159 ) 160 161 # if base path not given, make one 162 if fig_basepath is None: 163 if fig_config is None: 164 # if no config, use the current time 165 from datetime import datetime 166 167 fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" 168 else: 169 # if config given, convert to string 170 from muutils.misc import dict_to_filename 171 172 fig_basepath = f"figures/{dict_to_filename(fig_config)}" 173 174 FIG_BASEPATH = fig_basepath 175 os.makedirs(fig_basepath, exist_ok=True) 176 177 # if config given, serialize and save that config 178 if fig_config is not None: 179 import json 180 181 from muutils.json_serialize import json_serialize 182 183 with open(f"{fig_basepath}/config.json", "w") as f: 184 json.dump( 185 json_serialize(fig_config), 186 f, 187 indent="\t", 188 ) 189 190 print(f"Figures will be saved to: '{fig_basepath}'")
Set up plot saving/rendering options
def
configure_notebook( *args, seed: int = 42, device: Any = None, dark_mode: bool = True, plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline', fig_output_fmt: str | None = 'pdf', fig_numbered_fname: str = 'figure-{num}', fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False) -> torch.device | None:
193def configure_notebook( 194 *args, 195 seed: int = 42, 196 device: typing.Any = None, # this can be a string, torch.device, or None 197 dark_mode: bool = True, 198 plot_mode: PlottingMode = "inline", 199 fig_output_fmt: str | None = "pdf", 200 fig_numbered_fname: str = "figure-{num}", 201 fig_config: dict | None = None, 202 fig_basepath: str | None = None, 203 close_after_plotshow: bool = False, 204) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821 205 """Shared Jupyter notebook setup steps 206 207 - Set random seeds and library reproducibility settings 208 - Set device based on availability 209 - Set module reloading before code execution 210 - Set plot formatting 211 - Set plot saving/rendering options 212 213 # Parameters: 214 - `seed : int` 215 random seed across libraries including torch, numpy, and random (defaults to `42`) 216 (defaults to `42`) 217 - `device : typing.Any` 218 pytorch device to use 219 (defaults to `None`) 220 - `dark_mode : bool` 221 figures in dark mode 222 (defaults to `True`) 223 - `plot_mode : PlottingMode` 224 how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]` 225 (defaults to `"inline"`) 226 - `fig_output_fmt : str | None` 227 format for saving figures 228 (defaults to `"pdf"`) 229 - `fig_numbered_fname : str` 230 format for saving figures with numbers (if they aren't named) 231 (defaults to `"figure-{num}"`) 232 - `fig_config : dict | None` 233 metadata to save with the figures 234 (defaults to `None`) 235 - `fig_basepath : str | None` 236 base path for saving figures 237 (defaults to `None`) 238 - `close_after_plotshow : bool` 239 close figures after showing them 240 (defaults to `False`) 241 242 # Returns: 243 - `torch.device|None` 244 the device set, if torch is installed 245 """ 246 247 # set some globals related to plotting 248 setup_plots( 249 plot_mode=plot_mode, 250 fig_output_fmt=fig_output_fmt, 251 fig_numbered_fname=fig_numbered_fname, 252 fig_config=fig_config, 253 fig_basepath=fig_basepath, 254 close_after_plotshow=close_after_plotshow, 255 ) 256 257 global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH 258 259 print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }") 260 261 # Set seeds and other reproducibility-related library options 262 set_reproducibility(seed) 263 264 # Reload modules before executing user code 265 if IN_JUPYTER: 266 ipython = get_ipython() 267 if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded: 268 ipython.magic("load_ext autoreload") 269 ipython.magic("autoreload 2") 270 271 # Specify plotly renderer for vscode 272 if PLOTLY_IMPORTED: 273 pio.renderers.default = "notebook_connected" 274 275 if dark_mode: 276 pio.templates.default = "plotly_dark" 277 plt.style.use("dark_background") 278 279 try: 280 # Set device 281 device = get_device(device) 282 return device 283 except ImportError: 284 warnings.warn("Torch not installed. Cannot get/set device.") 285 return None
Shared Jupyter notebook setup steps
- Set random seeds and library reproducibility settings
- Set device based on availability
- Set module reloading before code execution
- Set plot formatting
- Set plot saving/rendering options
Parameters:
seed : int
random seed across libraries including torch, numpy, and random (defaults to42
) (defaults to42
)device : typing.Any
pytorch device to use (defaults toNone
)dark_mode : bool
figures in dark mode (defaults toTrue
)plot_mode : PlottingMode
how to display plots, one ofPlottingMode
or["ignore", "inline", "widget", "save"]
(defaults to"inline"
)fig_output_fmt : str | None
format for saving figures (defaults to"pdf"
)fig_numbered_fname : str
format for saving figures with numbers (if they aren't named) (defaults to"figure-{num}"
)fig_config : dict | None
metadata to save with the figures (defaults toNone
)fig_basepath : str | None
base path for saving figures (defaults toNone
)close_after_plotshow : bool
close figures after showing them (defaults toFalse
)
Returns:
torch.device|None
the device set, if torch is installed
def
plotshow( fname: str | None = None, plot_mode: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None, fmt: str | None = None):
288def plotshow( 289 fname: str | None = None, 290 plot_mode: PlottingMode | None = None, 291 fmt: str | None = None, 292): 293 """Show the active plot, depending on global configs""" 294 global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE 295 FIG_COUNTER += 1 296 297 if plot_mode is None: 298 plot_mode = PLOT_MODE 299 300 if plot_mode == "save": 301 # get numbered figure name if not given 302 if fname is None: 303 fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER) 304 305 # save figure 306 assert FIG_BASEPATH is not None 307 universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt) 308 elif plot_mode == "ignore": 309 # do nothing 310 pass 311 elif plot_mode == "inline": 312 # show figure 313 plt.show() 314 elif plot_mode == "widget": 315 # show figure 316 plt.show() 317 else: 318 warnings.warn(f"Invalid plot mode: {plot_mode}") 319 320 if CLOSE_AFTER_PLOTSHOW: 321 plt.close()
Show the active plot, depending on global configs