muutils.dbg
this code is based on an implementation of the Rust builtin dbg!
for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py
although it has been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1""" 2 3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 5although it has been significantly modified 6 7licensed under MIT: 8 9Copyright (c) 2019 Tyler Wince 10 11Permission is hereby granted, free of charge, to any person obtaining a copy 12of this software and associated documentation files (the "Software"), to deal 13in the Software without restriction, including without limitation the rights 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15copies of the Software, and to permit persons to whom the Software is 16furnished to do so, subject to the following conditions: 17 18The above copyright notice and this permission notice shall be included in 19all copies or substantial portions of the Software. 20 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27THE SOFTWARE. 28 29""" 30 31from __future__ import annotations 32 33import inspect 34import sys 35import typing 36from pathlib import Path 37import functools 38 39# type defs 40_ExpType = typing.TypeVar("_ExpType") 41 42 43# Sentinel type for no expression passed 44class _NoExpPassedSentinel: 45 """Unique sentinel type used to indicate that no expression was passed.""" 46 47 pass 48 49 50_NoExpPassed = _NoExpPassedSentinel() 51 52# global variables 53_CWD: Path = Path.cwd().absolute() 54_COUNTER: int = 0 55 56# configuration 57PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 58DEFAULT_VAL_JOINER: str = " = " 59 60 61# path processing 62def _process_path(path: Path) -> str: 63 path_abs: Path = path.absolute() 64 fname: Path 65 if PATH_MODE == "absolute": 66 fname = path_abs 67 elif PATH_MODE == "relative": 68 try: 69 # if it's inside the cwd, print the relative path 70 fname = path.relative_to(_CWD) 71 except ValueError: 72 # if its not in the subpath, use the absolute path 73 fname = path_abs 74 else: 75 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 76 77 return fname.as_posix() 78 79 80# actual dbg function 81@typing.overload 82def dbg() -> _NoExpPassedSentinel: ... 83@typing.overload 84def dbg( 85 exp: _NoExpPassedSentinel, 86 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 87 val_joiner: str = DEFAULT_VAL_JOINER, 88) -> _NoExpPassedSentinel: ... 89@typing.overload 90def dbg( 91 exp: _ExpType, 92 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 93 val_joiner: str = DEFAULT_VAL_JOINER, 94) -> _ExpType: ... 95def dbg( 96 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 97 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 98 val_joiner: str = DEFAULT_VAL_JOINER, 99) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 100 """Call dbg with any variable or expression. 101 102 Calling dbg will print to stderr the current filename and lineno, 103 as well as the passed expression and what the expression evaluates to: 104 105 from muutils.dbg import dbg 106 107 a = 2 108 b = 5 109 110 dbg(a+b) 111 112 def square(x: int) -> int: 113 return x * x 114 115 dbg(square(a)) 116 117 """ 118 global _COUNTER 119 120 # get the context 121 line_exp: str = "unknown" 122 current_file: str = "unknown" 123 dbg_frame: typing.Optional[inspect.FrameInfo] = None 124 for frame in inspect.stack(): 125 if frame.code_context is None: 126 continue 127 line: str = frame.code_context[0] 128 if "dbg" in line: 129 current_file = _process_path(Path(frame.filename)) 130 dbg_frame = frame 131 start: int = line.find("(") + 1 132 end: int = line.rfind(")") 133 if end == -1: 134 end = len(line) 135 line_exp = line[start:end] 136 break 137 138 fname: str = "unknown" 139 if current_file.startswith("/tmp/ipykernel_"): 140 stack: list[inspect.FrameInfo] = inspect.stack() 141 filtered_functions: list[str] = [] 142 # this loop will find, in this order: 143 # - the dbg function call 144 # - the functions we care about displaying 145 # - `<module>` 146 # - a bunch of jupyter internals we don't care about 147 for frame_info in stack: 148 if _process_path(Path(frame_info.filename)) != current_file: 149 continue 150 if frame_info.function == "<module>": 151 break 152 if frame_info.function.startswith("dbg"): 153 continue 154 filtered_functions.append(frame_info.function) 155 if dbg_frame is not None: 156 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 157 else: 158 filtered_functions.append(current_file) 159 filtered_functions.reverse() 160 fname = " -> ".join(filtered_functions) 161 elif dbg_frame is not None: 162 fname = f"{current_file}:{dbg_frame.lineno}" 163 164 # assemble the message 165 msg: str 166 if exp is _NoExpPassed: 167 # if no expression is passed, just show location and counter value 168 msg = f"[ {fname} ] <dbg {_COUNTER}>" 169 _COUNTER += 1 170 else: 171 # if expression passed, format its value and show location, expr, and value 172 exp_val: str = formatter(exp) if formatter else repr(exp) 173 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 174 175 # print the message 176 print( 177 msg, 178 file=sys.stderr, 179 ) 180 181 # return the expression itself 182 return exp 183 184 185# formatted `dbg_*` functions with their helpers 186 187DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = ( 188 dict( 189 fmt="unicode", 190 precision=2, 191 stats=True, 192 shape=True, 193 dtype=True, 194 device=True, 195 requires_grad=True, 196 sparkline=True, 197 sparkline_bins=7, 198 sparkline_logy=False, 199 colored=True, 200 eq_char="=", 201 ) 202) 203 204 205DBG_TENSOR_VAL_JOINER: str = ": " 206 207 208def tensor_info(tensor: typing.Any) -> str: 209 from muutils.tensor_info import array_summary 210 211 return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) 212 213 214dbg_tensor = functools.partial( 215 dbg, formatter=tensor_info, val_joiner=DBG_TENSOR_VAL_JOINER 216)
96def dbg( 97 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 98 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 99 val_joiner: str = DEFAULT_VAL_JOINER, 100) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 101 """Call dbg with any variable or expression. 102 103 Calling dbg will print to stderr the current filename and lineno, 104 as well as the passed expression and what the expression evaluates to: 105 106 from muutils.dbg import dbg 107 108 a = 2 109 b = 5 110 111 dbg(a+b) 112 113 def square(x: int) -> int: 114 return x * x 115 116 dbg(square(a)) 117 118 """ 119 global _COUNTER 120 121 # get the context 122 line_exp: str = "unknown" 123 current_file: str = "unknown" 124 dbg_frame: typing.Optional[inspect.FrameInfo] = None 125 for frame in inspect.stack(): 126 if frame.code_context is None: 127 continue 128 line: str = frame.code_context[0] 129 if "dbg" in line: 130 current_file = _process_path(Path(frame.filename)) 131 dbg_frame = frame 132 start: int = line.find("(") + 1 133 end: int = line.rfind(")") 134 if end == -1: 135 end = len(line) 136 line_exp = line[start:end] 137 break 138 139 fname: str = "unknown" 140 if current_file.startswith("/tmp/ipykernel_"): 141 stack: list[inspect.FrameInfo] = inspect.stack() 142 filtered_functions: list[str] = [] 143 # this loop will find, in this order: 144 # - the dbg function call 145 # - the functions we care about displaying 146 # - `<module>` 147 # - a bunch of jupyter internals we don't care about 148 for frame_info in stack: 149 if _process_path(Path(frame_info.filename)) != current_file: 150 continue 151 if frame_info.function == "<module>": 152 break 153 if frame_info.function.startswith("dbg"): 154 continue 155 filtered_functions.append(frame_info.function) 156 if dbg_frame is not None: 157 filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}") 158 else: 159 filtered_functions.append(current_file) 160 filtered_functions.reverse() 161 fname = " -> ".join(filtered_functions) 162 elif dbg_frame is not None: 163 fname = f"{current_file}:{dbg_frame.lineno}" 164 165 # assemble the message 166 msg: str 167 if exp is _NoExpPassed: 168 # if no expression is passed, just show location and counter value 169 msg = f"[ {fname} ] <dbg {_COUNTER}>" 170 _COUNTER += 1 171 else: 172 # if expression passed, format its value and show location, expr, and value 173 exp_val: str = formatter(exp) if formatter else repr(exp) 174 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 175 176 # print the message 177 print( 178 msg, 179 file=sys.stderr, 180 ) 181 182 # return the expression itself 183 return exp
Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from muutils.dbg import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))