Coverage for src/ttl_dict/__init__.py: 70%

132 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-22 18:56 -0700

1from __future__ import annotations 

2 

3from collections import UserDict 

4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView 

5from datetime import UTC, datetime, timedelta 

6from typing import Any, Self, cast, overload, override 

7 

8 

9class TTLDict[_KT, _VT](UserDict[_KT, _VT]): 

10 def __init__( 

11 self, 

12 ttl: timedelta, 

13 other: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None = None, 

14 /, 

15 **kwargs: _VT, 

16 ) -> None: 

17 self.__ttl: timedelta = ttl 

18 

19 self.__expiries: dict[_KT, datetime] = {} 

20 

21 # Must be at the end of __init__ as it calls self.update which needs self.__ttl 

22 super().__init__(other, **kwargs) 

23 

24 def cleanup(self) -> None: 

25 now: datetime = datetime.now(UTC) 

26 

27 expired_keys: list[_KT] = [] 

28 for key, expiry in self.__expiries.items(): 

29 # As dict is iterated by insert order, the newer ones are iterated later 

30 if now < expiry: 

31 break 

32 

33 expired_keys.append(key) 

34 

35 for key in expired_keys: 

36 del self.__expiries[key] 

37 del self.data[key] 

38 

39 def cleanup_by_key(self, key: _KT) -> bool: 

40 now: datetime = datetime.now(UTC) 

41 

42 if key not in self.__expiries: 

43 return False 

44 

45 if self.__expiries[key] <= now: 

46 del self.__expiries[key] 

47 del self.data[key] 

48 

49 return False 

50 

51 return True 

52 

53 @override 

54 def __len__(self) -> int: 

55 self.cleanup() 

56 return super().__len__() 

57 

58 @override 

59 def __contains__(self, key: _KT) -> bool: 

60 return self.cleanup_by_key(key) 

61 

62 @override 

63 def get(self, key: _KT, default: Any = None) -> Any: 

64 self.cleanup_by_key(key) 

65 return super().get(key, default) 

66 

67 @override 

68 def __getitem__(self, key: _KT) -> _VT: 

69 self.cleanup_by_key(key) 

70 return super().__getitem__(key) 

71 

72 def get_expiry(self, key: _KT) -> datetime | None: 

73 self.cleanup_by_key(key) 

74 return self.__expiries.get(key) 

75 

76 @override 

77 def __iter__(self) -> Iterator[_KT]: 

78 self.cleanup() 

79 return super().__iter__() 

80 

81 @override 

82 def clear(self) -> None: 

83 self.__expiries.clear() 

84 super().clear() 

85 

86 @override 

87 def pop(self, key: _KT, default: Any = None) -> Any: 

88 if not self.__expiries.pop(key): 

89 return default 

90 

91 return super().pop(key, default) 

92 

93 @override 

94 def popitem(self) -> tuple[_KT, _VT]: 

95 self.cleanup() 

96 

97 key, value = super().popitem() 

98 self.__expiries.pop(key) 

99 

100 return (key, value) 

101 

102 @override 

103 def __delitem__(self, key: _KT) -> None: 

104 del self.__expiries[key] 

105 super().__delitem__(key) 

106 

107 @override 

108 def __setitem__(self, key: _KT, value: _VT) -> None: 

109 self.__expiries[key] = datetime.now(UTC) + self.__ttl 

110 super().__setitem__(key, value) 

111 

112 @override 

113 def setdefault(self, key: _KT, default: Any = None) -> Any: 

114 self.__expiries[key] = datetime.now(UTC) + self.__ttl 

115 return super().setdefault(key, default) 

116 

117 def renew_expiry(self, key: _KT) -> None: 

118 del self.__expiries[key] 

119 self.__expiries[key] = datetime.now(UTC) + self.__ttl 

120 

121 @overload 

122 def update( 

123 self, 

124 other: Mapping[_KT, _VT], 

125 /, 

126 **kwargs: _VT, 

127 ) -> None: ... 

128 @overload 

129 def update( 

130 self, 

131 other: Iterable[tuple[_KT, _VT]], 

132 /, 

133 **kwargs: _VT, 

134 ) -> None: ... 

135 @overload 

136 def update( 

137 self, 

138 other: None = None, 

139 /, 

140 **kwargs: _VT, 

141 ) -> None: ... 

142 @override 

143 def update( 

144 self, 

145 other=None, 

146 /, 

147 **kwargs, 

148 ) -> None: 

149 now: datetime = datetime.now(UTC) 

150 expiry: datetime = now + self.__ttl 

151 other_ttl_dict: bool = isinstance(other, TTLDict) 

152 

153 key: _KT 

154 value: _VT 

155 if isinstance(other, Mapping): 

156 for key, value in other.items(): 

157 self.__expiries[key] = ( 

158 # In rare case, item may have been expired during iteration 

159 # Thus, we set expiry to now (which means it is already expired) 

160 (other.get_expiry(key) or now) if other_ttl_dict 

161 else expiry 

162 ) 

163 self.data[key] = value 

164 elif isinstance(other, Iterable): 

165 for key, value in other: 

166 self.__expiries[key] = expiry 

167 self.data[key] = value 

168 

169 for str_key, value in kwargs.items(): 

170 key = cast("_KT", str_key) 

171 self.__expiries[key] = expiry 

172 self.data[key] = value 

173 

174 @override 

175 def copy(self) -> TTLDict: 

176 return TTLDict(self.__ttl, self) 

177 

178 @override 

179 def __or__(self, other: Mapping[_KT, _VT]) -> TTLDict: 

180 d: TTLDict = self.copy() 

181 d.update(other) 

182 

183 return d 

184 

185 @override 

186 def __ior__(self, other: Mapping[_KT, _VT]) -> Self: 

187 self.update(other) 

188 return self 

189 

190 @override 

191 def __repr__(self) -> str: 

192 self.cleanup() 

193 return super().__repr__() 

194 

195 @override 

196 def keys(self) -> KeysView: 

197 self.cleanup() 

198 return super().keys() 

199 

200 @override 

201 def values(self) -> ValuesView: 

202 self.cleanup() 

203 return super().values() 

204 

205 @override 

206 def items(self) -> ItemsView: 

207 self.cleanup() 

208 return super().items()