muutils.dbg
an implementation of the Rust builtin dbg!
for Python,orignally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py
licensed under MIT:
Copyright (c) 2019 Tyler Wince
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1""" 2 3an implementation of the Rust builtin `dbg!` for Python,orignally from 4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 5 6licensed under MIT: 7 8Copyright (c) 2019 Tyler Wince 9 10Permission is hereby granted, free of charge, to any person obtaining a copy 11of this software and associated documentation files (the "Software"), to deal 12in the Software without restriction, including without limitation the rights 13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14copies of the Software, and to permit persons to whom the Software is 15furnished to do so, subject to the following conditions: 16 17The above copyright notice and this permission notice shall be included in 18all copies or substantial portions of the Software. 19 20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 26THE SOFTWARE. 27 28""" 29 30from __future__ import annotations 31 32import os 33import inspect 34import sys 35import typing 36from pathlib import Path 37import functools 38 39# type defs 40_ExpType = typing.TypeVar("_ExpType") 41 42 43# Sentinel type for no expression passed 44class _NoExpPassedSentinel: 45 """Unique sentinel type used to indicate that no expression was passed.""" 46 47 pass 48 49 50_NoExpPassed = _NoExpPassedSentinel() 51 52# global variables 53_CWD: Path = Path.cwd().absolute() 54_COUNTER: int = 0 55 56# configuration 57PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 58 59 60# path processing 61def _process_path(path: Path) -> str: 62 path_abs: Path = path.absolute() 63 if PATH_MODE == "absolute": 64 fname = path_abs.as_posix() 65 elif PATH_MODE == "relative": 66 try: 67 fname = path_abs.relative_to( 68 Path(os.path.commonpath([path_abs, _CWD])) 69 ).as_posix() 70 except ValueError: 71 fname = path_abs.as_posix() 72 else: 73 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 74 75 return fname 76 77 78# actual dbg function 79@typing.overload 80def dbg() -> _NoExpPassedSentinel: ... 81@typing.overload 82def dbg( 83 exp: _NoExpPassedSentinel, 84 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 85) -> _NoExpPassedSentinel: ... 86@typing.overload 87def dbg( 88 exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None 89) -> _ExpType: ... 90def dbg( 91 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 92 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 93) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 94 """Call dbg with any variable or expression. 95 96 Calling dbg will print to stderr the current filename and lineno, 97 as well as the passed expression and what the expression evaluates to: 98 99 from muutils.dbg import dbg 100 101 a = 2 102 b = 5 103 104 dbg(a+b) 105 106 def square(x: int) -> int: 107 return x * x 108 109 dbg(square(a)) 110 111 """ 112 global _COUNTER 113 114 # get the context 115 fname: str = "unknown" 116 line_exp: str = "unknown" 117 for frame in inspect.stack(): 118 if frame.code_context is None: 119 continue 120 line: str = frame.code_context[0] 121 if "dbg" in line: 122 start: int = line.find("(") + 1 123 end: int = line.rfind(")") 124 if end == -1: 125 end = len(line) 126 127 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 128 line_exp = line[start:end] 129 130 break 131 132 # assemble the message 133 msg: str 134 if exp is _NoExpPassed: 135 # if no expression is passed, just show location and counter value 136 msg = f"[ {fname} ] (dbg {_COUNTER})" 137 _COUNTER += 1 138 else: 139 # if expression passed, format its value and show location, expr, and value 140 exp_val: str = formatter(exp) if formatter else repr(exp) 141 msg = f"[ {fname} ] {line_exp} = {exp_val}" 142 143 # print the message 144 print( 145 msg, 146 file=sys.stderr, 147 ) 148 149 # return the expression itself 150 return exp 151 152 153# formatted `dbg_*` functions with their helpers 154def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]: 155 output: typing.Dict[str, str] = dict() 156 # shape 157 if hasattr(tensor, "shape"): 158 # output += f"shape={tuple(tensor.shape)}" 159 output["shape"] = repr(tuple(tensor.shape)) 160 161 # print the sum if its a nan or inf 162 if hasattr(tensor, "sum"): 163 sum: float = tensor.sum() 164 if sum != sum: 165 output["sum"] = repr(sum) 166 167 # more info 168 if hasattr(tensor, "dtype"): 169 # output += f", dtype={tensor.dtype}" 170 output["dtype"] = repr(tensor.dtype) 171 if hasattr(tensor, "device"): 172 output["device"] = repr(tensor.device) 173 if hasattr(tensor, "requires_grad"): 174 output["requires_grad"] = repr(tensor.requires_grad) 175 176 # return 177 return output 178 179 180def tensor_info(tensor: typing.Any) -> str: 181 info: typing.Dict[str, str] = tensor_info_dict(tensor) 182 return ", ".join(f"{k}={v}" for k, v in info.items()) 183 184 185dbg_tensor = functools.partial(dbg, formatter=tensor_info)
91def dbg( 92 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 93 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 94) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 95 """Call dbg with any variable or expression. 96 97 Calling dbg will print to stderr the current filename and lineno, 98 as well as the passed expression and what the expression evaluates to: 99 100 from muutils.dbg import dbg 101 102 a = 2 103 b = 5 104 105 dbg(a+b) 106 107 def square(x: int) -> int: 108 return x * x 109 110 dbg(square(a)) 111 112 """ 113 global _COUNTER 114 115 # get the context 116 fname: str = "unknown" 117 line_exp: str = "unknown" 118 for frame in inspect.stack(): 119 if frame.code_context is None: 120 continue 121 line: str = frame.code_context[0] 122 if "dbg" in line: 123 start: int = line.find("(") + 1 124 end: int = line.rfind(")") 125 if end == -1: 126 end = len(line) 127 128 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 129 line_exp = line[start:end] 130 131 break 132 133 # assemble the message 134 msg: str 135 if exp is _NoExpPassed: 136 # if no expression is passed, just show location and counter value 137 msg = f"[ {fname} ] (dbg {_COUNTER})" 138 _COUNTER += 1 139 else: 140 # if expression passed, format its value and show location, expr, and value 141 exp_val: str = formatter(exp) if formatter else repr(exp) 142 msg = f"[ {fname} ] {line_exp} = {exp_val}" 143 144 # print the message 145 print( 146 msg, 147 file=sys.stderr, 148 ) 149 150 # return the expression itself 151 return exp
Call dbg with any variable or expression.
Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:
from muutils.dbg import dbg
a = 2
b = 5
dbg(a+b)
def square(x: int) -> int:
return x * x
dbg(square(a))
155def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]: 156 output: typing.Dict[str, str] = dict() 157 # shape 158 if hasattr(tensor, "shape"): 159 # output += f"shape={tuple(tensor.shape)}" 160 output["shape"] = repr(tuple(tensor.shape)) 161 162 # print the sum if its a nan or inf 163 if hasattr(tensor, "sum"): 164 sum: float = tensor.sum() 165 if sum != sum: 166 output["sum"] = repr(sum) 167 168 # more info 169 if hasattr(tensor, "dtype"): 170 # output += f", dtype={tensor.dtype}" 171 output["dtype"] = repr(tensor.dtype) 172 if hasattr(tensor, "device"): 173 output["device"] = repr(tensor.device) 174 if hasattr(tensor, "requires_grad"): 175 output["requires_grad"] = repr(tensor.requires_grad) 176 177 # return 178 return output