docs for muutils v0.8.6
View Source on GitHub

muutils.dbg

this code is based on an implementation of the Rust builtin dbg! for Python, originally from https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has been significantly modified

licensed under MIT:

Copyright (c) 2019 Tyler Wince

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


  1"""
  2
  3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
  4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
  5although it has been significantly modified
  6
  7licensed under MIT:
  8
  9Copyright (c) 2019 Tyler Wince
 10
 11Permission is hereby granted, free of charge, to any person obtaining a copy
 12of this software and associated documentation files (the "Software"), to deal
 13in the Software without restriction, including without limitation the rights
 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 15copies of the Software, and to permit persons to whom the Software is
 16furnished to do so, subject to the following conditions:
 17
 18The above copyright notice and this permission notice shall be included in
 19all copies or substantial portions of the Software.
 20
 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 27THE SOFTWARE.
 28
 29"""
 30
 31from __future__ import annotations
 32
 33import inspect
 34import sys
 35import typing
 36from pathlib import Path
 37import functools
 38
 39# type defs
 40_ExpType = typing.TypeVar("_ExpType")
 41
 42
 43# Sentinel type for no expression passed
 44class _NoExpPassedSentinel:
 45    """Unique sentinel type used to indicate that no expression was passed."""
 46
 47    pass
 48
 49
 50_NoExpPassed = _NoExpPassedSentinel()
 51
 52# global variables
 53_CWD: Path = Path.cwd().absolute()
 54_COUNTER: int = 0
 55
 56# configuration
 57PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
 58DEFAULT_VAL_JOINER: str = " = "
 59
 60
 61# path processing
 62def _process_path(path: Path) -> str:
 63    path_abs: Path = path.absolute()
 64    fname: Path
 65    if PATH_MODE == "absolute":
 66        fname = path_abs
 67    elif PATH_MODE == "relative":
 68        try:
 69            # if it's inside the cwd, print the relative path
 70            fname = path.relative_to(_CWD)
 71        except ValueError:
 72            # if its not in the subpath, use the absolute path
 73            fname = path_abs
 74    else:
 75        raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
 76
 77    return fname.as_posix()
 78
 79
 80# actual dbg function
 81@typing.overload
 82def dbg() -> _NoExpPassedSentinel: ...
 83@typing.overload
 84def dbg(
 85    exp: _NoExpPassedSentinel,
 86    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 87    val_joiner: str = DEFAULT_VAL_JOINER,
 88) -> _NoExpPassedSentinel: ...
 89@typing.overload
 90def dbg(
 91    exp: _ExpType,
 92    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 93    val_joiner: str = DEFAULT_VAL_JOINER,
 94) -> _ExpType: ...
 95def dbg(
 96    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
 97    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 98    val_joiner: str = DEFAULT_VAL_JOINER,
 99) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
100    """Call dbg with any variable or expression.
101
102    Calling dbg will print to stderr the current filename and lineno,
103    as well as the passed expression and what the expression evaluates to:
104
105            from muutils.dbg import dbg
106
107            a = 2
108            b = 5
109
110            dbg(a+b)
111
112            def square(x: int) -> int:
113                    return x * x
114
115            dbg(square(a))
116
117    """
118    global _COUNTER
119
120    # get the context
121    fname: str = "unknown"
122    line_exp: str = "unknown"
123    for frame in inspect.stack():
124        if frame.code_context is None:
125            continue
126        line: str = frame.code_context[0]
127        if "dbg" in line:
128            start: int = line.find("(") + 1
129            end: int = line.rfind(")")
130            if end == -1:
131                end = len(line)
132
133            fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}"
134            # special case for jupyter notebooks
135            if fname.startswith("/tmp/ipykernel_"):
136                fname = f"<ipykernel>:{frame.lineno}"
137
138            line_exp = line[start:end]
139
140            break
141
142    # assemble the message
143    msg: str
144    if exp is _NoExpPassed:
145        # if no expression is passed, just show location and counter value
146        msg = f"[ {fname} ] <dbg {_COUNTER}>"
147        _COUNTER += 1
148    else:
149        # if expression passed, format its value and show location, expr, and value
150        exp_val: str = formatter(exp) if formatter else repr(exp)
151        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
152
153    # print the message
154    print(
155        msg,
156        file=sys.stderr,
157    )
158
159    # return the expression itself
160    return exp
161
162
163# formatted `dbg_*` functions with their helpers
164
165DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = (
166    dict(
167        fmt="unicode",
168        precision=2,
169        stats=True,
170        shape=True,
171        dtype=True,
172        device=True,
173        requires_grad=True,
174        sparkline=True,
175        sparkline_bins=7,
176        sparkline_logy=False,
177        colored=True,
178        eq_char="=",
179    )
180)
181
182
183DBG_TENSOR_VAL_JOINER: str = ": "
184
185
186def tensor_info(tensor: typing.Any) -> str:
187    from muutils.tensor_info import array_summary
188
189    return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
190
191
192dbg_tensor = functools.partial(
193    dbg, formatter=tensor_info, val_joiner=DBG_TENSOR_VAL_JOINER
194)

PATH_MODE: Literal['relative', 'absolute'] = 'relative'
DEFAULT_VAL_JOINER: str = ' = '
def dbg( exp: Union[~_ExpType, muutils.dbg._NoExpPassedSentinel] = <muutils.dbg._NoExpPassedSentinel object>, formatter: Optional[Callable[[Any], str]] = None, val_joiner: str = ' = ') -> Union[~_ExpType, muutils.dbg._NoExpPassedSentinel]:
 96def dbg(
 97    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
 98    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 99    val_joiner: str = DEFAULT_VAL_JOINER,
100) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
101    """Call dbg with any variable or expression.
102
103    Calling dbg will print to stderr the current filename and lineno,
104    as well as the passed expression and what the expression evaluates to:
105
106            from muutils.dbg import dbg
107
108            a = 2
109            b = 5
110
111            dbg(a+b)
112
113            def square(x: int) -> int:
114                    return x * x
115
116            dbg(square(a))
117
118    """
119    global _COUNTER
120
121    # get the context
122    fname: str = "unknown"
123    line_exp: str = "unknown"
124    for frame in inspect.stack():
125        if frame.code_context is None:
126            continue
127        line: str = frame.code_context[0]
128        if "dbg" in line:
129            start: int = line.find("(") + 1
130            end: int = line.rfind(")")
131            if end == -1:
132                end = len(line)
133
134            fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}"
135            # special case for jupyter notebooks
136            if fname.startswith("/tmp/ipykernel_"):
137                fname = f"<ipykernel>:{frame.lineno}"
138
139            line_exp = line[start:end]
140
141            break
142
143    # assemble the message
144    msg: str
145    if exp is _NoExpPassed:
146        # if no expression is passed, just show location and counter value
147        msg = f"[ {fname} ] <dbg {_COUNTER}>"
148        _COUNTER += 1
149    else:
150        # if expression passed, format its value and show location, expr, and value
151        exp_val: str = formatter(exp) if formatter else repr(exp)
152        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
153
154    # print the message
155    print(
156        msg,
157        file=sys.stderr,
158    )
159
160    # return the expression itself
161    return exp

Call dbg with any variable or expression.

Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:

    from muutils.dbg import dbg

    a = 2
    b = 5

    dbg(a+b)

    def square(x: int) -> int:
            return x * x

    dbg(square(a))
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: Dict[str, Union[str, int, bool]] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': True, 'sparkline_bins': 7, 'sparkline_logy': False, 'colored': True, 'eq_char': '='}
DBG_TENSOR_VAL_JOINER: str = ': '
def tensor_info(tensor: Any) -> str:
187def tensor_info(tensor: typing.Any) -> str:
188    from muutils.tensor_info import array_summary
189
190    return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
dbg_tensor = functools.partial(<function dbg>, formatter=<function tensor_info>, val_joiner=': ')