Coverage for src/threadful/core.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-11-19 20:27 +0100

1""" 

2Very simple threading abstraction. 

3""" 

4 

5import contextlib 

6import functools 

7import threading 

8import typing 

9from copy import copy 

10 

11from result import Err, Ok, Result 

12from typing_extensions import Self 

13 

14P = typing.ParamSpec("P") 

15R = typing.TypeVar("R") 

16 

17 

18class ThreadWithReturn(typing.Generic[R], threading.Thread): 

19 """ 

20 Should not be used directly. 

21 

22 Rather use the @thread decorator, 

23 which changes the return type of function() -> T into function() -> ThreadWithReturn[T] 

24 """ 

25 

26 _target: typing.Callable[P, R] 

27 _args: P.args 

28 _kwargs: P.kwargs 

29 _return: R | Exception 

30 _callbacks: list[typing.Callable[[R], R]] 

31 _catch: list[typing.Callable[[Exception | R], Exception | R]] 

32 

33 def __init__(self, target: typing.Callable[P, R], *a: typing.Any, **kw: typing.Any) -> None: 

34 """ 

35 Setup callbacks, otherwise same logic as super. 

36 

37 'target' is explicitly mentioned outside of kw for type hinting. 

38 """ 

39 kw["target"] = target 

40 super().__init__(*a, **kw) 

41 self._callbacks = [] 

42 self._catch = [] 

43 

44 def start(self) -> Self: # type: ignore 

45 """ 

46 Normally, starting multiple times will lead to an error. 

47 

48 This version ignores duplicate starts. 

49 """ 

50 with contextlib.suppress(RuntimeError): 

51 super().start() 

52 return self 

53 

54 def run(self) -> None: 

55 """ 

56 Called in a new thread and handles the calling logic. 

57 """ 

58 if self._target is None: # pragma: no cover 

59 return 

60 

61 try: 

62 result = self._target(*self._args, **self._kwargs) 

63 for callback in self._callbacks: 

64 result = callback(result) 

65 self._return = result 

66 except Exception as _e: 

67 e: Exception | R = _e # make mypy happy 

68 for err_callback in self._catch: 

69 e = err_callback(e) 

70 self._return = e 

71 finally: 

72 # Avoid a refcycle if the thread is running a function with 

73 # an argument that has a member that points to the thread. 

74 self._callbacks.clear() 

75 self._catch.clear() 

76 del self._target, self._args, self._kwargs 

77 # keep self._return for .result() 

78 

79 def result(self) -> "Result[R, Exception | None]": 

80 """ 

81 Get the result value (Ok or Err) from the threaded function. 

82 

83 If the thread is not ready, Err(None) is returned. 

84 """ 

85 self.start() 

86 if self.is_alive(): 

87 # still busy 

88 return Err(None) 

89 else: 

90 result = self._return 

91 if isinstance(result, Exception): 

92 return Err(result) 

93 else: 

94 return Ok(result) 

95 

96 def is_done(self) -> bool: 

97 """ 

98 Returns whether the thread has finished (result or error). 

99 """ 

100 self.start() 

101 return not self.is_alive() 

102 

103 def then(self, callback: typing.Callable[[R], R]) -> Self: 

104 """ 

105 Attach a callback (which runs in the thread as well) on success. 

106 

107 Returns 'self' so you can do .then().then().then(). 

108 """ 

109 new = copy(self) 

110 new._callbacks.append(callback) 

111 return new # builder pattern 

112 

113 def catch(self, callback: typing.Callable[[Exception | R], Exception | R]) -> Self: 

114 """ 

115 Attach a callback (which runs in the thread as well) on error. 

116 

117 You can either return a new Exception or a fallback value. 

118 Returns 'self' so you can do .then().catch().catch(). 

119 """ 

120 new = copy(self) 

121 new._catch.append(callback) 

122 return new 

123 

124 def join(self, timeout: int | float | None = None) -> R: # type: ignore 

125 """ 

126 Enhanced version of thread.join that also returns the value or raises the exception. 

127 """ 

128 self.start() 

129 super().join(timeout) 

130 

131 match self.result(): 

132 case Ok(value): 

133 return value 

134 case Err(exc): 

135 raise exc or Exception("Something went wrong.") 

136 

137 # thread must be ready so Err(None) can't happen 

138 

139 

140@typing.overload 

141def thread(my_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: # pragma: no cover 

142 """ 

143 Code in this function is never executed, just shown for reference of the complex return type. 

144 """ 

145 

146 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

147 """Idem ditto.""" 

148 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) # code copied for mypy/ruff 

149 

150 return wraps 

151 

152 

153@typing.overload 

154def thread( 

155 my_function: None = None, 

156) -> typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]: # pragma: no cover 

157 """ 

158 Code in this function is never executed, just shown for reference of the complex return type. 

159 """ 

160 

161 def wraps(inner_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: 

162 """Idem ditto.""" 

163 

164 def inner(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

165 """Idem ditto.""" 

166 return ThreadWithReturn(target=inner_function, args=a, kwargs=kw) # code copied for mypy/ruff 

167 

168 return inner 

169 

170 return wraps 

171 

172 

173def thread( 

174 my_function: typing.Callable[P, R] | None = None 

175) -> ( 

176 typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]] 

177 | typing.Callable[P, ThreadWithReturn[R]] 

178): 

179 """ 

180 This decorator can be used to automagically make functions threaded! 

181 

182 Examples: 

183 @thread 

184 def myfunc(): 

185 ... 

186 

187 @thread() 

188 def otherfunc(): 

189 ... 

190 

191 myfunc() and otherfunc() now return a custom thread object, 

192 from which you can get the result value or exception with .result(). 

193 This uses a Result (Ok or Err) type from rustedpy/result (based on the Rust Result type.) 

194 If the thread is not done yet, it will return Err(None) 

195 You can also call .join(), which waits (blocking) until the thread is done 

196 and then returns the return value or raises an exception (if raised in the thread) 

197 """ 

198 if my_function is None: 

199 return thread 

200 

201 @functools.wraps(my_function) 

202 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]: 

203 # note: before it called .start() immediately here 

204 # however, if you then attach callbacks and the thread already finishes, they would not run. 

205 # now, start() is called once you check for a result() or wait for it to finish via join() 

206 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) 

207 

208 return wraps 

209 

210 

211__all__ = [ 

212 "ThreadWithReturn", 

213 "thread", 

214]