Coverage for src/ttl_dict/__init__.py: 75%

113 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-05-23 07:58 -0700

1from __future__ import annotations 

2 

3from collections import UserDict 

4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView 

5from datetime import UTC, datetime, timedelta 

6from typing import Self, cast, overload, override 

7 

8 

9class TTLDict[_KT, _VT](UserDict[_KT, _VT]): 

10 def __init__( 

11 self, 

12 ttl: timedelta, 

13 other: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None = None, 

14 /, 

15 **kwargs: _VT, 

16 ) -> None: 

17 self.__ttl: timedelta = ttl 

18 

19 self.__expiries: dict[_KT, datetime] = {} 

20 

21 # Must be at the end of __init__ as it calls self.update which needs self.__ttl 

22 super().__init__(other, **kwargs) 

23 

24 def cleanup(self) -> None: 

25 now: datetime = datetime.now(UTC) 

26 

27 expired_keys: list[_KT] = [] 

28 for key, expiry in self.__expiries.items(): 

29 # As dict is iterated by insert order, the newer ones are iterated later 

30 if now < expiry: 

31 break 

32 

33 expired_keys.append(key) 

34 

35 for key in expired_keys: 

36 del self.__expiries[key] 

37 del self.data[key] 

38 

39 def cleanup_by_key(self, key: _KT) -> bool: 

40 now: datetime = datetime.now(UTC) 

41 

42 if key not in self.__expiries: 

43 return False 

44 

45 if self.__expiries[key] <= now: 

46 del self.__expiries[key] 

47 del self.data[key] 

48 

49 return False 

50 

51 return True 

52 

53 @override 

54 def __len__(self) -> int: 

55 self.cleanup() 

56 return super().__len__() 

57 

58 @override 

59 def __contains__(self, key: _KT) -> bool: 

60 return self.cleanup_by_key(key) 

61 

62 @override 

63 def __getitem__(self, key: _KT) -> _VT: 

64 self.cleanup_by_key(key) 

65 return super().__getitem__(key) 

66 

67 def get_expiry(self, key: _KT) -> datetime | None: 

68 self.cleanup_by_key(key) 

69 return self.__expiries.get(key) 

70 

71 @override 

72 def __iter__(self) -> Iterator[_KT]: 

73 self.cleanup() 

74 return super().__iter__() 

75 

76 @override 

77 def clear(self) -> None: 

78 self.__expiries.clear() 

79 self.data.clear() 

80 

81 @override 

82 def __delitem__(self, key: _KT) -> None: 

83 del self.__expiries[key] 

84 super().__delitem__(key) 

85 

86 @override 

87 def __setitem__(self, key: _KT, value: _VT) -> None: 

88 self.__expiries[key] = datetime.now(UTC) + self.__ttl 

89 super().__setitem__(key, value) 

90 

91 def renew_expiry(self, key: _KT) -> None: 

92 del self.__expiries[key] 

93 self.__expiries[key] = datetime.now(UTC) + self.__ttl 

94 

95 @overload 

96 def update( 

97 self, 

98 other: Mapping[_KT, _VT], 

99 /, 

100 **kwargs: _VT, 

101 ) -> None: ... 

102 @overload 

103 def update( 

104 self, 

105 other: Iterable[tuple[_KT, _VT]], 

106 /, 

107 **kwargs: _VT, 

108 ) -> None: ... 

109 @overload 

110 def update( 

111 self, 

112 other: None = None, 

113 /, 

114 **kwargs: _VT, 

115 ) -> None: ... 

116 @override 

117 def update( 

118 self, 

119 other=None, 

120 /, 

121 **kwargs, 

122 ) -> None: 

123 now: datetime = datetime.now(UTC) 

124 expiry: datetime = now + self.__ttl 

125 other_ttl_dict: bool = isinstance(other, TTLDict) 

126 

127 key: _KT 

128 value: _VT 

129 if isinstance(other, Mapping): 

130 for key, value in other.items(): 

131 self.__expiries[key] = ( 

132 # In rare case, item may have been expired during iteration 

133 # Thus, we set expiry to now (which means it is already expired) 

134 (other.get_expiry(key) or now) if other_ttl_dict 

135 else expiry 

136 ) 

137 self.data[key] = value 

138 elif isinstance(other, Iterable): 

139 for key, value in other: 

140 self.__expiries[key] = expiry 

141 self.data[key] = value 

142 

143 for str_key, value in kwargs.items(): 

144 key = cast("_KT", str_key) 

145 self.__expiries[key] = expiry 

146 self.data[key] = value 

147 

148 @override 

149 def copy(self) -> TTLDict: 

150 return TTLDict(self.__ttl, self) 

151 

152 @override 

153 def __or__(self, other: Mapping[_KT, _VT]) -> TTLDict: 

154 d: TTLDict = self.copy() 

155 d.update(other) 

156 

157 return d 

158 

159 @override 

160 def __ior__(self, other: Mapping[_KT, _VT]) -> Self: 

161 self.update(other) 

162 return self 

163 

164 @override 

165 def __repr__(self) -> str: 

166 self.cleanup() 

167 return super().__repr__() 

168 

169 @override 

170 def keys(self) -> KeysView: 

171 self.cleanup() 

172 return super().keys() 

173 

174 @override 

175 def values(self) -> ValuesView: 

176 self.cleanup() 

177 return super().values() 

178 

179 @override 

180 def items(self) -> ItemsView: 

181 self.cleanup() 

182 return super().items()