Coverage for src/threadful/core.py: 100%
76 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-11-19 20:27 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-11-19 20:27 +0100
1"""
2Very simple threading abstraction.
3"""
5import contextlib
6import functools
7import threading
8import typing
9from copy import copy
11from result import Err, Ok, Result
12from typing_extensions import Self
14P = typing.ParamSpec("P")
15R = typing.TypeVar("R")
18class ThreadWithReturn(typing.Generic[R], threading.Thread):
19 """
20 Should not be used directly.
22 Rather use the @thread decorator,
23 which changes the return type of function() -> T into function() -> ThreadWithReturn[T]
24 """
26 _target: typing.Callable[P, R]
27 _args: P.args
28 _kwargs: P.kwargs
29 _return: R | Exception
30 _callbacks: list[typing.Callable[[R], R]]
31 _catch: list[typing.Callable[[Exception | R], Exception | R]]
33 def __init__(self, target: typing.Callable[P, R], *a: typing.Any, **kw: typing.Any) -> None:
34 """
35 Setup callbacks, otherwise same logic as super.
37 'target' is explicitly mentioned outside of kw for type hinting.
38 """
39 kw["target"] = target
40 super().__init__(*a, **kw)
41 self._callbacks = []
42 self._catch = []
44 def start(self) -> Self: # type: ignore
45 """
46 Normally, starting multiple times will lead to an error.
48 This version ignores duplicate starts.
49 """
50 with contextlib.suppress(RuntimeError):
51 super().start()
52 return self
54 def run(self) -> None:
55 """
56 Called in a new thread and handles the calling logic.
57 """
58 if self._target is None: # pragma: no cover
59 return
61 try:
62 result = self._target(*self._args, **self._kwargs)
63 for callback in self._callbacks:
64 result = callback(result)
65 self._return = result
66 except Exception as _e:
67 e: Exception | R = _e # make mypy happy
68 for err_callback in self._catch:
69 e = err_callback(e)
70 self._return = e
71 finally:
72 # Avoid a refcycle if the thread is running a function with
73 # an argument that has a member that points to the thread.
74 self._callbacks.clear()
75 self._catch.clear()
76 del self._target, self._args, self._kwargs
77 # keep self._return for .result()
79 def result(self) -> "Result[R, Exception | None]":
80 """
81 Get the result value (Ok or Err) from the threaded function.
83 If the thread is not ready, Err(None) is returned.
84 """
85 self.start()
86 if self.is_alive():
87 # still busy
88 return Err(None)
89 else:
90 result = self._return
91 if isinstance(result, Exception):
92 return Err(result)
93 else:
94 return Ok(result)
96 def is_done(self) -> bool:
97 """
98 Returns whether the thread has finished (result or error).
99 """
100 self.start()
101 return not self.is_alive()
103 def then(self, callback: typing.Callable[[R], R]) -> Self:
104 """
105 Attach a callback (which runs in the thread as well) on success.
107 Returns 'self' so you can do .then().then().then().
108 """
109 new = copy(self)
110 new._callbacks.append(callback)
111 return new # builder pattern
113 def catch(self, callback: typing.Callable[[Exception | R], Exception | R]) -> Self:
114 """
115 Attach a callback (which runs in the thread as well) on error.
117 You can either return a new Exception or a fallback value.
118 Returns 'self' so you can do .then().catch().catch().
119 """
120 new = copy(self)
121 new._catch.append(callback)
122 return new
124 def join(self, timeout: int | float | None = None) -> R: # type: ignore
125 """
126 Enhanced version of thread.join that also returns the value or raises the exception.
127 """
128 self.start()
129 super().join(timeout)
131 match self.result():
132 case Ok(value):
133 return value
134 case Err(exc):
135 raise exc or Exception("Something went wrong.")
137 # thread must be ready so Err(None) can't happen
140@typing.overload
141def thread(my_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]: # pragma: no cover
142 """
143 Code in this function is never executed, just shown for reference of the complex return type.
144 """
146 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
147 """Idem ditto."""
148 return ThreadWithReturn(target=my_function, args=a, kwargs=kw) # code copied for mypy/ruff
150 return wraps
153@typing.overload
154def thread(
155 my_function: None = None,
156) -> typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]: # pragma: no cover
157 """
158 Code in this function is never executed, just shown for reference of the complex return type.
159 """
161 def wraps(inner_function: typing.Callable[P, R]) -> typing.Callable[P, ThreadWithReturn[R]]:
162 """Idem ditto."""
164 def inner(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
165 """Idem ditto."""
166 return ThreadWithReturn(target=inner_function, args=a, kwargs=kw) # code copied for mypy/ruff
168 return inner
170 return wraps
173def thread(
174 my_function: typing.Callable[P, R] | None = None
175) -> (
176 typing.Callable[[typing.Callable[P, R]], typing.Callable[P, ThreadWithReturn[R]]]
177 | typing.Callable[P, ThreadWithReturn[R]]
178):
179 """
180 This decorator can be used to automagically make functions threaded!
182 Examples:
183 @thread
184 def myfunc():
185 ...
187 @thread()
188 def otherfunc():
189 ...
191 myfunc() and otherfunc() now return a custom thread object,
192 from which you can get the result value or exception with .result().
193 This uses a Result (Ok or Err) type from rustedpy/result (based on the Rust Result type.)
194 If the thread is not done yet, it will return Err(None)
195 You can also call .join(), which waits (blocking) until the thread is done
196 and then returns the return value or raises an exception (if raised in the thread)
197 """
198 if my_function is None:
199 return thread
201 @functools.wraps(my_function)
202 def wraps(*a: P.args, **kw: P.kwargs) -> ThreadWithReturn[R]:
203 # note: before it called .start() immediately here
204 # however, if you then attach callbacks and the thread already finishes, they would not run.
205 # now, start() is called once you check for a result() or wait for it to finish via join()
206 return ThreadWithReturn(target=my_function, args=a, kwargs=kw)
208 return wraps
211__all__ = [
212 "ThreadWithReturn",
213 "thread",
214]