Coverage for src/ttl_dict/__init__.py: 70%
132 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 18:56 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 18:56 -0700
1from __future__ import annotations
3from collections import UserDict
4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView
5from datetime import UTC, datetime, timedelta
6from typing import Any, Self, cast, overload, override
9class TTLDict[_KT, _VT](UserDict[_KT, _VT]):
10 def __init__(
11 self,
12 ttl: timedelta,
13 other: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None = None,
14 /,
15 **kwargs: _VT,
16 ) -> None:
17 self.__ttl: timedelta = ttl
19 self.__expiries: dict[_KT, datetime] = {}
21 # Must be at the end of __init__ as it calls self.update which needs self.__ttl
22 super().__init__(other, **kwargs)
24 def cleanup(self) -> None:
25 now: datetime = datetime.now(UTC)
27 expired_keys: list[_KT] = []
28 for key, expiry in self.__expiries.items():
29 # As dict is iterated by insert order, the newer ones are iterated later
30 if now < expiry:
31 break
33 expired_keys.append(key)
35 for key in expired_keys:
36 del self.__expiries[key]
37 del self.data[key]
39 def cleanup_by_key(self, key: _KT) -> bool:
40 now: datetime = datetime.now(UTC)
42 if key not in self.__expiries:
43 return False
45 if self.__expiries[key] <= now:
46 del self.__expiries[key]
47 del self.data[key]
49 return False
51 return True
53 @override
54 def __len__(self) -> int:
55 self.cleanup()
56 return super().__len__()
58 @override
59 def __contains__(self, key: _KT) -> bool:
60 return self.cleanup_by_key(key)
62 @override
63 def get(self, key: _KT, default: Any = None) -> Any:
64 self.cleanup_by_key(key)
65 return super().get(key, default)
67 @override
68 def __getitem__(self, key: _KT) -> _VT:
69 self.cleanup_by_key(key)
70 return super().__getitem__(key)
72 def get_expiry(self, key: _KT) -> datetime | None:
73 self.cleanup_by_key(key)
74 return self.__expiries.get(key)
76 @override
77 def __iter__(self) -> Iterator[_KT]:
78 self.cleanup()
79 return super().__iter__()
81 @override
82 def clear(self) -> None:
83 self.__expiries.clear()
84 super().clear()
86 @override
87 def pop(self, key: _KT, default: Any = None) -> Any:
88 if not self.__expiries.pop(key):
89 return default
91 return super().pop(key, default)
93 @override
94 def popitem(self) -> tuple[_KT, _VT]:
95 self.cleanup()
97 key, value = super().popitem()
98 self.__expiries.pop(key)
100 return (key, value)
102 @override
103 def __delitem__(self, key: _KT) -> None:
104 del self.__expiries[key]
105 super().__delitem__(key)
107 @override
108 def __setitem__(self, key: _KT, value: _VT) -> None:
109 self.__expiries[key] = datetime.now(UTC) + self.__ttl
110 super().__setitem__(key, value)
112 @override
113 def setdefault(self, key: _KT, default: Any = None) -> Any:
114 self.__expiries[key] = datetime.now(UTC) + self.__ttl
115 return super().setdefault(key, default)
117 def renew_expiry(self, key: _KT) -> None:
118 del self.__expiries[key]
119 self.__expiries[key] = datetime.now(UTC) + self.__ttl
121 @overload
122 def update(
123 self,
124 other: Mapping[_KT, _VT],
125 /,
126 **kwargs: _VT,
127 ) -> None: ...
128 @overload
129 def update(
130 self,
131 other: Iterable[tuple[_KT, _VT]],
132 /,
133 **kwargs: _VT,
134 ) -> None: ...
135 @overload
136 def update(
137 self,
138 other: None = None,
139 /,
140 **kwargs: _VT,
141 ) -> None: ...
142 @override
143 def update(
144 self,
145 other=None,
146 /,
147 **kwargs,
148 ) -> None:
149 now: datetime = datetime.now(UTC)
150 expiry: datetime = now + self.__ttl
151 other_ttl_dict: bool = isinstance(other, TTLDict)
153 key: _KT
154 value: _VT
155 if isinstance(other, Mapping):
156 for key, value in other.items():
157 self.__expiries[key] = (
158 # In rare case, item may have been expired during iteration
159 # Thus, we set expiry to now (which means it is already expired)
160 (other.get_expiry(key) or now) if other_ttl_dict
161 else expiry
162 )
163 self.data[key] = value
164 elif isinstance(other, Iterable):
165 for key, value in other:
166 self.__expiries[key] = expiry
167 self.data[key] = value
169 for str_key, value in kwargs.items():
170 key = cast("_KT", str_key)
171 self.__expiries[key] = expiry
172 self.data[key] = value
174 @override
175 def copy(self) -> TTLDict:
176 return TTLDict(self.__ttl, self)
178 @override
179 def __or__(self, other: Mapping[_KT, _VT]) -> TTLDict:
180 d: TTLDict = self.copy()
181 d.update(other)
183 return d
185 @override
186 def __ior__(self, other: Mapping[_KT, _VT]) -> Self:
187 self.update(other)
188 return self
190 @override
191 def __repr__(self) -> str:
192 self.cleanup()
193 return super().__repr__()
195 @override
196 def keys(self) -> KeysView:
197 self.cleanup()
198 return super().keys()
200 @override
201 def values(self) -> ValuesView:
202 self.cleanup()
203 return super().values()
205 @override
206 def items(self) -> ItemsView:
207 self.cleanup()
208 return super().items()