Coverage for src/ttl_dict/__init__.py: 69%

118 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-22 01:36 -0700

1from __future__ import annotations 

2 

3from collections import UserDict 

4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView 

5from datetime import UTC, datetime, timedelta 

6from typing import Any, Self, override 

7 

8 

9class TTLDict[_KT, _VT](UserDict[_KT, _VT]): 

10 def __init__( 

11 self, 

12 ttl: timedelta, 

13 other: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None = None, 

14 /, 

15 **kwargs: _VT, 

16 ) -> None: 

17 self.__ttl: timedelta = ttl 

18 

19 self.expiries: dict[_KT, datetime] = {} 

20 

21 # Must be at the end of __init__ as it calls self.update which needs self.__ttl 

22 super().__init__(other, **kwargs) 

23 

24 def cleanup(self) -> None: 

25 now: datetime = datetime.now(UTC) 

26 

27 expired_keys: list[_KT] = [] 

28 for key, expiry in self.expiries.items(): 

29 # As dict is iterated by insert order, the newer ones are iterated later 

30 if now < expiry: 

31 break 

32 

33 expired_keys.append(key) 

34 

35 for key in expired_keys: 

36 del self.expiries[key] 

37 del self.data[key] 

38 

39 def cleanup_by_key(self, key: _KT) -> bool: 

40 now: datetime = datetime.now(UTC) 

41 

42 if key not in self.expiries: 

43 return False 

44 

45 if self.expiries[key] <= now: 

46 del self.expiries[key] 

47 del self.data[key] 

48 

49 return False 

50 

51 return True 

52 

53 @override 

54 def __len__(self) -> int: 

55 self.cleanup() 

56 return super().__len__() 

57 

58 @override 

59 def __contains__(self, key: _KT) -> bool: 

60 return self.cleanup_by_key(key) 

61 

62 @override 

63 def get(self, key: _KT, default: Any = None) -> Any: 

64 self.cleanup_by_key(key) 

65 return super().get(key, default) 

66 

67 @override 

68 def __getitem__(self, key: _KT) -> _VT: 

69 self.cleanup_by_key(key) 

70 return super().__getitem__(key) 

71 

72 @override 

73 def __iter__(self) -> Iterator[_KT]: 

74 self.cleanup() 

75 return super().__iter__() 

76 

77 @override 

78 def clear(self) -> None: 

79 self.expiries.clear() 

80 super().clear() 

81 

82 @override 

83 def pop(self, key: _KT, default: Any = None) -> Any: 

84 if not self.expiries.pop(key): 

85 return default 

86 

87 return super().pop(key, default) 

88 

89 @override 

90 def popitem(self) -> tuple[_KT, _VT]: 

91 self.cleanup() 

92 

93 key, value = super().popitem() 

94 self.expiries.pop(key) 

95 

96 return (key, value) 

97 

98 @override 

99 def __delitem__(self, key: _KT) -> None: 

100 del self.expiries[key] 

101 super().__delitem__(key) 

102 

103 @override 

104 def __setitem__(self, key: _KT, value: _VT) -> None: 

105 self.expiries[key] = datetime.now(UTC) + self.__ttl 

106 super().__setitem__(key, value) 

107 

108 @override 

109 def setdefault(self, key: _KT, default: Any = None) -> Any: 

110 self.expiries[key] = datetime.now(UTC) + self.__ttl 

111 return super().setdefault(key, default) 

112 

113 @override 

114 def update( 

115 self, 

116 other: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None, 

117 /, 

118 **kwargs: _VT, 

119 ) -> None: 

120 expiry: datetime = datetime.now(UTC) + self.__ttl 

121 other_ttl_dict: bool = isinstance(other, TTLDict) 

122 

123 key: _KT 

124 value: _VT 

125 if isinstance(other, Mapping): 

126 for key, value in other.items(): 

127 self.expiries[key] = other.expiries[key] if other_ttl_dict else expiry 

128 self.data[key] = value 

129 elif isinstance(other, Iterable): 

130 for key, value in other: 

131 self.expiries[key] = expiry 

132 self.data[key] = value 

133 

134 for key, value in kwargs.items(): 

135 self.expiries[key] = expiry 

136 self.data[key] = value 

137 

138 @override 

139 def copy(self) -> TTLDict: 

140 return TTLDict(self.__ttl, self) 

141 

142 @override 

143 def __or__(self, other: Mapping[_KT, _VT]) -> TTLDict: 

144 d: TTLDict = self.copy() 

145 d.update(other) 

146 

147 return d 

148 

149 @override 

150 def __ior__(self, other: Mapping[_KT, _VT]) -> Self: 

151 self.update(other) 

152 return self 

153 

154 @override 

155 def __repr__(self) -> str: 

156 self.cleanup() 

157 return super().__repr__() 

158 

159 @override 

160 def keys(self) -> KeysView: 

161 self.cleanup() 

162 return super().keys() 

163 

164 @override 

165 def values(self) -> ValuesView: 

166 self.cleanup() 

167 return super().values() 

168 

169 @override 

170 def items(self) -> ItemsView: 

171 self.cleanup() 

172 return super().items()