Coverage for auttcomp/composable.py: 100%
98 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-24 12:00 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-24 12:00 -0600
1from typing import Callable, Concatenate, Optional, ParamSpec, TypeVar, Generic
2import inspect
4_INV_R_TYPE_PACK = {type((1,)), type(None)}
6#composable
7P = ParamSpec('P')
8R = TypeVar('R')
10#partial app
11P2 = ParamSpec('P2')
12R2 = TypeVar('R2')
13A = TypeVar('A')
15#invocation
16IT = TypeVar('IT')
17IR = TypeVar('IR')
19class Composable(Generic[P, R]):
21 def __init__(self, func:Callable[P, R]):
22 self.f:Callable[P, R] = func
23 self.g = None
24 self.__chained = False
26 #composition operator
27 def __or__(self, other):
28 if not isinstance(other, Composable):
29 other = Composable(other)
31 new_comp = Composable(self)
32 self.__chained = True
33 new_comp.__chained = False
34 other_comp = Composable(other.f)
35 other_comp.__chained = True
36 new_comp.g = other_comp
38 return new_comp
40 def __get_bound_args(sig, args, kwargs):
41 bound = sig.bind_partial(*args, **kwargs)
42 bound.apply_defaults()
43 return bound.args
45 @staticmethod
46 def __get_sig_recurse(func):
47 if isinstance(func, Composable):
48 return Composable.__get_sig_recurse(func.f)
49 else:
50 if inspect.isclass(func):
51 return inspect.signature(func.__call__)
52 return inspect.signature(func)
54 __sig = None
55 def __get_singleton_sig_f(self):
56 return self.__sig if self.__sig is not None else Composable.__get_sig_recurse(self.f)
58 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
60 if len(kwargs.keys()) > 0:
61 sig = self.__get_singleton_sig_f()
62 args = Composable.__get_bound_args(sig, args, kwargs)
64 result = Composable.__internal_call(self.f, self.g, args)
65 is_single_tuple = type(result) == tuple and len(result) == 1
66 is_terminating = not self.__chained and Composable.__is_terminating(self.f, self.g)
67 should_unpack_result = is_terminating and is_single_tuple
69 if should_unpack_result:
70 result = result[0]
72 return result
74 @staticmethod
75 def __is_terminating(f, g):
76 g_chain_state = Composable.__is_chained(g)
78 if g_chain_state:
79 return True
81 return Composable.__is_chained(f) is None and g_chain_state is None #is unchained
83 @staticmethod
84 def __internal_call(f, g, args):
85 invoke_f = Composable.__invoke_compose if isinstance(f, Composable) else Composable.__invoke_native
86 result = invoke_f(f, args)
88 if g is not None:
89 invoke_g = Composable.__invoke_compose if isinstance(g, Composable) else Composable.__invoke_native
90 result = invoke_g(g, result)
92 return result
94 @staticmethod
95 def __invoke_compose(func, args):
96 return func(*args) if args is not None else func()
98 @staticmethod
99 def __invoke_native(func, args):
100 result = func(*args)
102 if type(result) not in _INV_R_TYPE_PACK:
103 result = (result,)
105 return result
107 @staticmethod
108 def __is_chained(target) -> Optional[bool]:
109 if target is None:
110 return None
112 if not isinstance(target, Composable):
113 return None
115 return target.__chained
117 #partial application operator
118 def __and__(self:Callable[Concatenate[A, P2], R2], param:A) -> Callable[P2, R2]:
119 arg_count = len(self.__get_singleton_sig_f().parameters)
120 return Composable._PartialApp._bind(self, param, arg_count)
122 class _PartialApp:
124 @staticmethod
125 def _bind(func, param, arg_count):
126 match arg_count:
127 case 1: return Composable(lambda: func(param))()
128 case 2: return Composable(lambda x: func(param, x))
129 case 3: return Composable(lambda x1, x2: func(param, x1, x2))
130 case 4: return Composable(lambda x1, x2, x3: func(param, x1, x2, x3))
131 case 5: return Composable(lambda x1, x2, x3, x4: func(param, x1, x2, x3, x4))
132 case 6: return Composable(lambda x1, x2, x3, x4, x5: func(param, x1, x2, x3, x4, x5))
133 case 7: return Composable(lambda x1, x2, x3, x4, x5, x6: func(param, x1, x2, x3, x4, x5, x6))
134 case 8: return Composable(lambda x1, x2, x3, x4, x5, x6, x7: func(param, x1, x2, x3, x4, x5, x6, x7))
135 case _: raise TypeError(f"unsupported argument count {arg_count}")
137 #invocation operator
138 def __lt__(next_func:Callable[[IT], IR], id_func:Callable[[], IT]):
139 return next_func(id_func())