Coverage for muutils\parallel.py: 94%
94 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-21 19:08 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-21 19:08 -0700
1import multiprocessing
2import functools
3from typing import (
4 Any,
5 Callable,
6 Iterable,
7 Literal,
8 Optional,
9 Tuple,
10 TypeVar,
11 Dict,
12 List,
13 Union,
14 Protocol,
15)
17# for no tqdm fallback
18from muutils.spinner import SpinnerContext
19from muutils.validate_type import get_fn_allowed_kwargs
22InputType = TypeVar("InputType")
23OutputType = TypeVar("OutputType")
24# typevars for our iterable and map
27class ProgressBarFunction(Protocol):
28 "a protocol for a progress bar function"
30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
33ProgressBarOption = Literal["tqdm", "spinner", "none", None]
34# type for the progress bar option
37DEFAULT_PBAR_FN: ProgressBarOption
38# default progress bar function
40try:
41 # use tqdm if it's available
42 import tqdm # type: ignore[import-untyped]
44 DEFAULT_PBAR_FN = "tqdm"
46except ImportError:
47 # use progress bar as fallback
48 DEFAULT_PBAR_FN = "spinner"
51def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52 "spinner wrapper"
53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
54 SpinnerContext.__init__
55 )
56 mapped_kwargs: dict = {
57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
58 }
59 if "desc" in kwargs and "message" not in mapped_kwargs:
60 mapped_kwargs["message"] = kwargs["desc"]
62 if "message" not in mapped_kwargs and "total" in kwargs:
63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
65 with SpinnerContext(**mapped_kwargs):
66 output = list(x)
68 return output
71def map_kwargs_for_tqdm(kwargs: dict) -> dict:
72 "map kwargs for tqdm, cant wrap because the pbar dissapears?"
73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
74 print(f"{tqdm_allowed_kwargs = }")
75 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
77 if "desc" not in kwargs:
78 if "message" in kwargs:
79 mapped_kwargs["desc"] = kwargs["message"]
81 elif "total" in kwargs:
82 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
83 print(f"{mapped_kwargs = }")
84 return mapped_kwargs
87def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
88 "fallback to no progress bar"
89 return x
92def set_up_progress_bar_fn(
93 pbar: Union[ProgressBarFunction, ProgressBarOption],
94 pbar_kwargs: Optional[Dict[str, Any]] = None,
95 **extra_kwargs,
96) -> Tuple[ProgressBarFunction, dict]:
97 """set up the progress bar function and its kwargs
99 # Parameters:
100 - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
101 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
102 - `pbar_kwargs : Optional[Dict[str, Any]]`
103 kwargs passed to the progress bar function (default to `None`)
104 (defaults to `None`)
106 # Returns:
107 - `Tuple[ProgressBarFunction, dict]`
108 a tuple of the progress bar function and its kwargs
110 # Raises:
111 - `ValueError` : if `pbar` is not one of the valid options
112 """
113 pbar_fn: ProgressBarFunction
115 if pbar_kwargs is None:
116 pbar_kwargs = dict()
118 pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
120 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
121 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
122 pbar_fn = no_progress_fn_wrap # type: ignore[assignment]
124 # if `pbar` is a different string, figure out which progress bar to use
125 elif isinstance(pbar, str):
126 if pbar == "tqdm":
127 pbar_fn = tqdm.tqdm
128 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
129 elif pbar == "spinner":
130 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
131 pbar_kwargs = dict()
132 else:
133 raise ValueError(
134 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
135 )
136 else:
137 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
138 pbar_fn = pbar
140 return pbar_fn, pbar_kwargs
143def run_maybe_parallel(
144 func: Callable[[InputType], OutputType],
145 iterable: Iterable[InputType],
146 parallel: Union[bool, int],
147 pbar_kwargs: Optional[Dict[str, Any]] = None,
148 chunksize: Optional[int] = None,
149 keep_ordered: bool = True,
150 use_multiprocess: bool = False,
151 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
152) -> List[OutputType]:
153 """a function to make it easier to sometimes parallelize an operation
155 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
156 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
157 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
159 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
161 # Parameters:
162 - `func : Callable[[InputType], OutputType]`
163 function passed to either `map` or `Pool.imap`
164 - `iterable : Iterable[InputType]`
165 iterable passed to either `map` or `Pool.imap`
166 - `parallel : bool | int`
167 whether to run in parallel, and how many processes to use
168 - `pbar_kwargs : Dict[str, Any]`
169 kwargs passed to the progress bar function
171 # Returns:
172 - `List[OutputType]`
173 a list of the output of `func` for each element in `iterable`
175 # Raises:
176 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
177 - `ValueError` : if `use_multiprocess=True` and `parallel=False`
178 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
179 """
181 # number of inputs in iterable
182 n_inputs: int = len(iterable) # type: ignore[arg-type]
183 if n_inputs == 0:
184 # Return immediately if there is no input
185 return list()
187 # which progress bar to use
188 pbar_fn: ProgressBarFunction
189 pbar_kwargs_processed: dict
190 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
191 pbar=pbar,
192 pbar_kwargs=pbar_kwargs,
193 # extra kwargs
194 total=n_inputs,
195 )
197 # number of processes
198 num_processes: int
199 if isinstance(parallel, bool):
200 num_processes = multiprocessing.cpu_count() if parallel else 1
201 elif isinstance(parallel, int):
202 if parallel < 2:
203 raise ValueError(
204 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
205 )
206 num_processes = parallel
207 else:
208 raise ValueError(
209 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
210 )
212 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
213 num_processes = min(num_processes, n_inputs)
214 mp = multiprocessing
215 if num_processes == 1:
216 parallel = False
218 if use_multiprocess:
219 if not parallel:
220 raise ValueError("`use_multiprocess=True` requires `parallel=True`")
222 try:
223 import multiprocess # type: ignore[import-untyped]
224 except ImportError as e:
225 raise ImportError(
226 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
227 ) from e
229 mp = multiprocess
231 # set up the map function -- maybe its parallel, maybe it's just `map`
232 do_map: Callable[
233 [Callable[[InputType], OutputType], Iterable[InputType]],
234 Iterable[OutputType],
235 ]
236 if parallel:
237 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
238 pool = mp.Pool(num_processes)
240 # use `imap` if we want to keep the order, otherwise use `imap_unordered`
241 if keep_ordered:
242 do_map = pool.imap
243 else:
244 do_map = pool.imap_unordered
246 # figure out a smart chunksize if one is not given
247 chunksize_int: int
248 if chunksize is None:
249 chunksize_int = max(1, n_inputs // num_processes)
250 else:
251 chunksize_int = chunksize
253 # set the chunksize
254 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore
256 else:
257 do_map = map
259 # run the map function with a progress bar
260 output: List[OutputType] = list(
261 pbar_fn(
262 do_map(
263 func,
264 iterable,
265 ),
266 **pbar_kwargs_processed,
267 )
268 )
270 # close the pool if we used one
271 if parallel:
272 pool.close()
273 pool.join()
275 # return the output as a list
276 return output