Coverage for src/threadful/core.py: 100%
84 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-12-04 19:19 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-12-04 19:19 +0100
1"""
2Very simple threading abstraction.
3"""
5import contextlib
6import functools
7import threading
8import typing
9from copy import copy
11from result import Err, Ok, Result
12from typing_extensions import Self
14P = typing.ParamSpec("P")
15R = typing.TypeVar("R")
18class ThreadWithReturn(typing.Generic[R], threading.Thread):
19 """
20 Should not be used directly.
22 Rather use the @thread decorator,
23 which changes the return type of function() -> T into function() -> ThreadWithReturn[T]
24 """
26 _target: typing.Callable[P, R]
27 _args: P.args
28 _kwargs: P.kwargs
29 _return: R | Exception
30 _callbacks: list[typing.Callable[[R], R]]
31 _catch: list[typing.Callable[[Exception | R], Exception | R]]
33 def __init__(self, target: typing.Callable[P, R], *a: typing.Any, **kw: typing.Any) -> None:
34 """
35 Setup callbacks, otherwise same logic as super.
37 'target' is explicitly mentioned outside of kw for type hinting.
38 """
39 kw["target"] = target
40 super().__init__(*a, **kw)
41 self._callbacks = []
42 self._catch = []
44 def start(self) -> Self: # type: ignore
45 """
46 Normally, starting multiple times will lead to an error.
48 This version ignores duplicate starts.
49 """
50 with contextlib.suppress(RuntimeError):
51 super().start()
52 return self
54 def run(self) -> None:
55 """
56 Called in a new thread and handles the calling logic.
57 """
58 if self._target is None: # pragma: no cover
59 return
61 try:
62 result = self._target(*self._args, **self._kwargs)
63 for callback in self._callbacks:
64 result = callback(result)
65 self._return = result
66 except Exception as _e:
67 e: Exception | R = _e # make mypy happy
68 for err_callback in self._catch:
69 e = err_callback(e)
70 self._return = e
71 finally:
72 # Avoid a refcycle if the thread is running a function with
73 # an argument that has a member that points to the thread.
74 self._callbacks.clear()
75 self._catch.clear()
76 del self._target, self._args, self._kwargs
77 # keep self._return for .result()
79 def result(self, wait: bool = False) -> "Result[R, Exception | None]":
80 """
81 Get the result value (Ok or Err) from the threaded function.
83 By default, if the thread is not ready, Err(None) is returned.
84 If `wait` is used, this functions like a join() but with a Result.
86 """
87 self.start()
88 if wait:
89 super().join()
91 if self.is_alive():
92 # still busy
93 return Err(None)
94 else:
95 result = self._return
96 if isinstance(result, Exception):
97 return Err(result)
98 else:
99 return Ok(result)
101 def is_done(self) -> bool:
102 """
103 Returns whether the thread has finished (result or error).
104 """
105 self.start()
106 return not self.is_alive()
108 def then(self, callback: typing.Callable[[R], R]) -> Self:
109 """
110 Attach a callback (which runs in the thread as well) on success.
112 Returns 'self' so you can do .then().then().then().
113 """
114 new = copy(self)
115 new._callbacks.append(callback)
116 return new # builder pattern
118 def catch(self, callback: typing.Callable[[Exception | R], Exception | R]) -> Self:
119 """
120 Attach a callback (which runs in the thread as well) on error.
122 You can either return a new Exception or a fallback value.
123 Returns 'self' so you can do .then().catch().catch().
124 """
125 new = copy(self)
126 new._catch.append(callback)
127 return new
129 def join(self, timeout: int | float | None = None) -> R: # type: ignore
130 """
131 Enhanced version of thread.join that also returns the value or raises the exception.
132 """
133 self.start()
134 super().join(timeout)
136 match self.result():
137 case Ok(value):
138 return value
139 case Err(exc):
140 raise exc or Exception("Something went wrong.")
142 # thread must be ready so Err(None) can't happen
145@typing.overload
146def thread(my_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: # pragma: no cover
147 """
148 Code in this function is never executed, just shown for reference of the complex return type.
149 """
151 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
152 """Idem ditto."""
153 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) # code copied for mypy/ruff
155 return wraps
158@typing.overload
159def thread(
160 my_function: None = None,
161) -> typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]: # pragma: no cover
162 """
163 Code in this function is never executed, just shown for reference of the complex return type.
164 """
166 def wraps(inner_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]:
167 """Idem ditto."""
169 def inner(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
170 """Idem ditto."""
171 return ThreadWithReturn(target=inner_function, args=a, kwargs=kw) # code copied for mypy/ruff
173 return inner
175 return wraps
178def thread(
179 my_function: typing.Callable[P, R] | None = None,
180) -> (
181 typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]
182 | typing.Callable[P, ThreadWithReturn[R]]
183):
184 """
185 This decorator can be used to automagically make functions threaded!
187 Examples:
188 @thread
189 def myfunc():
190 ...
192 @thread()
193 def otherfunc():
194 ...
196 myfunc() and otherfunc() now return a custom thread object,
197 from which you can get the result value or exception with .result().
198 This uses a Result (Ok or Err) type from rustedpy/result (based on the Rust Result type.)
199 If the thread is not done yet, it will return Err(None)
200 You can also call .join(), which waits (blocking) until the thread is done
201 and then returns the return value or raises an exception (if raised in the thread)
202 """
203 if my_function is None:
204 return thread
206 @functools.wraps(my_function)
207 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
208 # note: before it called .start() immediately here
209 # however, if you then attach callbacks and the thread already finishes, they would not run.
210 # now, start() is called once you check for a result() or wait for it to finish via join()
211 return ThreadWithReturn(target=my_function, args=a, kwargs=kw)
213 return wraps
216def join_all_results(*threads: ThreadWithReturn[R]) -> tuple[Result[R, Exception], ...]:
217 """
218 Wait for all threads to complete and retrieve their results as `Result` objects.
220 Args:
221 *threads: A variable number of `ThreadWithReturn` instances to join.
223 Returns:
224 tuple[Result[R, Exception], ...]: A tuple containing `Result` objects for each thread,
225 where each result represents the success or error outcome of the thread.
226 """
227 return tuple(_.result(wait=True) for _ in threads)
230def join_all_or_raise(*threads: ThreadWithReturn[R]) -> tuple[R, ...]:
231 """
232 Wait for all threads to complete and retrieve their results, raising exceptions on failure.
234 Args:
235 *threads: A variable number of `ThreadWithReturn` instances to join.
237 Returns:
238 tuple[R, ...]: A tuple containing the successful results of each thread.
240 Raises:
241 Exception: If any thread raises an exception, it is propagated.
242 """
243 return tuple(_.join() for _ in threads)
246def join_all_unwrap(*threads: ThreadWithReturn[R]) -> tuple[R | None, ...]:
247 """
248 Wait for all threads to complete and retrieve their results, unwrapping successes or returning None on error.
250 Args:
251 *threads: A variable number of `ThreadWithReturn` instances to join.
253 Returns:
254 tuple[R | None, ...]: A tuple containing the results of each thread, where errors are replaced with None.
255 """
256 return tuple(_.result(wait=True).unwrap_or(None) for _ in threads)
259__all__ = [
260 "ThreadWithReturn",
261 "thread",
262 "join_all_or_raise",
263 "join_all_results",
264 "join_all_unwrap",
265]