muutils.parallel
1import multiprocessing 2import functools 3from typing import ( 4 Any, 5 Callable, 6 Iterable, 7 Literal, 8 Optional, 9 Tuple, 10 TypeVar, 11 Dict, 12 List, 13 Union, 14 Protocol, 15) 16 17# for no tqdm fallback 18from muutils.spinner import SpinnerContext 19from muutils.validate_type import get_fn_allowed_kwargs 20 21 22InputType = TypeVar("InputType") 23OutputType = TypeVar("OutputType") 24# typevars for our iterable and map 25 26 27class ProgressBarFunction(Protocol): 28 "a protocol for a progress bar function" 29 30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 31 32 33ProgressBarOption = Literal["tqdm", "spinner", "none", None] 34# type for the progress bar option 35 36 37DEFAULT_PBAR_FN: ProgressBarOption 38# default progress bar function 39 40try: 41 # use tqdm if it's available 42 import tqdm # type: ignore[import-untyped] 43 44 DEFAULT_PBAR_FN = "tqdm" 45 46except ImportError: 47 # use progress bar as fallback 48 DEFAULT_PBAR_FN = "spinner" 49 50 51def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 52 "spinner wrapper" 53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 54 SpinnerContext.__init__ 55 ) 56 mapped_kwargs: dict = { 57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 58 } 59 if "desc" in kwargs and "message" not in mapped_kwargs: 60 mapped_kwargs["message"] = kwargs["desc"] 61 62 if "message" not in mapped_kwargs and "total" in kwargs: 63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 64 65 with SpinnerContext(**mapped_kwargs): 66 output = list(x) 67 68 return output 69 70 71def map_kwargs_for_tqdm(kwargs: dict) -> dict: 72 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 74 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 75 76 if "desc" not in kwargs: 77 if "message" in kwargs: 78 mapped_kwargs["desc"] = kwargs["message"] 79 80 elif "total" in kwargs: 81 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 82 return mapped_kwargs 83 84 85def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 86 "fallback to no progress bar" 87 return x 88 89 90def set_up_progress_bar_fn( 91 pbar: Union[ProgressBarFunction, ProgressBarOption], 92 pbar_kwargs: Optional[Dict[str, Any]] = None, 93 **extra_kwargs, 94) -> Tuple[ProgressBarFunction, dict]: 95 """set up the progress bar function and its kwargs 96 97 # Parameters: 98 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 99 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 100 - `pbar_kwargs : Optional[Dict[str, Any]]` 101 kwargs passed to the progress bar function (default to `None`) 102 (defaults to `None`) 103 104 # Returns: 105 - `Tuple[ProgressBarFunction, dict]` 106 a tuple of the progress bar function and its kwargs 107 108 # Raises: 109 - `ValueError` : if `pbar` is not one of the valid options 110 """ 111 pbar_fn: ProgressBarFunction 112 113 if pbar_kwargs is None: 114 pbar_kwargs = dict() 115 116 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 117 118 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 119 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 120 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 121 122 # if `pbar` is a different string, figure out which progress bar to use 123 elif isinstance(pbar, str): 124 if pbar == "tqdm": 125 pbar_fn = tqdm.tqdm 126 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 127 elif pbar == "spinner": 128 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 129 pbar_kwargs = dict() 130 else: 131 raise ValueError( 132 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 133 ) 134 else: 135 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 136 pbar_fn = pbar 137 138 return pbar_fn, pbar_kwargs 139 140 141# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes 142def run_maybe_parallel( 143 func: Callable[[InputType], OutputType], 144 iterable: Iterable[InputType], 145 parallel: Union[bool, int], 146 pbar_kwargs: Optional[Dict[str, Any]] = None, 147 chunksize: Optional[int] = None, 148 keep_ordered: bool = True, 149 use_multiprocess: bool = False, 150 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 151) -> List[OutputType]: 152 """a function to make it easier to sometimes parallelize an operation 153 154 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 155 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 156 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 157 158 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 159 160 # Parameters: 161 - `func : Callable[[InputType], OutputType]` 162 function passed to either `map` or `Pool.imap` 163 - `iterable : Iterable[InputType]` 164 iterable passed to either `map` or `Pool.imap` 165 - `parallel : bool | int` 166 whether to run in parallel, and how many processes to use 167 - `pbar_kwargs : Dict[str, Any]` 168 kwargs passed to the progress bar function 169 170 # Returns: 171 - `List[OutputType]` 172 a list of the output of `func` for each element in `iterable` 173 174 # Raises: 175 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 176 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 177 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 178 """ 179 180 # number of inputs in iterable 181 n_inputs: int = len(iterable) # type: ignore[arg-type] 182 if n_inputs == 0: 183 # Return immediately if there is no input 184 return list() 185 186 # which progress bar to use 187 pbar_fn: ProgressBarFunction 188 pbar_kwargs_processed: dict 189 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 190 pbar=pbar, 191 pbar_kwargs=pbar_kwargs, 192 # extra kwargs 193 total=n_inputs, 194 ) 195 196 # number of processes 197 num_processes: int 198 if isinstance(parallel, bool): 199 num_processes = multiprocessing.cpu_count() if parallel else 1 200 elif isinstance(parallel, int): 201 if parallel < 2: 202 raise ValueError( 203 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 204 ) 205 num_processes = parallel 206 else: 207 raise ValueError( 208 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 209 ) 210 211 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 212 num_processes = min(num_processes, n_inputs) 213 mp = multiprocessing 214 if num_processes == 1: 215 parallel = False 216 217 if use_multiprocess: 218 if not parallel: 219 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 220 221 try: 222 import multiprocess # type: ignore[import-untyped] 223 except ImportError as e: 224 raise ImportError( 225 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 226 ) from e 227 228 mp = multiprocess 229 230 # set up the map function -- maybe its parallel, maybe it's just `map` 231 do_map: Callable[ 232 [Callable[[InputType], OutputType], Iterable[InputType]], 233 Iterable[OutputType], 234 ] 235 if parallel: 236 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 237 pool = mp.Pool(num_processes) 238 239 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 240 if keep_ordered: 241 do_map = pool.imap 242 else: 243 do_map = pool.imap_unordered 244 245 # figure out a smart chunksize if one is not given 246 chunksize_int: int 247 if chunksize is None: 248 chunksize_int = max(1, n_inputs // num_processes) 249 else: 250 chunksize_int = chunksize 251 252 # set the chunksize 253 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 254 255 else: 256 do_map = map 257 258 # run the map function with a progress bar 259 output: List[OutputType] = list( 260 pbar_fn( 261 do_map( 262 func, 263 iterable, 264 ), 265 **pbar_kwargs_processed, 266 ) 267 ) 268 269 # close the pool if we used one 270 if parallel: 271 pool.close() 272 pool.join() 273 274 # return the output as a list 275 return output
class
ProgressBarFunction(typing.Protocol):
28class ProgressBarFunction(Protocol): 29 "a protocol for a progress bar function" 30 31 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
a protocol for a progress bar function
ProgressBarFunction(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs): 1768 cls = type(self) 1769 1770 if cls._is_protocol: 1771 raise TypeError('Protocols cannot be instantiated') 1772 1773 # Already using a custom `__init__`. No need to calculate correct 1774 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1775 if cls.__init__ is not _no_init_or_replace_init: 1776 return 1777 1778 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1779 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1780 # searches for a proper new `__init__` in the MRO. The new `__init__` 1781 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1782 # instantiation of the protocol subclass will thus use the new 1783 # `__init__` and no longer call `_no_init_or_replace_init`. 1784 for base in cls.__mro__: 1785 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1786 if init is not _no_init_or_replace_init: 1787 cls.__init__ = init 1788 break 1789 else: 1790 # should not happen 1791 cls.__init__ = object.__init__ 1792 1793 cls.__init__(self, *args, **kwargs)
ProgressBarOption =
typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] =
'tqdm'
def
spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 53 "spinner wrapper" 54 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 55 SpinnerContext.__init__ 56 ) 57 mapped_kwargs: dict = { 58 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 59 } 60 if "desc" in kwargs and "message" not in mapped_kwargs: 61 mapped_kwargs["message"] = kwargs["desc"] 62 63 if "message" not in mapped_kwargs and "total" in kwargs: 64 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 65 66 with SpinnerContext(**mapped_kwargs): 67 output = list(x) 68 69 return output
spinner wrapper
def
map_kwargs_for_tqdm(kwargs: dict) -> dict:
72def map_kwargs_for_tqdm(kwargs: dict) -> dict: 73 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 74 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 75 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 76 77 if "desc" not in kwargs: 78 if "message" in kwargs: 79 mapped_kwargs["desc"] = kwargs["message"] 80 81 elif "total" in kwargs: 82 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 83 return mapped_kwargs
map kwargs for tqdm, cant wrap because the pbar dissapears?
def
no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
86def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 87 "fallback to no progress bar" 88 return x
fallback to no progress bar
def
set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
91def set_up_progress_bar_fn( 92 pbar: Union[ProgressBarFunction, ProgressBarOption], 93 pbar_kwargs: Optional[Dict[str, Any]] = None, 94 **extra_kwargs, 95) -> Tuple[ProgressBarFunction, dict]: 96 """set up the progress bar function and its kwargs 97 98 # Parameters: 99 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 100 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 101 - `pbar_kwargs : Optional[Dict[str, Any]]` 102 kwargs passed to the progress bar function (default to `None`) 103 (defaults to `None`) 104 105 # Returns: 106 - `Tuple[ProgressBarFunction, dict]` 107 a tuple of the progress bar function and its kwargs 108 109 # Raises: 110 - `ValueError` : if `pbar` is not one of the valid options 111 """ 112 pbar_fn: ProgressBarFunction 113 114 if pbar_kwargs is None: 115 pbar_kwargs = dict() 116 117 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 118 119 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 120 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 121 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 122 123 # if `pbar` is a different string, figure out which progress bar to use 124 elif isinstance(pbar, str): 125 if pbar == "tqdm": 126 pbar_fn = tqdm.tqdm 127 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 128 elif pbar == "spinner": 129 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 130 pbar_kwargs = dict() 131 else: 132 raise ValueError( 133 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 134 ) 135 else: 136 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 137 pbar_fn = pbar 138 139 return pbar_fn, pbar_kwargs
set up the progress bar function and its kwargs
Parameters:
pbar : Union[ProgressBarFunction, ProgressBarOption]
progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to usepbar_kwargs : Optional[Dict[str, Any]]
kwargs passed to the progress bar function (default toNone
) (defaults toNone
)
Returns:
Tuple[ProgressBarFunction, dict]
a tuple of the progress bar function and its kwargs
Raises:
ValueError
: ifpbar
is not one of the valid options
def
run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
143def run_maybe_parallel( 144 func: Callable[[InputType], OutputType], 145 iterable: Iterable[InputType], 146 parallel: Union[bool, int], 147 pbar_kwargs: Optional[Dict[str, Any]] = None, 148 chunksize: Optional[int] = None, 149 keep_ordered: bool = True, 150 use_multiprocess: bool = False, 151 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 152) -> List[OutputType]: 153 """a function to make it easier to sometimes parallelize an operation 154 155 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 156 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 157 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 158 159 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 160 161 # Parameters: 162 - `func : Callable[[InputType], OutputType]` 163 function passed to either `map` or `Pool.imap` 164 - `iterable : Iterable[InputType]` 165 iterable passed to either `map` or `Pool.imap` 166 - `parallel : bool | int` 167 whether to run in parallel, and how many processes to use 168 - `pbar_kwargs : Dict[str, Any]` 169 kwargs passed to the progress bar function 170 171 # Returns: 172 - `List[OutputType]` 173 a list of the output of `func` for each element in `iterable` 174 175 # Raises: 176 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 177 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 178 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 179 """ 180 181 # number of inputs in iterable 182 n_inputs: int = len(iterable) # type: ignore[arg-type] 183 if n_inputs == 0: 184 # Return immediately if there is no input 185 return list() 186 187 # which progress bar to use 188 pbar_fn: ProgressBarFunction 189 pbar_kwargs_processed: dict 190 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 191 pbar=pbar, 192 pbar_kwargs=pbar_kwargs, 193 # extra kwargs 194 total=n_inputs, 195 ) 196 197 # number of processes 198 num_processes: int 199 if isinstance(parallel, bool): 200 num_processes = multiprocessing.cpu_count() if parallel else 1 201 elif isinstance(parallel, int): 202 if parallel < 2: 203 raise ValueError( 204 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 205 ) 206 num_processes = parallel 207 else: 208 raise ValueError( 209 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 210 ) 211 212 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 213 num_processes = min(num_processes, n_inputs) 214 mp = multiprocessing 215 if num_processes == 1: 216 parallel = False 217 218 if use_multiprocess: 219 if not parallel: 220 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 221 222 try: 223 import multiprocess # type: ignore[import-untyped] 224 except ImportError as e: 225 raise ImportError( 226 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 227 ) from e 228 229 mp = multiprocess 230 231 # set up the map function -- maybe its parallel, maybe it's just `map` 232 do_map: Callable[ 233 [Callable[[InputType], OutputType], Iterable[InputType]], 234 Iterable[OutputType], 235 ] 236 if parallel: 237 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 238 pool = mp.Pool(num_processes) 239 240 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 241 if keep_ordered: 242 do_map = pool.imap 243 else: 244 do_map = pool.imap_unordered 245 246 # figure out a smart chunksize if one is not given 247 chunksize_int: int 248 if chunksize is None: 249 chunksize_int = max(1, n_inputs // num_processes) 250 else: 251 chunksize_int = chunksize 252 253 # set the chunksize 254 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 255 256 else: 257 do_map = map 258 259 # run the map function with a progress bar 260 output: List[OutputType] = list( 261 pbar_fn( 262 do_map( 263 func, 264 iterable, 265 ), 266 **pbar_kwargs_processed, 267 ) 268 ) 269 270 # close the pool if we used one 271 if parallel: 272 pool.close() 273 pool.join() 274 275 # return the output as a list 276 return output
a function to make it easier to sometimes parallelize an operation
- if
parallel
isFalse
, then the function will run in serial, runningmap(func, iterable)
- if
parallel
isTrue
, then the function will run in parallel, running in parallel with the maximum number of processes - if
parallel
is anint
, it must be greater than 1, and the function will run in parallel with the number of processes specified byparallel
the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())
Parameters:
func : Callable[[InputType], OutputType]
function passed to eithermap
orPool.imap
iterable : Iterable[InputType]
iterable passed to eithermap
orPool.imap
parallel : bool | int
whether to run in parallel, and how many processes to usepbar_kwargs : Dict[str, Any]
kwargs passed to the progress bar function
Returns:
List[OutputType]
a list of the output offunc
for each element initerable
Raises:
ValueError
: ifparallel
is not a boolean or an integer greater than 1ValueError
: ifuse_multiprocess=True
andparallel=False
ImportError
: ifuse_multiprocess=True
andmultiprocess
is not available