muutils.dbg
this code is based on an implementation of the Rust builtin dbg!
for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py
although it has been significantly modified
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1""" 2 3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 5although it has been significantly modified 6 7licensed under MIT: 8 9Copyright (c) 2019 Tyler Wince 10 11Permission is hereby granted, free of charge, to any person obtaining a copy 12of this software and associated documentation files (the "Software"), to deal 13in the Software without restriction, including without limitation the rights 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15copies of the Software, and to permit persons to whom the Software is 16furnished to do so, subject to the following conditions: 17 18The above copyright notice and this permission notice shall be included in 19all copies or substantial portions of the Software. 20 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27THE SOFTWARE. 28 29""" 30 31from __future__ import annotations 32 33import inspect 34import sys 35import typing 36from pathlib import Path 37import functools 38 39# type defs 40_ExpType = typing.TypeVar("_ExpType") 41 42 43# Sentinel type for no expression passed 44class _NoExpPassedSentinel: 45 """Unique sentinel type used to indicate that no expression was passed.""" 46 47 pass 48 49 50_NoExpPassed = _NoExpPassedSentinel() 51 52# global variables 53_CWD: Path = Path.cwd().absolute() 54_COUNTER: int = 0 55 56# configuration 57PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 58DEFAULT_VAL_JOINER: str = " = " 59 60 61# path processing 62def _process_path(path: Path) -> str: 63 path_abs: Path = path.absolute() 64 fname: Path 65 if PATH_MODE == "absolute": 66 fname = path_abs 67 elif PATH_MODE == "relative": 68 try: 69 # if it's inside the cwd, print the relative path 70 fname = path.relative_to(_CWD) 71 except ValueError: 72 # if its not in the subpath, use the absolute path 73 fname = path_abs 74 else: 75 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 76 77 return fname.as_posix() 78 79 80# actual dbg function 81@typing.overload 82def dbg() -> _NoExpPassedSentinel: ... 83@typing.overload 84def dbg( 85 exp: _NoExpPassedSentinel, 86 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 87 val_joiner: str = DEFAULT_VAL_JOINER, 88) -> _NoExpPassedSentinel: ... 89@typing.overload 90def dbg( 91 exp: _ExpType, 92 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 93 val_joiner: str = DEFAULT_VAL_JOINER, 94) -> _ExpType: ... 95def dbg( 96 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 97 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 98 val_joiner: str = DEFAULT_VAL_JOINER, 99) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 100 """Call dbg with any variable or expression. 101 102 Calling dbg will print to stderr the current filename and lineno, 103 as well as the passed expression and what the expression evaluates to: 104 105 from muutils.dbg import dbg 106 107 a = 2 108 b = 5 109 110 dbg(a+b) 111 112 def square(x: int) -> int: 113 return x * x 114 115 dbg(square(a)) 116 117 """ 118 global _COUNTER 119 120 # get the context 121 fname: str = "unknown" 122 line_exp: str = "unknown" 123 for frame in inspect.stack(): 124 if frame.code_context is None: 125 continue 126 line: str = frame.code_context[0] 127 if "dbg" in line: 128 start: int = line.find("(") + 1 129 end: int = line.rfind(")") 130 if end == -1: 131 end = len(line) 132 133 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 134 # special case for jupyter notebooks 135 if fname.startswith("/tmp/ipykernel_"): 136 fname = f"<ipykernel>:{frame.lineno}" 137 138 line_exp = line[start:end] 139 140 break 141 142 # assemble the message 143 msg: str 144 if exp is _NoExpPassed: 145 # if no expression is passed, just show location and counter value 146 msg = f"[ {fname} ] <dbg {_COUNTER}>" 147 _COUNTER += 1 148 else: 149 # if expression passed, format its value and show location, expr, and value 150 exp_val: str = formatter(exp) if formatter else repr(exp) 151 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 152 153 # print the message 154 print( 155 msg, 156 file=sys.stderr, 157 ) 158 159 # return the expression itself 160 return exp 161 162 163# formatted `dbg_*` functions with their helpers 164 165DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = ( 166 dict( 167 fmt="unicode", 168 precision=2, 169 stats=True, 170 shape=True, 171 dtype=True, 172 device=True, 173 requires_grad=True, 174 sparkline=True, 175 sparkline_bins=7, 176 sparkline_logy=False, 177 colored=True, 178 eq_char="=", 179 ) 180) 181 182 183DBG_TENSOR_VAL_JOINER: str = ": " 184 185 186def tensor_info(tensor: typing.Any) -> str: 187 from muutils.tensor_info import array_summary 188 189 return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) 190 191 192dbg_tensor = functools.partial( 193 dbg, formatter=tensor_info, val_joiner=DBG_TENSOR_VAL_JOINER 194)
96def dbg( 97 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 98 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 99 val_joiner: str = DEFAULT_VAL_JOINER, 100) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 101 """Call dbg with any variable or expression. 102 103 Calling dbg will print to stderr the current filename and lineno, 104 as well as the passed expression and what the expression evaluates to: 105 106 from muutils.dbg import dbg 107 108 a = 2 109 b = 5 110 111 dbg(a+b) 112 113 def square(x: int) -> int: 114 return x * x 115 116 dbg(square(a)) 117 118 """ 119 global _COUNTER 120 121 # get the context 122 fname: str = "unknown" 123 line_exp: str = "unknown" 124 for frame in inspect.stack(): 125 if frame.code_context is None: 126 continue 127 line: str = frame.code_context[0] 128 if "dbg" in line: 129 start: int = line.find("(") + 1 130 end: int = line.rfind(")") 131 if end == -1: 132 end = len(line) 133 134 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 135 # special case for jupyter notebooks 136 if fname.startswith("/tmp/ipykernel_"): 137 fname = f"<ipykernel>:{frame.lineno}" 138 139 line_exp = line[start:end] 140 141 break 142 143 # assemble the message 144 msg: str 145 if exp is _NoExpPassed: 146 # if no expression is passed, just show location and counter value 147 msg = f"[ {fname} ] <dbg {_COUNTER}>" 148 _COUNTER += 1 149 else: 150 # if expression passed, format its value and show location, expr, and value 151 exp_val: str = formatter(exp) if formatter else repr(exp) 152 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 153 154 # print the message 155 print( 156 msg, 157 file=sys.stderr, 158 ) 159 160 # return the expression itself 161 return exp
Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from muutils.dbg import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))