Coverage for src/threadful/core.py: 100%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-12-04 19:19 +0100

1""" 

2Very simple threading abstraction. 

3""" 

4 

5import contextlib 

6import functools 

7import threading 

8import typing 

9from copy import copy 

10 

11from result import Err, Ok, Result 

12from typing_extensions import Self 

13 

14P = typing.ParamSpec("P") 

15R = typing.TypeVar("R") 

16 

17 

18class ThreadWithReturn(typing.Generic[R], threading.Thread): 

19 """ 

20 Should not be used directly. 

21 

22 Rather use the @thread decorator, 

23 which changes the return type of function() -> T into function() -> ThreadWithReturn[T] 

24 """ 

25 

26 _target: typing.Callable[P, R] 

27 _args: P.args 

28 _kwargs: P.kwargs 

29 _return: R | Exception 

30 _callbacks: list[typing.Callable[[R], R]] 

31 _catch: list[typing.Callable[[Exception | R], Exception | R]] 

32 

33 def __init__(self, target: typing.Callable[P, R], *a: typing.Any, **kw: typing.Any) -> None: 

34 """ 

35 Setup callbacks, otherwise same logic as super. 

36 

37 'target' is explicitly mentioned outside of kw for type hinting. 

38 """ 

39 kw["target"] = target 

40 super().__init__(*a, **kw) 

41 self._callbacks = [] 

42 self._catch = [] 

43 

44 def start(self) -> Self: # type: ignore 

45 """ 

46 Normally, starting multiple times will lead to an error. 

47 

48 This version ignores duplicate starts. 

49 """ 

50 with contextlib.suppress(RuntimeError): 

51 super().start() 

52 return self 

53 

54 def run(self) -> None: 

55 """ 

56 Called in a new thread and handles the calling logic. 

57 """ 

58 if self._target is None: # pragma: no cover 

59 return 

60 

61 try: 

62 result = self._target(*self._args, **self._kwargs) 

63 for callback in self._callbacks: 

64 result = callback(result) 

65 self._return = result 

66 except Exception as _e: 

67 e: Exception | R = _e # make mypy happy 

68 for err_callback in self._catch: 

69 e = err_callback(e) 

70 self._return = e 

71 finally: 

72 # Avoid a refcycle if the thread is running a function with 

73 # an argument that has a member that points to the thread. 

74 self._callbacks.clear() 

75 self._catch.clear() 

76 del self._target, self._args, self._kwargs 

77 # keep self._return for .result() 

78 

79 def result(self, wait: bool = False) -> "Result[R, Exception | None]": 

80 """ 

81 Get the result value (Ok or Err) from the threaded function. 

82 

83 By default, if the thread is not ready, Err(None) is returned. 

84 If `wait` is used, this functions like a join() but with a Result. 

85 

86 """ 

87 self.start() 

88 if wait: 

89 super().join() 

90 

91 if self.is_alive(): 

92 # still busy 

93 return Err(None) 

94 else: 

95 result = self._return 

96 if isinstance(result, Exception): 

97 return Err(result) 

98 else: 

99 return Ok(result) 

100 

101 def is_done(self) -> bool: 

102 """ 

103 Returns whether the thread has finished (result or error). 

104 """ 

105 self.start() 

106 return not self.is_alive() 

107 

108 def then(self, callback: typing.Callable[[R], R]) -> Self: 

109 """ 

110 Attach a callback (which runs in the thread as well) on success. 

111 

112 Returns 'self' so you can do .then().then().then(). 

113 """ 

114 new = copy(self) 

115 new._callbacks.append(callback) 

116 return new # builder pattern 

117 

118 def catch(self, callback: typing.Callable[[Exception | R], Exception | R]) -> Self: 

119 """ 

120 Attach a callback (which runs in the thread as well) on error. 

121 

122 You can either return a new Exception or a fallback value. 

123 Returns 'self' so you can do .then().catch().catch(). 

124 """ 

125 new = copy(self) 

126 new._catch.append(callback) 

127 return new 

128 

129 def join(self, timeout: int | float | None = None) -> R: # type: ignore 

130 """ 

131 Enhanced version of thread.join that also returns the value or raises the exception. 

132 """ 

133 self.start() 

134 super().join(timeout) 

135 

136 match self.result(): 

137 case Ok(value): 

138 return value 

139 case Err(exc): 

140 raise exc or Exception("Something went wrong.") 

141 

142 # thread must be ready so Err(None) can't happen 

143 

144 

145@typing.overload 

146def thread(my_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: # pragma: no cover 

147 """ 

148 Code in this function is never executed, just shown for reference of the complex return type. 

149 """ 

150 

151 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

152 """Idem ditto.""" 

153 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) # code copied for mypy/ruff 

154 

155 return wraps 

156 

157 

158@typing.overload 

159def thread( 

160 my_function: None = None, 

161) -> typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]: # pragma: no cover 

162 """ 

163 Code in this function is never executed, just shown for reference of the complex return type. 

164 """ 

165 

166 def wraps(inner_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: 

167 """Idem ditto.""" 

168 

169 def inner(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

170 """Idem ditto.""" 

171 return ThreadWithReturn(target=inner_function, args=a, kwargs=kw) # code copied for mypy/ruff 

172 

173 return inner 

174 

175 return wraps 

176 

177 

178def thread( 

179 my_function: typing.Callable[P, R] | None = None, 

180) -> ( 

181 typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]] 

182 | typing.Callable[P, ThreadWithReturn[R]] 

183): 

184 """ 

185 This decorator can be used to automagically make functions threaded! 

186 

187 Examples: 

188 @thread 

189 def myfunc(): 

190 ... 

191 

192 @thread() 

193 def otherfunc(): 

194 ... 

195 

196 myfunc() and otherfunc() now return a custom thread object, 

197 from which you can get the result value or exception with .result(). 

198 This uses a Result (Ok or Err) type from rustedpy/result (based on the Rust Result type.) 

199 If the thread is not done yet, it will return Err(None) 

200 You can also call .join(), which waits (blocking) until the thread is done 

201 and then returns the return value or raises an exception (if raised in the thread) 

202 """ 

203 if my_function is None: 

204 return thread 

205 

206 @functools.wraps(my_function) 

207 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

208 # note: before it called .start() immediately here 

209 # however, if you then attach callbacks and the thread already finishes, they would not run. 

210 # now, start() is called once you check for a result() or wait for it to finish via join() 

211 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) 

212 

213 return wraps 

214 

215 

216def join_all_results(*threads: ThreadWithReturn[R]) -> tuple[Result[R, Exception], ...]: 

217 """ 

218 Wait for all threads to complete and retrieve their results as `Result` objects. 

219 

220 Args: 

221 *threads: A variable number of `ThreadWithReturn` instances to join. 

222 

223 Returns: 

224 tuple[Result[R, Exception], ...]: A tuple containing `Result` objects for each thread, 

225 where each result represents the success or error outcome of the thread. 

226 """ 

227 return tuple(_.result(wait=True) for _ in threads) 

228 

229 

230def join_all_or_raise(*threads: ThreadWithReturn[R]) -> tuple[R, ...]: 

231 """ 

232 Wait for all threads to complete and retrieve their results, raising exceptions on failure. 

233 

234 Args: 

235 *threads: A variable number of `ThreadWithReturn` instances to join. 

236 

237 Returns: 

238 tuple[R, ...]: A tuple containing the successful results of each thread. 

239 

240 Raises: 

241 Exception: If any thread raises an exception, it is propagated. 

242 """ 

243 return tuple(_.join() for _ in threads) 

244 

245 

246def join_all_unwrap(*threads: ThreadWithReturn[R]) -> tuple[R | None, ...]: 

247 """ 

248 Wait for all threads to complete and retrieve their results, unwrapping successes or returning None on error. 

249 

250 Args: 

251 *threads: A variable number of `ThreadWithReturn` instances to join. 

252 

253 Returns: 

254 tuple[R | None, ...]: A tuple containing the results of each thread, where errors are replaced with None. 

255 """ 

256 return tuple(_.result(wait=True).unwrap_or(None) for _ in threads) 

257 

258 

259__all__ = [ 

260 "ThreadWithReturn", 

261 "thread", 

262 "join_all_or_raise", 

263 "join_all_results", 

264 "join_all_unwrap", 

265]