Coverage for auttcomp/shape_eval.py: 84%
195 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-24 12:00 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-24 12:00 -0600
1from typing import Union, Self, Any
2import sys
3import pprint
4import io
6class ShapeNode:
7 def __init__(self, container_type: Union[list | dict | str | tuple | None]=None, value:str=None, parent=None):
8 self.container_type : Union[list|dict|str|tuple|None] = container_type
9 self.value : str = value
10 self.parent:ShapeNode = parent
11 self.children:list[ShapeNode] = []
12 self.tuple_index = None
13 self.is_null_val = False
15 def get_nullable_container_name(self):
16 if self.is_null_val: return f"{self.container_type}?"
17 return self.container_type
19 def add_child(self, node) -> Self:
20 node.parent = self
21 self.children.append(node)
22 return node
24 def has_child_with_container(self, raw_type, out_param:list):
25 if self.children is None: return False
26 for c in self.children:
27 if c.container_type == raw_type:
28 out_param.append(c)
29 return True
31 def has_child_with_value(self, value, tuple_index=None):
32 for c in self.children:
33 if c.value == value and c.tuple_index == tuple_index:
34 return True
35 return False
37class NodeWriter:
38 def __init__(self):
39 self.h: Union[ShapeNode | None] = None
40 self.current_node: Union[ShapeNode | None] = None
42 def pop(self): self.current_node = self.current_node.parent
44 def push_container(self, raw_type, tuple_index=None, is_null_val=False):
46 new_node = ShapeNode(container_type=raw_type)
47 new_node.tuple_index = tuple_index
48 new_node.is_null_val = is_null_val
50 if self.h is None:
51 self.current_node = new_node
52 self.h = self.current_node
53 else:
54 out_param = []
55 if self.current_node.has_child_with_container(raw_type, out_param):
56 self.current_node = out_param[0]
57 if new_node.is_null_val:
58 self.current_node.is_null_val = True
59 return
61 self.current_node = self.current_node.add_child(new_node)
63 def push_list(self, tuple_index=None): self.push_container([], tuple_index)
64 def push_dict(self, tuple_index=None): self.push_container({}, tuple_index)
65 def push_tuple(self, tuple_index=None): self.push_container((1,), tuple_index)
66 def push_dict_key(self, key, is_null_val=False): self.push_container(key, tuple_index=None, is_null_val=is_null_val)
68 def write_name(self, value, tuple_index=None):
69 name = type(value).__name__ if value is not None else "None"
70 node = ShapeNode(value=name)
71 node.tuple_index = tuple_index
72 if self.h is None:
73 self.h = node
74 else:
75 if not self.current_node.has_child_with_value(name, tuple_index):
76 self.current_node.add_child(node)
78def get_path_to_node_recurse(node):
79 yield node.container_type
80 if node.parent is not None:
81 get_path_to_node_recurse(node.parent)
83def get_path_to_node(node):
84 return "->".join(list(reversed(get_path_to_node_recurse(node))))
86def node_graph_to_obj_dict_key_eval(parent_node:ShapeNode, set_any_type=False) -> Any :
87 is_nullable_container = parent_node.is_null_val
88 nodes = parent_node.children
89 if len(nodes) == 1: return node_graph_to_obj(nodes[0], set_any_type)
90 else:
91 not_none = lambda x: x is not None
92 range_values = list(map(lambda x: x.value, nodes))
93 range_containers = list(map(lambda x: x.container_type, nodes))
94 values = list(filter(not_none, range_values))
95 containers = list(filter(not_none, range_containers))
96 has_primitives = any(values)
97 has_containers = any(containers)
99 if is_nullable_container:
100 nodes_without_none_type = list(filter(lambda x: x.container_type is not None, nodes))
101 if len(nodes_without_none_type) == 1:
102 return node_graph_to_obj(nodes_without_none_type[0], set_any_type)
104 if has_primitives and not has_containers:
105 if is_nullable_container:
106 #when the container is "nullable?", we won't bother specifying None in the property
107 return "|".join(values)
108 else:
109 return "|".join(range_values)
111 path = get_path_to_node(nodes[0].parent.parent)
112 key = nodes[0].parent
113 str_rep = get_path_to_node(nodes[0])
115 if has_primitives and has_containers:
116 #in the case a dictionary has keys of differing types (other than None),
117 #will issue a warning and continue processing with the container
118 sys.stderr.writelines(f"WARNING: {path} dictionary key {key} contains both primitives and values: {str_rep}")
119 return "|".join(range_values + range_containers)
120 elif not has_primitives and has_containers:
121 sys.stderr.writelines(f"ERROR: {path} dictionary key {key} contains both array and dictionary accessors: {str_rep}")
122 return "|".join(range_containers)
124 raise Exception("unexpected path")
126#NOTE: recurse with nodeGraphToObj_dictKeyEval
127def node_graph_to_obj(node:ShapeNode, set_any_type=False) -> Any :
128 if node.value is not None:
129 if set_any_type:
130 return 'Any'
131 else:
132 return node.value
133 if isinstance(node.container_type, dict):
134 return {c.get_nullable_container_name(): node_graph_to_obj_dict_key_eval(c, set_any_type) for c in node.children}
135 if isinstance(node.container_type, list):
136 return [node_graph_to_obj(c, set_any_type) for c in node.children]
137 if isinstance(node.container_type, tuple):
138 return tuple([node_graph_to_obj(c, set_any_type) for c in sorted(node.children, key=lambda x: x.tuple_index)])
140 raise Exception("unexpected path")
143def dict_kv(obj):
144 if isinstance(obj, dict):
145 for k in obj:
146 yield k, obj[k]
147 else:
148 v = vars(obj)
149 for k in vars(obj).keys():
150 yield k, v.get(k)
152def normalize_type(obj):
153 if hasattr(obj, "__dict__"):
154 obj = obj.__dict__
155 return obj
157def object_crawler(obj, node_writer, tuple_index=None):
159 obj = normalize_type(obj)
161 if isinstance(obj, list):
162 node_writer.push_list(tuple_index)
163 for prop in obj:
164 object_crawler(prop, node_writer)
165 node_writer.pop()
166 elif isinstance(obj, dict):
167 node_writer.push_dict(tuple_index)
168 for (key, value) in dict_kv(obj):
169 node_writer.push_dict_key(key, is_null_val=value is None)
170 object_crawler(value, node_writer)
171 node_writer.pop()
172 node_writer.pop()
173 elif isinstance(obj, tuple):
174 node_writer.push_tuple(tuple_index)
175 for i in range(0, len(obj)):
176 object_crawler(obj[i], node_writer, tuple_index=i)
177 node_writer.pop()
178 else:
179 node_writer.write_name(obj, tuple_index)
181class BaseShape:
182 def __init__(self, obj):
183 self.obj = obj
185 def __eq__(self, other):
186 return self.obj == other
188 def __repr__(self):
189 ss = io.StringIO()
190 ss.write("\n")
191 pprint.pprint(self.obj, stream=ss, indent=2)
192 ss_len = ss.tell()
193 ss.seek(0)
194 data_str = ss.read(ss_len - 1)
195 return data_str
197 @staticmethod
198 def factory(obj):
199 if isinstance(obj, dict): return DictShape(obj)
200 if isinstance(obj, list): return ListShape(obj)
201 if isinstance(obj, tuple): return TupleShape(obj)
202 if isinstance(obj, str): return StrShape(obj)
203 return NoneShape()
205class NoneShape(BaseShape):
206 def __init__(self):
207 super().__init__(None)
211class DictShape(dict, BaseShape):
212 def __init__(self, obj):
213 super().__init__(obj)
214 self.obj = obj
216 def __repr__(self): return BaseShape.__repr__(self)
218 def __getattr__(self, item):
219 if item in self.obj.keys(): return BaseShape.factory(self.obj[item])
220 return NoneShape()
225class ListShape(list, BaseShape):
226 def __init__(self, obj):
227 super().__init__(obj)
228 self.obj = obj
230 def __repr__(self): return BaseShape.__repr__(self)
232 def __getattr__(self, item):
233 if hasattr(self.obj, item): return BaseShape.factory(self.obj[item])
234 return NoneShape()
236 def __getitem__(self, item):
237 return BaseShape.factory(self.obj[item])
242class StrShape(str, BaseShape):
243 def __init__(self, obj):
244 super().__init__(obj)
245 self.obj = obj
247 def __repr__(self): return BaseShape.__repr__(self)
252class TupleShape(BaseShape):
253 def __init__(self, obj):
254 super().__init__(obj)
255 self.obj = obj
257 def __repr__(self): return BaseShape.__repr__(self)
262def eval_shape(obj:Any, set_any_type=False) -> Any:
263 w = NodeWriter()
264 object_crawler(obj, w)
265 res = node_graph_to_obj(w.h, set_any_type)
266 return BaseShape.factory(res)