Coverage for src/extratools_core/trie.py: 62%
93 statements
« prev ^ index » next coverage.py v7.8.1, created at 2025-05-27 20:51 -0700
« prev ^ index » next coverage.py v7.8.1, created at 2025-05-27 20:51 -0700
1from __future__ import annotations
3from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping
4from typing import Any
7class TrieDict[VT: Any](MutableMapping[str, VT]):
8 def __init__(
9 self,
10 initial_data: Mapping[str, VT] | Iterable[tuple[str, VT]] | None = None,
11 ) -> None:
12 self.root: dict[str, Any] = {}
14 self.__len: int = 0
16 if initial_data:
17 for key, value in (
18 initial_data.items() if isinstance(initial_data, Mapping)
19 else initial_data
20 ):
21 self.__setitem__(key, value)
23 def __len__(self) -> int:
24 return self.__len
26 def __find(self, s: str, func: Callable[[dict[str, Any], str], Any]) -> Any:
27 node: dict[str, Any] = self.root
29 while True:
30 c: str = s[0] if s else ""
31 rest: str = s[1:] if s else ""
33 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
34 if next_node is None:
35 raise KeyError
37 if isinstance(next_node, dict):
38 node = next_node
39 s = rest
40 continue
42 if rest == next_node[0]:
43 return func(node, c)
45 raise KeyError
47 def __delitem__(self, s: str) -> None:
48 def delitem(node: dict[str, Any], c: str) -> None:
49 del node[c]
50 self.__len -= 1
52 return self.__find(s, delitem)
54 def __getitem__(self, s: str) -> VT:
55 def getitem(node: dict[str, Any], c: str) -> VT:
56 return node[c][1]
58 return self.__find(s, getitem)
60 def __setitem__(self, s: str, v: VT) -> None:
61 self.__set(s, v, self.root, is_new=True)
63 def __set(self, s: str, v: VT, node: dict[str, Any], *, is_new: bool) -> None:
64 if not s:
65 is_new = is_new and "" not in node
66 node[""] = ("", v)
67 if is_new:
68 self.__len += 1
70 return
72 c: str = s[0]
73 rest: str = s[1:]
75 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
76 if next_node is None:
77 node[c] = (rest, v)
78 if is_new:
79 self.__len += 1
80 elif isinstance(next_node, dict):
81 self.__set(rest, v, next_node, is_new=is_new)
82 else:
83 other_rest: str
84 other_value: VT
85 other_rest, other_value = next_node
87 if rest == other_rest:
88 node[c] = (rest, v)
89 return
91 next_node = node[c] = {}
93 self.__set(other_rest, other_value, next_node, is_new=False)
94 self.__set(rest, v, next_node, is_new=is_new)
96 def __iter__(self) -> Iterator[str]:
97 for _, value in self.__prefixes("", self.root):
98 yield value
100 def prefixes(self) -> Iterator[tuple[str, str]]:
101 yield from self.__prefixes("", self.root)
103 def __prefixes(self, prefix: str, node: dict[str, Any]) -> Iterator[tuple[str, str]]:
104 for key, next_node in node.items():
105 new_prefix = prefix + key
106 if isinstance(next_node, dict):
107 yield from self.__prefixes(new_prefix, next_node)
108 else:
109 yield (new_prefix, new_prefix + next_node[0])
111 def match(self, prefix: str) -> Iterator[str]:
112 node: dict[str, Any] = self.root
113 s: str = prefix
115 matched: str = ""
117 while s:
118 c: str = s[0]
119 rest: str = s[1:]
120 matched += c
122 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
123 if next_node is None:
124 return
126 if isinstance(next_node, dict):
127 node = next_node
128 s = rest
129 continue
131 other_rest: str = next_node[0]
132 if other_rest.startswith(rest):
133 yield matched + other_rest
135 return
137 for _, value in self.__prefixes(prefix, node):
138 yield value