muutils.misc.func
1from __future__ import annotations 2import functools 3import sys 4from types import CodeType 5import warnings 6from typing import Any, Callable, Tuple, cast, TypeVar 7 8try: 9 if sys.version_info >= (3, 11): 10 # 3.11+ 11 from typing import Unpack, TypeVarTuple, ParamSpec 12 else: 13 # 3.9+ 14 from typing_extensions import Unpack, TypeVarTuple, ParamSpec # type: ignore[assignment] 15except ImportError: 16 warnings.warn( 17 "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work" 18 ) 19 ParamSpec = TypeVar # type: ignore 20 Unpack = Any # type: ignore 21 TypeVarTuple = TypeVar # type: ignore 22 23 24from muutils.errormode import ErrorMode 25 26warnings.warn("muutils.misc.func is experimental, use with caution") 27 28ReturnType = TypeVar("ReturnType") 29T_kwarg = TypeVar("T_kwarg") 30T_process_in = TypeVar("T_process_in") 31T_process_out = TypeVar("T_process_out") 32 33FuncParams = ParamSpec("FuncParams") 34FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap") 35 36 37def process_kwarg( 38 kwarg_name: str, 39 processor: Callable[[T_process_in], T_process_out], 40) -> Callable[ 41 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 42]: 43 """Decorator that applies a processor to a keyword argument. 44 45 The underlying function is expected to have a keyword argument 46 (with name `kwarg_name`) of type `T_out`, but the caller provides 47 a value of type `T_in` that is converted via `processor`. 48 49 # Parameters: 50 - `kwarg_name : str` 51 The name of the keyword argument to process. 52 - `processor : Callable[[T_in], T_out]` 53 A callable that converts the input value (`T_in`) into the 54 type expected by the function (`T_out`). 55 56 # Returns: 57 - A decorator that converts a function of type 58 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 59 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 60 """ 61 62 def decorator( 63 func: Callable[FuncParamsPreWrap, ReturnType], 64 ) -> Callable[FuncParams, ReturnType]: 65 @functools.wraps(func) 66 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 67 if kwarg_name in kwargs: 68 # Convert the caller’s value (of type T_in) to T_out 69 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 70 return func(*args, **kwargs) # type: ignore[arg-type] 71 72 return cast(Callable[FuncParams, ReturnType], wrapper) 73 74 return decorator 75 76 77@process_kwarg("action", ErrorMode.from_any) 78def validate_kwarg( 79 kwarg_name: str, 80 validator: Callable[[T_kwarg], bool], 81 description: str | None = None, 82 action: ErrorMode = ErrorMode.EXCEPT, 83) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 84 """Decorator that validates a specific keyword argument. 85 86 # Parameters: 87 - `kwarg_name : str` 88 The name of the keyword argument to validate. 89 - `validator : Callable[[Any], bool]` 90 A callable that returns True if the keyword argument is valid. 91 - `description : str | None` 92 A message template if validation fails. 93 - `action : str` 94 Either `"raise"` (default) or `"warn"`. 95 96 # Returns: 97 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 98 A decorator that validates the keyword argument. 99 100 # Modifies: 101 - If validation fails and `action=="warn"`, emits a warning. 102 Otherwise, raises a ValueError. 103 104 # Usage: 105 106 ```python 107 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 108 def my_func(x: int) -> int: 109 return x 110 111 assert my_func(x=1) == 1 112 ``` 113 114 # Raises: 115 - `ValueError` if validation fails and `action == "raise"`. 116 """ 117 118 def decorator( 119 func: Callable[FuncParams, ReturnType], 120 ) -> Callable[FuncParams, ReturnType]: 121 @functools.wraps(func) 122 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 123 if kwarg_name in kwargs: 124 value: Any = kwargs[kwarg_name] 125 if not validator(value): 126 msg: str = ( 127 description.format(kwarg_name=kwarg_name, value=value) 128 if description 129 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 130 ) 131 if action == "warn": 132 warnings.warn(msg, UserWarning) 133 else: 134 raise ValueError(msg) 135 return func(*args, **kwargs) 136 137 return cast(Callable[FuncParams, ReturnType], wrapper) 138 139 return decorator 140 141 142def replace_kwarg( 143 kwarg_name: str, 144 check: Callable[[T_kwarg], bool], 145 replacement_value: T_kwarg, 146 replace_if_missing: bool = False, 147) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 148 """Decorator that replaces a specific keyword argument value by identity comparison. 149 150 # Parameters: 151 - `kwarg_name : str` 152 The name of the keyword argument to replace. 153 - `check : Callable[[T_kwarg], bool]` 154 A callable that returns True if the keyword argument should be replaced. 155 - `replacement_value : T_kwarg` 156 The value to replace with. 157 - `replace_if_missing : bool` 158 If True, replaces the keyword argument even if it's missing. 159 160 # Returns: 161 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 162 A decorator that replaces the keyword argument value. 163 164 # Modifies: 165 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 166 167 # Usage: 168 169 ```python 170 @replace_kwarg("x", None, "default_string") 171 def my_func(*, x: str | None = None) -> str: 172 return x 173 174 assert my_func(x=None) == "default_string" 175 ``` 176 """ 177 178 def decorator( 179 func: Callable[FuncParams, ReturnType], 180 ) -> Callable[FuncParams, ReturnType]: 181 @functools.wraps(func) 182 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 183 if kwarg_name in kwargs: 184 # TODO: no way to type hint this, I think 185 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 186 kwargs[kwarg_name] = replacement_value 187 elif replace_if_missing and kwarg_name not in kwargs: 188 kwargs[kwarg_name] = replacement_value 189 return func(*args, **kwargs) 190 191 return cast(Callable[FuncParams, ReturnType], wrapper) 192 193 return decorator 194 195 196def is_none(value: Any) -> bool: 197 return value is None 198 199 200def always_true(value: Any) -> bool: 201 return True 202 203 204def always_false(value: Any) -> bool: 205 return False 206 207 208def format_docstring( 209 **fmt_kwargs: Any, 210) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 211 """Decorator that formats a function's docstring with the provided keyword arguments.""" 212 213 def decorator( 214 func: Callable[FuncParams, ReturnType], 215 ) -> Callable[FuncParams, ReturnType]: 216 if func.__doc__ is not None: 217 func.__doc__ = func.__doc__.format(**fmt_kwargs) 218 return func 219 220 return decorator 221 222 223# TODO: no way to make the type system understand this afaik 224LambdaArgs = TypeVarTuple("LambdaArgs") 225LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...]) 226 227 228def typed_lambda( 229 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 230 in_types: LambdaArgsTypes, 231 out_type: type[ReturnType], 232) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 233 """Wraps a lambda function with type hints. 234 235 # Parameters: 236 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 237 The lambda function to wrap. 238 - `in_types : tuple[type, ...]` 239 Tuple of input types. 240 - `out_type : type[ReturnType]` 241 The output type. 242 243 # Returns: 244 - `Callable[..., ReturnType]` 245 A new function with annotations matching the given signature. 246 247 # Usage: 248 249 ```python 250 add = typed_lambda(lambda x, y: x + y, (int, int), int) 251 assert add(1, 2) == 3 252 ``` 253 254 # Raises: 255 - `ValueError` if the number of input types doesn't match the lambda's parameters. 256 """ 257 code: CodeType = fn.__code__ 258 n_params: int = code.co_argcount 259 260 if len(in_types) != n_params: 261 raise ValueError( 262 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 263 ) 264 265 param_names: tuple[str, ...] = code.co_varnames[:n_params] 266 annotations: dict[str, type] = { # type: ignore[var-annotated] 267 name: typ 268 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 269 } 270 annotations["return"] = out_type 271 272 @functools.wraps(fn) 273 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: 274 return fn(*args) 275 276 wrapped.__annotations__ = annotations 277 return wrapped
FuncParams =
~FuncParams
FuncParamsPreWrap =
~FuncParamsPreWrap
def
process_kwarg( kwarg_name: str, processor: Callable[[~T_process_in], ~T_process_out]) -> Callable[[Callable[~FuncParamsPreWrap, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
38def process_kwarg( 39 kwarg_name: str, 40 processor: Callable[[T_process_in], T_process_out], 41) -> Callable[ 42 [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType] 43]: 44 """Decorator that applies a processor to a keyword argument. 45 46 The underlying function is expected to have a keyword argument 47 (with name `kwarg_name`) of type `T_out`, but the caller provides 48 a value of type `T_in` that is converted via `processor`. 49 50 # Parameters: 51 - `kwarg_name : str` 52 The name of the keyword argument to process. 53 - `processor : Callable[[T_in], T_out]` 54 A callable that converts the input value (`T_in`) into the 55 type expected by the function (`T_out`). 56 57 # Returns: 58 - A decorator that converts a function of type 59 `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`) 60 into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`). 61 """ 62 63 def decorator( 64 func: Callable[FuncParamsPreWrap, ReturnType], 65 ) -> Callable[FuncParams, ReturnType]: 66 @functools.wraps(func) 67 def wrapper(*args: Any, **kwargs: Any) -> ReturnType: 68 if kwarg_name in kwargs: 69 # Convert the caller’s value (of type T_in) to T_out 70 kwargs[kwarg_name] = processor(kwargs[kwarg_name]) 71 return func(*args, **kwargs) # type: ignore[arg-type] 72 73 return cast(Callable[FuncParams, ReturnType], wrapper) 74 75 return decorator
Decorator that applies a processor to a keyword argument.
The underlying function is expected to have a keyword argument
(with name kwarg_name
) of type T_out
, but the caller provides
a value of type T_in
that is converted via processor
.
Parameters:
kwarg_name : str
The name of the keyword argument to process.processor : Callable[[T_in], T_out]
A callable that converts the input value (T_in
) into the type expected by the function (T_out
).
Returns:
- A decorator that converts a function of type
Callable[OutputParams, ReturnType]
(expectingkwarg_name
of typeT_out
) into one of typeCallable[InputParams, ReturnType]
(acceptingkwarg_name
of typeT_in
).
@process_kwarg('action', ErrorMode.from_any)
def
validate_kwarg( kwarg_name: str, validator: Callable[[~T_kwarg], bool], description: str | None = None, action: muutils.errormode.ErrorMode = ErrorMode.Except) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
78@process_kwarg("action", ErrorMode.from_any) 79def validate_kwarg( 80 kwarg_name: str, 81 validator: Callable[[T_kwarg], bool], 82 description: str | None = None, 83 action: ErrorMode = ErrorMode.EXCEPT, 84) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 85 """Decorator that validates a specific keyword argument. 86 87 # Parameters: 88 - `kwarg_name : str` 89 The name of the keyword argument to validate. 90 - `validator : Callable[[Any], bool]` 91 A callable that returns True if the keyword argument is valid. 92 - `description : str | None` 93 A message template if validation fails. 94 - `action : str` 95 Either `"raise"` (default) or `"warn"`. 96 97 # Returns: 98 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 99 A decorator that validates the keyword argument. 100 101 # Modifies: 102 - If validation fails and `action=="warn"`, emits a warning. 103 Otherwise, raises a ValueError. 104 105 # Usage: 106 107 ```python 108 @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}") 109 def my_func(x: int) -> int: 110 return x 111 112 assert my_func(x=1) == 1 113 ``` 114 115 # Raises: 116 - `ValueError` if validation fails and `action == "raise"`. 117 """ 118 119 def decorator( 120 func: Callable[FuncParams, ReturnType], 121 ) -> Callable[FuncParams, ReturnType]: 122 @functools.wraps(func) 123 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 124 if kwarg_name in kwargs: 125 value: Any = kwargs[kwarg_name] 126 if not validator(value): 127 msg: str = ( 128 description.format(kwarg_name=kwarg_name, value=value) 129 if description 130 else f"Validation failed for keyword '{kwarg_name}' with value {value}" 131 ) 132 if action == "warn": 133 warnings.warn(msg, UserWarning) 134 else: 135 raise ValueError(msg) 136 return func(*args, **kwargs) 137 138 return cast(Callable[FuncParams, ReturnType], wrapper) 139 140 return decorator
Decorator that validates a specific keyword argument.
Parameters:
kwarg_name : str
The name of the keyword argument to validate.validator : Callable[[Any], bool]
A callable that returns True if the keyword argument is valid.description : str | None
A message template if validation fails.action : str
Either"raise"
(default) or"warn"
.
Returns:
Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that validates the keyword argument.
Modifies:
- If validation fails and
action=="warn"
, emits a warning. Otherwise, raises a ValueError.
Usage:
@validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
def my_func(x: int) -> int:
return x
assert my_func(x=1) == 1
Raises:
ValueError
if validation fails andaction == "raise"
.
def
replace_kwarg( kwarg_name: str, check: Callable[[~T_kwarg], bool], replacement_value: ~T_kwarg, replace_if_missing: bool = False) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
143def replace_kwarg( 144 kwarg_name: str, 145 check: Callable[[T_kwarg], bool], 146 replacement_value: T_kwarg, 147 replace_if_missing: bool = False, 148) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 149 """Decorator that replaces a specific keyword argument value by identity comparison. 150 151 # Parameters: 152 - `kwarg_name : str` 153 The name of the keyword argument to replace. 154 - `check : Callable[[T_kwarg], bool]` 155 A callable that returns True if the keyword argument should be replaced. 156 - `replacement_value : T_kwarg` 157 The value to replace with. 158 - `replace_if_missing : bool` 159 If True, replaces the keyword argument even if it's missing. 160 161 # Returns: 162 - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]` 163 A decorator that replaces the keyword argument value. 164 165 # Modifies: 166 - Updates `kwargs[kwarg_name]` if its value is `default_value`. 167 168 # Usage: 169 170 ```python 171 @replace_kwarg("x", None, "default_string") 172 def my_func(*, x: str | None = None) -> str: 173 return x 174 175 assert my_func(x=None) == "default_string" 176 ``` 177 """ 178 179 def decorator( 180 func: Callable[FuncParams, ReturnType], 181 ) -> Callable[FuncParams, ReturnType]: 182 @functools.wraps(func) 183 def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType: 184 if kwarg_name in kwargs: 185 # TODO: no way to type hint this, I think 186 if check(kwargs[kwarg_name]): # type: ignore[arg-type] 187 kwargs[kwarg_name] = replacement_value 188 elif replace_if_missing and kwarg_name not in kwargs: 189 kwargs[kwarg_name] = replacement_value 190 return func(*args, **kwargs) 191 192 return cast(Callable[FuncParams, ReturnType], wrapper) 193 194 return decorator
Decorator that replaces a specific keyword argument value by identity comparison.
Parameters:
kwarg_name : str
The name of the keyword argument to replace.check : Callable[[T_kwarg], bool]
A callable that returns True if the keyword argument should be replaced.replacement_value : T_kwarg
The value to replace with.replace_if_missing : bool
If True, replaces the keyword argument even if it's missing.
Returns:
Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]
A decorator that replaces the keyword argument value.
Modifies:
- Updates
kwargs[kwarg_name]
if its value isdefault_value
.
Usage:
@replace_kwarg("x", None, "default_string")
def my_func(*, x: str | None = None) -> str:
return x
assert my_func(x=None) == "default_string"
def
is_none(value: Any) -> bool:
def
always_true(value: Any) -> bool:
def
always_false(value: Any) -> bool:
def
format_docstring( **fmt_kwargs: Any) -> Callable[[Callable[~FuncParams, ~ReturnType]], Callable[~FuncParams, ~ReturnType]]:
209def format_docstring( 210 **fmt_kwargs: Any, 211) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]: 212 """Decorator that formats a function's docstring with the provided keyword arguments.""" 213 214 def decorator( 215 func: Callable[FuncParams, ReturnType], 216 ) -> Callable[FuncParams, ReturnType]: 217 if func.__doc__ is not None: 218 func.__doc__ = func.__doc__.format(**fmt_kwargs) 219 return func 220 221 return decorator
Decorator that formats a function's docstring with the provided keyword arguments.
LambdaArgs =
LambdaArgs
def
typed_lambda( fn: Callable[[Unpack[LambdaArgs]], ~ReturnType], in_types: ~LambdaArgsTypes, out_type: type[~ReturnType]) -> Callable[[Unpack[LambdaArgs]], ~ReturnType]:
229def typed_lambda( 230 fn: Callable[[Unpack[LambdaArgs]], ReturnType], 231 in_types: LambdaArgsTypes, 232 out_type: type[ReturnType], 233) -> Callable[[Unpack[LambdaArgs]], ReturnType]: 234 """Wraps a lambda function with type hints. 235 236 # Parameters: 237 - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]` 238 The lambda function to wrap. 239 - `in_types : tuple[type, ...]` 240 Tuple of input types. 241 - `out_type : type[ReturnType]` 242 The output type. 243 244 # Returns: 245 - `Callable[..., ReturnType]` 246 A new function with annotations matching the given signature. 247 248 # Usage: 249 250 ```python 251 add = typed_lambda(lambda x, y: x + y, (int, int), int) 252 assert add(1, 2) == 3 253 ``` 254 255 # Raises: 256 - `ValueError` if the number of input types doesn't match the lambda's parameters. 257 """ 258 code: CodeType = fn.__code__ 259 n_params: int = code.co_argcount 260 261 if len(in_types) != n_params: 262 raise ValueError( 263 f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})" 264 ) 265 266 param_names: tuple[str, ...] = code.co_varnames[:n_params] 267 annotations: dict[str, type] = { # type: ignore[var-annotated] 268 name: typ 269 for name, typ in zip(param_names, in_types) # type: ignore[arg-type] 270 } 271 annotations["return"] = out_type 272 273 @functools.wraps(fn) 274 def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType: 275 return fn(*args) 276 277 wrapped.__annotations__ = annotations 278 return wrapped
Wraps a lambda function with type hints.
Parameters:
fn : Callable[[Unpack[LambdaArgs]], ReturnType]
The lambda function to wrap.in_types : tuple[type, ...]
Tuple of input types.out_type : type[ReturnType]
The output type.
Returns:
Callable[..., ReturnType]
A new function with annotations matching the given signature.
Usage:
add = typed_lambda(lambda x, y: x + y, (int, int), int)
assert add(1, 2) == 3
Raises:
ValueError
if the number of input types doesn't match the lambda's parameters.